diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f109c6b..264af0a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,3 +40,18 @@ jobs: run: make test-doc env: MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + + test_examples: + name: Test Documentation + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: 1.76.0 + - name: Run documentation tests + run: make test-doc + env: + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 749a9d8..d6f00f4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,10 @@ - [Requirements](#requirements) - [First setup](#first-setup) - [Optional requirements](#optional-requirements) +- [Local Development](#local-development) - [Test](#test) +- [Documentation](#documentation) + - [Readme](#readme) - [Code of Conduct](#code-of-conduct) - [Commit Message Format](#commit-message-format) @@ -41,6 +44,8 @@ Then edit the `.env` file to set your `MISTRAL_API_KEY`. - [cargo-llvm-cov](https://github.com/taiki-e/cargo-llvm-cov?tab=readme-ov-file#installation) for `make test-cover` - [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-watch`. +## Local Development + ### Test ```sh @@ -53,6 +58,16 @@ or make test-watch ``` +## Documentation + +### Readme + +> [!IMPORTANT] +> Do not edit the `README.md` file directly. It is generated from the `README.template.md` file. + +1. Edit the `README.template.md` file. +2. Run `make readme` to generate/update the `README.md` file. + ## Code of Conduct Help us keep this project open and inclusive. Please read and follow our [Code of Conduct](./CODE_OF_CONDUCT.md). diff --git a/Cargo.toml b/Cargo.toml index 3761f3b..bb6f63a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,14 +15,18 @@ readme = "README.md" repository = "https://github.com/ivangabriele/mistralai-client-rs" [dependencies] +async-stream = "0.3.5" +async-trait = "0.1.77" +env_logger = "0.11.3" futures = "0.3.30" +log = "0.4.21" reqwest = { version = "0.11.24", features = ["json", "blocking", "stream"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" strum = "0.26.1" -strum_macros = "0.26.1" thiserror = "1.0.57" tokio = { version = "1.36.0", features = ["full"] } +tokio-stream = "0.1.14" [dev-dependencies] jrest = "0.2.3" diff --git a/Makefile b/Makefile index 6e940ec..02eada5 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ SHELL := /bin/bash -.PHONY: test +.PHONY: doc readme test define source_env_if_not_ci @if [ -z "$${CI}" ]; then \ @@ -26,20 +26,47 @@ doc: cargo doc open ./target/doc/mistralai_client/index.html +readme: + @echo "Generating README.md from template..." + @> README.md # Clear README.md content before starting + @while IFS= read -r line || [[ -n "$$line" ]]; do \ + if [[ $$line == *""* && $$line == *""* ]]; then \ + example_path=$$(echo $$line | sed -n 's/.*\(.*\)<\/CODE>.*/\1/p'); \ + if [ -f $$example_path ]; then \ + echo '```rs' >> README.md; \ + cat $$example_path >> README.md; \ + echo '```' >> README.md; \ + else \ + echo "Error: Example $$example_path not found." >&2; \ + fi; \ + else \ + echo "$$line" >> README.md; \ + fi; \ + done < README.template.md + @echo "README.md has been generated." + release-patch: $(call RELEASE_TEMPLATE,patch) - release-minor: $(call RELEASE_TEMPLATE,minor) - release-major: $(call RELEASE_TEMPLATE,major) test: - @$(source_env_if_not_ci) && cargo test --no-fail-fast + @$(source_env_if_not_ci) + cargo test --no-fail-fast test-cover: - @$(source_env_if_not_ci) && cargo llvm-cov + @$(source_env_if_not_ci) + cargo llvm-cov test-doc: - @$(source_env_if_not_ci) && cargo test --doc --no-fail-fast + @$(source_env_if_not_ci) + cargo test --doc --no-fail-fast +test-examples: + @$(source_env_if_not_ci) + @for example in $$(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do \ + echo "Running $$example"; \ + cargo run --example $$example; \ + done test-watch: - @source ./.env && cargo watch -x "test -- --nocapture" + @source ./.env + cargo watch -x "test -- --nocapture" diff --git a/README.md b/README.md index a828533..3f55332 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,16 @@ Rust client for the Mistral AI API. - [As an environment variable](#as-an-environment-variable) - [As a client argument](#as-a-client-argument) - [Usage](#usage) - - [Chat without streaming](#chat-without-streaming) - - [Chat without streaming (async)](#chat-without-streaming-async) + - [Chat](#chat) + - [Chat (async)](#chat-async) - [Chat with streaming (async)](#chat-with-streaming-async) + - [Chat with Function Calling](#chat-with-function-calling) + - [Chat with Function Calling (async)](#chat-with-function-calling-async) - [Embeddings](#embeddings) - [Embeddings (async)](#embeddings-async) - [List models](#list-models) - [List models (async)](#list-models-async) +- [Contributing](#contributing) --- @@ -34,8 +37,8 @@ Rust client for the Mistral AI API. - [x] Embedding (async) - [x] List models - [x] List models (async) -- [ ] Function Calling -- [ ] Function Calling (async) +- [x] Function Calling +- [x] Function Calling (async) ## Installation @@ -53,6 +56,18 @@ You can get your Mistral API Key there: . Just set the `MISTRAL_API_KEY` environment variable. +```rs +use mistralai_client::v1::client::Client; + +fn main() { + let client = Client::new(None, None, None, None); +} +``` + +```sh +MISTRAL_API_KEY=your_api_key cargo run +``` + #### As a client argument ```rs @@ -67,11 +82,11 @@ fn main() { ## Usage -### Chat without streaming +### Chat ```rs use mistralai_client::v1::{ - chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, + chat::{ChatMessage, ChatMessageRole, ChatParams}, client::Client, constants::Model, }; @@ -82,10 +97,11 @@ fn main() { let model = Model::OpenMistral7b; let messages = vec![ChatMessage { - role: ChatMessageRole::user, + role: ChatMessageRole::User, content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + tool_calls: None, }]; - let options = ChatCompletionRequestOptions { + let options = ChatParams { temperature: Some(0.0), random_seed: Some(42), ..Default::default() @@ -93,15 +109,15 @@ fn main() { let result = client.chat(model, messages, Some(options)).unwrap(); println!("Assistant: {}", result.choices[0].message.content); - // => "Assistant: Tower. [...]" + // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." } ``` -### Chat without streaming (async) +### Chat (async) ```rs use mistralai_client::v1::{ - chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, + chat::{ChatMessage, ChatMessageRole, ChatParams}, client::Client, constants::Model, }; @@ -113,18 +129,25 @@ async fn main() { let model = Model::OpenMistral7b; let messages = vec![ChatMessage { - role: ChatMessageRole::user, + role: ChatMessageRole::User, content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + tool_calls: None, }]; - let options = ChatCompletionRequestOptions { + let options = ChatParams { temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; - let result = client.chat_async(model, messages, Some(options)).await.unwrap(); - println!("Assistant: {}", result.choices[0].message.content); - // => "Assistant: Tower. [...]" + let result = client + .chat_async(model, messages, Some(options)) + .await + .unwrap(); + println!( + "{:?}: {}", + result.choices[0].message.role, result.choices[0].message.content + ); + // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." } ``` @@ -133,38 +156,206 @@ async fn main() { ```rs use futures::stream::StreamExt; use mistralai_client::v1::{ - chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, + chat::{ChatMessage, ChatMessageRole, ChatParams}, client::Client, constants::Model, }; +use std::io::{self, Write}; -[#tokio::main] +#[tokio::main] async fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. - let client = Client::new(None, None, None, None).unwrap(); + let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; let messages = vec![ChatMessage { - role: ChatMessageRole::user, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + role: ChatMessageRole::User, + content: "Tell me a short happy story.".to_string(), + tool_calls: None, }]; - let options = ChatCompletionParams { + let options = ChatParams { temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; - let stream_result = client.chat_stream(model, messages, Some(options)).await; - let mut stream = stream_result.expect("Failed to create stream."); - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - println!("Assistant (message chunk): {}", chunk.choices[0].delta.content); + let stream_result = client + .chat_stream(model, messages, Some(options)) + .await + .expect("Failed to create stream."); + stream_result + .for_each(|chunk_result| async { + match chunk_result { + Ok(chunks) => chunks.iter().for_each(|chunk| { + print!("{}", chunk.choices[0].delta.content); + io::stdout().flush().unwrap(); + // => "Once upon a time, [...]" + }), + Err(error) => { + eprintln!("Error processing chunk: {:?}", error) + } } - Err(e) => eprintln!("Error processing chunk: {:?}", e), - } + }) + .await; + + print!("\n") // To persist the last chunk output. +} +``` + +### Chat with Function Calling + +```rs +use mistralai_client::v1::{ + chat::{ChatMessage, ChatMessageRole, ChatParams}, + client::Client, + constants::Model, + tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, +}; +use serde::Deserialize; +use std::any::Any; + +#[derive(Debug, Deserialize)] +struct GetCityTemperatureArguments { + city: String, +} + +struct GetCityTemperatureFunction; +#[async_trait::async_trait] +impl Function for GetCityTemperatureFunction { + async fn execute(&self, arguments: String) -> Box { + // Deserialize arguments, perform the logic, and return the result + let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); + + let temperature = match city.as_str() { + "Paris" => "20°C", + _ => "Unknown city", + }; + + Box::new(temperature.to_string()) } } + +fn main() { + let tools = vec![Tool::new( + "get_city_temperature".to_string(), + "Get the current temperature in a city.".to_string(), + vec![ToolFunctionParameter::new( + "city".to_string(), + "The name of the city.".to_string(), + ToolFunctionParameterType::String, + )], + )]; + + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let mut client = Client::new(None, None, None, None).unwrap(); + client.register_function( + "get_city_temperature".to_string(), + Box::new(GetCityTemperatureFunction), + ); + + let model = Model::MistralSmallLatest; + let messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "What's the temperature in Paris?".to_string(), + tool_calls: None, + }]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + tool_choice: Some(ToolChoice::Auto), + tools: Some(tools), + ..Default::default() + }; + + client.chat(model, messages, Some(options)).unwrap(); + let temperature = client + .get_last_function_call_result() + .unwrap() + .downcast::() + .unwrap(); + println!("The temperature in Paris is: {}.", temperature); + // => "The temperature in Paris is: 20°C." +} +``` + +### Chat with Function Calling (async) + +```rs +use mistralai_client::v1::{ + chat::{ChatMessage, ChatMessageRole, ChatParams}, + client::Client, + constants::Model, + tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, +}; +use serde::Deserialize; +use std::any::Any; + +#[derive(Debug, Deserialize)] +struct GetCityTemperatureArguments { + city: String, +} + +struct GetCityTemperatureFunction; +#[async_trait::async_trait] +impl Function for GetCityTemperatureFunction { + async fn execute(&self, arguments: String) -> Box { + // Deserialize arguments, perform the logic, and return the result + let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); + + let temperature = match city.as_str() { + "Paris" => "20°C", + _ => "Unknown city", + }; + + Box::new(temperature.to_string()) + } +} + +#[tokio::main] +async fn main() { + let tools = vec![Tool::new( + "get_city_temperature".to_string(), + "Get the current temperature in a city.".to_string(), + vec![ToolFunctionParameter::new( + "city".to_string(), + "The name of the city.".to_string(), + ToolFunctionParameterType::String, + )], + )]; + + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let mut client = Client::new(None, None, None, None).unwrap(); + client.register_function( + "get_city_temperature".to_string(), + Box::new(GetCityTemperatureFunction), + ); + + let model = Model::MistralSmallLatest; + let messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "What's the temperature in Paris?".to_string(), + tool_calls: None, + }]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + tool_choice: Some(ToolChoice::Auto), + tools: Some(tools), + ..Default::default() + }; + + client + .chat_async(model, messages, Some(options)) + .await + .unwrap(); + let temperature = client + .get_last_function_call_result() + .unwrap() + .downcast::() + .unwrap(); + println!("The temperature in Paris is: {}.", temperature); + // => "The temperature in Paris is: 20°C." +} ``` ### Embeddings @@ -174,18 +365,18 @@ use mistralai_client::v1::{client::Client, constants::EmbedModel}; fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. - let client: Client = Client::new(None, None, None, None).unwrap(); + let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; - let input = vec!["Embed this sentence.", "As well as this one."] - .iter() - .map(|s| s.to_string()) - .collect(); - let options = None; + let model = EmbedModel::MistralEmbed; + let input = vec!["Embed this sentence.", "As well as this one."] + .iter() + .map(|s| s.to_string()) + .collect(); + let options = None; - let response = client.embeddings(model, input, options).unwrap(); - println!("Embeddings: {:?}", response.data); - // => "Embeddings: [{...}, {...}]" + let response = client.embeddings(model, input, options).unwrap(); + println!("First Embedding: {:?}", response.data[0]); + // => "First Embedding: {...}" } ``` @@ -197,18 +388,21 @@ use mistralai_client::v1::{client::Client, constants::EmbedModel}; #[tokio::main] async fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. - let client: Client = Client::new(None, None, None, None).unwrap(); + let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; - let input = vec!["Embed this sentence.", "As well as this one."] - .iter() - .map(|s| s.to_string()) - .collect(); - let options = None; + let model = EmbedModel::MistralEmbed; + let input = vec!["Embed this sentence.", "As well as this one."] + .iter() + .map(|s| s.to_string()) + .collect(); + let options = None; - let response = client.embeddings_async(model, input, options).await.unwrap(); - println!("Embeddings: {:?}", response.data); - // => "Embeddings: [{...}, {...}]" + let response = client + .embeddings_async(model, input, options) + .await + .unwrap(); + println!("First Embedding: {:?}", response.data[0]); + // => "First Embedding: {...}" } ``` @@ -235,10 +429,14 @@ use mistralai_client::v1::client::Client; #[tokio::main] async fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. - let client = Client::new(None, None, None, None).await.unwrap(); + let client = Client::new(None, None, None, None).unwrap(); - let result = client.list_models_async().unwrap(); + let result = client.list_models_async().await.unwrap(); println!("First Model ID: {:?}", result.data[0].id); // => "First Model ID: open-mistral-7b" } ``` + +## Contributing + +Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for details on how to contribute to this library. diff --git a/README.template.md b/README.template.md new file mode 100644 index 0000000..5f3d496 --- /dev/null +++ b/README.template.md @@ -0,0 +1,123 @@ +# Mistral AI Rust Client + +[![Crates.io Package](https://img.shields.io/crates/v/mistralai-client?style=for-the-badge)](https://crates.io/crates/mistralai-client) +[![Docs.rs Documentation](https://img.shields.io/docsrs/mistralai-client/latest?style=for-the-badge)](https://docs.rs/mistralai-client/latest/mistralai-client) +[![Test Workflow Status](https://img.shields.io/github/actions/workflow/status/ivangabriele/mistralai-client-rs/test.yml?label=CI&style=for-the-badge)](https://github.com/ivangabriele/mistralai-client-rs/actions?query=branch%3Amain+workflow%3ATest++) +[![Code Coverage](https://img.shields.io/codecov/c/github/ivangabriele/mistralai-client-rs/main?label=Cov&style=for-the-badge)](https://app.codecov.io/github/ivangabriele/mistralai-client-rs) + +Rust client for the Mistral AI API. + +--- + +- [Supported APIs](#supported-apis) +- [Installation](#installation) + - [Mistral API Key](#mistral-api-key) + - [As an environment variable](#as-an-environment-variable) + - [As a client argument](#as-a-client-argument) +- [Usage](#usage) + - [Chat](#chat) + - [Chat (async)](#chat-async) + - [Chat with streaming (async)](#chat-with-streaming-async) + - [Chat with Function Calling](#chat-with-function-calling) + - [Chat with Function Calling (async)](#chat-with-function-calling-async) + - [Embeddings](#embeddings) + - [Embeddings (async)](#embeddings-async) + - [List models](#list-models) + - [List models (async)](#list-models-async) +- [Contributing](#contributing) + +--- + +## Supported APIs + +- [x] Chat without streaming +- [x] Chat without streaming (async) +- [x] Chat with streaming +- [x] Embedding +- [x] Embedding (async) +- [x] List models +- [x] List models (async) +- [x] Function Calling +- [x] Function Calling (async) + +## Installation + +You can install the library in your project using: + +```sh +cargo add mistralai-client +``` + +### Mistral API Key + +You can get your Mistral API Key there: . + +#### As an environment variable + +Just set the `MISTRAL_API_KEY` environment variable. + +```rs +use mistralai_client::v1::client::Client; + +fn main() { + let client = Client::new(None, None, None, None); +} +``` + +```sh +MISTRAL_API_KEY=your_api_key cargo run +``` + +#### As a client argument + +```rs +use mistralai_client::v1::client::Client; + +fn main() { + let api_key = "your_api_key"; + + let client = Client::new(Some(api_key), None, None, None).unwrap(); +} +``` + +## Usage + +### Chat + +examples/chat.rs + +### Chat (async) + +examples/chat_async.rs + +### Chat with streaming (async) + +examples/chat_with_streaming.rs + +### Chat with Function Calling + +examples/chat_with_function_calling.rs + +### Chat with Function Calling (async) + +examples/chat_with_function_calling_async.rs + +### Embeddings + +examples/embeddings.rs + +### Embeddings (async) + +examples/embeddings_async.rs + +### List models + +examples/list_models.rs + +### List models (async) + +examples/list_models_async.rs + +## Contributing + +Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for details on how to contribute to this library. diff --git a/examples/chat.rs b/examples/chat.rs new file mode 100644 index 0000000..12d5fd4 --- /dev/null +++ b/examples/chat.rs @@ -0,0 +1,26 @@ +use mistralai_client::v1::{ + chat::{ChatMessage, ChatMessageRole, ChatParams}, + client::Client, + constants::Model, +}; + +fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::OpenMistral7b; + let messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + tool_calls: None, + }]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let result = client.chat(model, messages, Some(options)).unwrap(); + println!("Assistant: {}", result.choices[0].message.content); + // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." +} diff --git a/examples/chat_async.rs b/examples/chat_async.rs new file mode 100644 index 0000000..7034553 --- /dev/null +++ b/examples/chat_async.rs @@ -0,0 +1,33 @@ +use mistralai_client::v1::{ + chat::{ChatMessage, ChatMessageRole, ChatParams}, + client::Client, + constants::Model, +}; + +#[tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::OpenMistral7b; + let messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + tool_calls: None, + }]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let result = client + .chat_async(model, messages, Some(options)) + .await + .unwrap(); + println!( + "{:?}: {}", + result.choices[0].message.role, result.choices[0].message.content + ); + // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." +} diff --git a/examples/chat_with_function_calling.rs b/examples/chat_with_function_calling.rs new file mode 100644 index 0000000..991d2f2 --- /dev/null +++ b/examples/chat_with_function_calling.rs @@ -0,0 +1,71 @@ +use mistralai_client::v1::{ + chat::{ChatMessage, ChatMessageRole, ChatParams}, + client::Client, + constants::Model, + tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, +}; +use serde::Deserialize; +use std::any::Any; + +#[derive(Debug, Deserialize)] +struct GetCityTemperatureArguments { + city: String, +} + +struct GetCityTemperatureFunction; +#[async_trait::async_trait] +impl Function for GetCityTemperatureFunction { + async fn execute(&self, arguments: String) -> Box { + // Deserialize arguments, perform the logic, and return the result + let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); + + let temperature = match city.as_str() { + "Paris" => "20°C", + _ => "Unknown city", + }; + + Box::new(temperature.to_string()) + } +} + +fn main() { + let tools = vec![Tool::new( + "get_city_temperature".to_string(), + "Get the current temperature in a city.".to_string(), + vec![ToolFunctionParameter::new( + "city".to_string(), + "The name of the city.".to_string(), + ToolFunctionParameterType::String, + )], + )]; + + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let mut client = Client::new(None, None, None, None).unwrap(); + client.register_function( + "get_city_temperature".to_string(), + Box::new(GetCityTemperatureFunction), + ); + + let model = Model::MistralSmallLatest; + let messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "What's the temperature in Paris?".to_string(), + tool_calls: None, + }]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + tool_choice: Some(ToolChoice::Auto), + tools: Some(tools), + ..Default::default() + }; + + client.chat(model, messages, Some(options)).unwrap(); + let temperature = client + .get_last_function_call_result() + .unwrap() + .downcast::() + .unwrap(); + println!("The temperature in Paris is: {}.", temperature); + // => "The temperature in Paris is: 20°C." +} diff --git a/examples/chat_with_function_calling_async.rs b/examples/chat_with_function_calling_async.rs new file mode 100644 index 0000000..0fb5213 --- /dev/null +++ b/examples/chat_with_function_calling_async.rs @@ -0,0 +1,75 @@ +use mistralai_client::v1::{ + chat::{ChatMessage, ChatMessageRole, ChatParams}, + client::Client, + constants::Model, + tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, +}; +use serde::Deserialize; +use std::any::Any; + +#[derive(Debug, Deserialize)] +struct GetCityTemperatureArguments { + city: String, +} + +struct GetCityTemperatureFunction; +#[async_trait::async_trait] +impl Function for GetCityTemperatureFunction { + async fn execute(&self, arguments: String) -> Box { + // Deserialize arguments, perform the logic, and return the result + let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); + + let temperature = match city.as_str() { + "Paris" => "20°C", + _ => "Unknown city", + }; + + Box::new(temperature.to_string()) + } +} + +#[tokio::main] +async fn main() { + let tools = vec![Tool::new( + "get_city_temperature".to_string(), + "Get the current temperature in a city.".to_string(), + vec![ToolFunctionParameter::new( + "city".to_string(), + "The name of the city.".to_string(), + ToolFunctionParameterType::String, + )], + )]; + + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let mut client = Client::new(None, None, None, None).unwrap(); + client.register_function( + "get_city_temperature".to_string(), + Box::new(GetCityTemperatureFunction), + ); + + let model = Model::MistralSmallLatest; + let messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "What's the temperature in Paris?".to_string(), + tool_calls: None, + }]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + tool_choice: Some(ToolChoice::Auto), + tools: Some(tools), + ..Default::default() + }; + + client + .chat_async(model, messages, Some(options)) + .await + .unwrap(); + let temperature = client + .get_last_function_call_result() + .unwrap() + .downcast::() + .unwrap(); + println!("The temperature in Paris is: {}.", temperature); + // => "The temperature in Paris is: 20°C." +} diff --git a/examples/chat_with_streaming.rs b/examples/chat_with_streaming.rs new file mode 100644 index 0000000..8515a45 --- /dev/null +++ b/examples/chat_with_streaming.rs @@ -0,0 +1,45 @@ +use futures::stream::StreamExt; +use mistralai_client::v1::{ + chat::{ChatMessage, ChatMessageRole, ChatParams}, + client::Client, + constants::Model, +}; +use std::io::{self, Write}; + +#[tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::OpenMistral7b; + let messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "Tell me a short happy story.".to_string(), + tool_calls: None, + }]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let stream_result = client + .chat_stream(model, messages, Some(options)) + .await + .unwrap(); + stream_result + .for_each(|chunk_result| async { + match chunk_result { + Ok(chunks) => chunks.iter().for_each(|chunk| { + print!("{}", chunk.choices[0].delta.content); + io::stdout().flush().unwrap(); + // => "Once upon a time, [...]" + }), + Err(error) => { + eprintln!("Error processing chunk: {:?}", error) + } + } + }) + .await; + print!("\n") // To persist the last chunk output. +} diff --git a/examples/embeddings.rs b/examples/embeddings.rs new file mode 100644 index 0000000..898e7d4 --- /dev/null +++ b/examples/embeddings.rs @@ -0,0 +1,17 @@ +use mistralai_client::v1::{client::Client, constants::EmbedModel}; + +fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client: Client = Client::new(None, None, None, None).unwrap(); + + let model = EmbedModel::MistralEmbed; + let input = vec!["Embed this sentence.", "As well as this one."] + .iter() + .map(|s| s.to_string()) + .collect(); + let options = None; + + let response = client.embeddings(model, input, options).unwrap(); + println!("First Embedding: {:?}", response.data[0]); + // => "First Embedding: {...}" +} diff --git a/examples/embeddings_async.rs b/examples/embeddings_async.rs new file mode 100644 index 0000000..a93d374 --- /dev/null +++ b/examples/embeddings_async.rs @@ -0,0 +1,21 @@ +use mistralai_client::v1::{client::Client, constants::EmbedModel}; + +#[tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client: Client = Client::new(None, None, None, None).unwrap(); + + let model = EmbedModel::MistralEmbed; + let input = vec!["Embed this sentence.", "As well as this one."] + .iter() + .map(|s| s.to_string()) + .collect(); + let options = None; + + let response = client + .embeddings_async(model, input, options) + .await + .unwrap(); + println!("First Embedding: {:?}", response.data[0]); + // => "First Embedding: {...}" +} diff --git a/examples/list_models.rs b/examples/list_models.rs new file mode 100644 index 0000000..7551b0a --- /dev/null +++ b/examples/list_models.rs @@ -0,0 +1,10 @@ +use mistralai_client::v1::client::Client; + +fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let result = client.list_models().unwrap(); + println!("First Model ID: {:?}", result.data[0].id); + // => "First Model ID: open-mistral-7b" +} diff --git a/examples/list_models_async.rs b/examples/list_models_async.rs new file mode 100644 index 0000000..a801928 --- /dev/null +++ b/examples/list_models_async.rs @@ -0,0 +1,11 @@ +use mistralai_client::v1::client::Client; + +#[tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let result = client.list_models_async().await.unwrap(); + println!("First Model ID: {:?}", result.data[0].id); + // => "First Model ID: open-mistral-7b" +} diff --git a/src/v1/chat_completion.rs b/src/v1/chat.rs similarity index 64% rename from src/v1/chat_completion.rs rename to src/v1/chat.rs index 701764b..20853d2 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::v1::{common, constants}; +use crate::v1::{common, constants, tool}; // ----------------------------------------------------------------------------- // Definitions @@ -9,88 +9,115 @@ use crate::v1::{common, constants}; pub struct ChatMessage { pub role: ChatMessageRole, pub content: String, + pub tool_calls: Option>, +} +impl ChatMessage { + pub fn new_assistant_message(content: &str, tool_calls: Option>) -> Self { + Self { + role: ChatMessageRole::Assistant, + content: content.to_string(), + tool_calls, + } + } + + pub fn new_user_message(content: &str) -> Self { + Self { + role: ChatMessageRole::User, + content: content.to_string(), + tool_calls: None, + } + } } -#[derive(Clone, Debug, strum_macros::Display, Eq, PartialEq, Deserialize, Serialize)] -#[allow(non_camel_case_types)] +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] pub enum ChatMessageRole { - assistant, - user, + #[serde(rename = "assistant")] + Assistant, + #[serde(rename = "user")] + User, } // ----------------------------------------------------------------------------- // Request #[derive(Debug)] -pub struct ChatCompletionParams { - pub tools: Option, - pub temperature: Option, +pub struct ChatParams { pub max_tokens: Option, - pub top_p: Option, pub random_seed: Option, pub safe_prompt: Option, + pub temperature: Option, + pub tool_choice: Option, + pub tools: Option>, + pub top_p: Option, } -impl Default for ChatCompletionParams { +impl Default for ChatParams { fn default() -> Self { Self { - tools: None, - temperature: None, max_tokens: None, - top_p: None, random_seed: None, safe_prompt: None, + temperature: None, + tool_choice: None, + tools: None, + top_p: None, } } } #[derive(Debug, Serialize, Deserialize)] -pub struct ChatCompletionRequest { +pub struct ChatRequest { pub messages: Vec, pub model: constants::Model, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub random_seed: Option, - pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub safe_prompt: Option, + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, // TODO Check this prop (seen in official Python client but not in API doc). // pub tool_choice: Option, // TODO Check this prop (seen in official Python client but not in API doc). // pub response_format: Option, } -impl ChatCompletionRequest { +impl ChatRequest { pub fn new( model: constants::Model, messages: Vec, stream: bool, - options: Option, + options: Option, ) -> Self { - let ChatCompletionParams { - tools, - temperature, + let ChatParams { max_tokens, - top_p, random_seed, safe_prompt, + temperature, + tool_choice, + tools, + top_p, } = options.unwrap_or_default(); Self { messages, model, - tools, - temperature, + max_tokens, - top_p, random_seed, - stream, safe_prompt, + stream, + temperature, + tool_choice, + tools, + top_p, } } } @@ -99,51 +126,29 @@ impl ChatCompletionRequest { // Response #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ChatCompletionResponse { +pub struct ChatResponse { pub id: String, pub object: String, /// Unix timestamp (in seconds). pub created: u32, pub model: constants::Model, - pub choices: Vec, + pub choices: Vec, pub usage: common::ResponseUsage, } #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ChatCompletionResponseChoice { +pub struct ChatResponseChoice { pub index: u32, pub message: ChatMessage, - pub finish_reason: String, + pub finish_reason: ChatResponseChoiceFinishReason, // TODO Check this prop (seen in API responses but undocumented). // pub logprobs: ??? } -// ----------------------------------------------------------------------------- -// Stream - -#[derive(Debug, Deserialize)] -pub struct ChatCompletionStreamChunk { - pub id: String, - pub object: String, - /// Unix timestamp (in seconds). - pub created: u32, - pub model: constants::Model, - pub choices: Vec, - // TODO Check this prop (seen in API responses but undocumented). - // pub usage: ???, -} - -#[derive(Debug, Deserialize)] -pub struct ChatCompletionStreamChunkChoice { - pub index: u32, - pub delta: ChatCompletionStreamChunkChoiceDelta, - pub finish_reason: Option, - // TODO Check this prop (seen in API responses but undocumented). - // pub logprobs: ???, -} - -#[derive(Debug, Deserialize)] -pub struct ChatCompletionStreamChunkChoiceDelta { - pub role: Option, - pub content: String, +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +pub enum ChatResponseChoiceFinishReason { + #[serde(rename = "stop")] + Stop, + #[serde(rename = "tool_calls")] + ToolCalls, } diff --git a/src/v1/chat_stream.rs b/src/v1/chat_stream.rs new file mode 100644 index 0000000..1daf481 --- /dev/null +++ b/src/v1/chat_stream.rs @@ -0,0 +1,57 @@ +use serde::{Deserialize, Serialize}; +use serde_json::from_str; + +use crate::v1::{chat, common, constants, error}; + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ChatStreamChunk { + pub id: String, + pub object: String, + /// Unix timestamp (in seconds). + pub created: u32, + pub model: constants::Model, + pub choices: Vec, + pub usage: Option, + // TODO Check this prop (seen in API responses but undocumented). + // pub logprobs: ???, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ChatStreamChunkChoice { + pub index: u32, + pub delta: ChatStreamChunkChoiceDelta, + pub finish_reason: Option, + // TODO Check this prop (seen in API responses but undocumented). + // pub logprobs: ???, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ChatStreamChunkChoiceDelta { + pub role: Option, + pub content: String, +} + +/// Extracts serialized chunks from a stream message. +pub fn get_chunk_from_stream_message_line( + line: &str, +) -> Result>, error::ApiError> { + if line.trim() == "data: [DONE]" { + return Ok(None); + } + + let chunk_as_json = line.trim_start_matches("data: ").trim(); + if chunk_as_json.is_empty() { + return Ok(Some(vec![])); + } + + // Attempt to deserialize the JSON string into ChatStreamChunk + match from_str::(chunk_as_json) { + Ok(chunk) => Ok(Some(vec![chunk])), + Err(e) => Err(error::ApiError { + message: e.to_string(), + }), + } +} diff --git a/src/v1/client.rs b/src/v1/client.rs index 02082a6..89b7ea5 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -1,27 +1,23 @@ use futures::stream::StreamExt; use futures::Stream; +use log::debug; use reqwest::Error as ReqwestError; -use serde_json::from_str; - -use crate::v1::error::ApiError; - -use crate::v1::{ - chat_completion::{ - ChatCompletionParams, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, - }, - constants::{EmbedModel, Model, API_URL_BASE}, - embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse}, - error::ClientError, - model_list::ModelListResponse, +use std::{ + any::Any, + collections::HashMap, + sync::{Arc, Mutex}, }; -use super::chat_completion::ChatCompletionStreamChunk; +use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils}; pub struct Client { pub api_key: String, pub endpoint: String, pub max_retries: u32, pub timeout: u32, + + functions: Arc>>>, + last_function_call_result: Arc>>>, } impl Client { @@ -53,20 +49,28 @@ impl Client { endpoint: Option, max_retries: Option, timeout: Option, - ) -> Result { + ) -> Result { let api_key = match api_key { Some(api_key_from_param) => api_key_from_param, - None => std::env::var("MISTRAL_API_KEY").map_err(|_| ClientError::MissingApiKey)?, + None => { + std::env::var("MISTRAL_API_KEY").map_err(|_| error::ClientError::MissingApiKey)? + } }; - let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string()); + let endpoint = endpoint.unwrap_or(constants::API_URL_BASE.to_string()); let max_retries = max_retries.unwrap_or(5); let timeout = timeout.unwrap_or(120); + let functions: Arc<_> = Arc::new(Mutex::new(HashMap::new())); + let last_function_call_result = Arc::new(Mutex::new(None)); + Ok(Self { api_key, endpoint, max_retries, timeout, + + functions, + last_function_call_result, }) } @@ -76,42 +80,49 @@ impl Client { /// /// * `model` - The [Model] to use for the chat completion. /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatCompletionParams] to customize the request. + /// * `options` - Optional [ChatParams] to customize the request. /// /// # Returns /// - /// Returns a [Result] containing the `ChatCompletionResponse` if the request is successful, + /// Returns a [Result] containing the `ChatResponse` if the request is successful, /// or an [ApiError] if there is an error. /// /// # Examples /// /// ``` /// use mistralai_client::v1::{ - /// chat_completion::{ChatMessage, ChatMessageRole}, + /// chat::{ChatMessage, ChatMessageRole}, /// client::Client, /// constants::Model, /// }; /// /// let client = Client::new(None, None, None, None).unwrap(); /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::user, + /// role: ChatMessageRole::User, /// content: "Hello, world!".to_string(), + /// tool_calls: None, /// }]; /// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap(); - /// println!("{}: {}", response.choices[0].message.role, response.choices[0].message.content); + /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); /// ``` pub fn chat( &self, - model: Model, - messages: Vec, - options: Option, - ) -> Result { - let request = ChatCompletionRequest::new(model, messages, false, options); + model: constants::Model, + messages: Vec, + options: Option, + ) -> Result { + let request = chat::ChatRequest::new(model, messages, false, options); let response = self.post_sync("/chat/completions", &request)?; - let result = response.json::(); + let result = response.json::(); match result { - Ok(response) => Ok(response), + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + + self.call_function_if_any(data.clone()); + + Ok(data) + } Err(error) => Err(self.to_api_error(error)), } } @@ -122,18 +133,18 @@ impl Client { /// /// * `model` - The [Model] to use for the chat completion. /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatCompletionParams] to customize the request. + /// * `options` - Optional [ChatParams] to customize the request. /// /// # Returns /// - /// Returns a [Result] containing a `Stream` of `ChatCompletionStreamChunk` if the request is successful, + /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, /// or an [ApiError] if there is an error. /// /// # Examples /// /// ``` /// use mistralai_client::v1::{ - /// chat_completion::{ChatMessage, ChatMessageRole}, + /// chat::{ChatMessage, ChatMessageRole}, /// client::Client, /// constants::Model, /// }; @@ -142,25 +153,32 @@ impl Client { /// async fn main() { /// let client = Client::new(None, None, None, None).unwrap(); /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::user, + /// role: ChatMessageRole::User, /// content: "Hello, world!".to_string(), + /// tool_calls: None, /// }]; /// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap(); - /// println!("{}: {}", response.choices[0].message.role, response.choices[0].message.content); + /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); /// } /// ``` pub async fn chat_async( &self, - model: Model, - messages: Vec, - options: Option, - ) -> Result { - let request = ChatCompletionRequest::new(model, messages, false, options); + model: constants::Model, + messages: Vec, + options: Option, + ) -> Result { + let request = chat::ChatRequest::new(model, messages, false, options); let response = self.post_async("/chat/completions", &request).await?; - let result = response.json::().await; + let result = response.json::().await; match result { - Ok(response) => Ok(response), + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + + self.call_function_if_any_async(data.clone()).await; + + Ok(data) + } Err(error) => Err(self.to_api_error(error)), } } @@ -171,11 +189,11 @@ impl Client { /// /// * `model` - The [Model] to use for the chat completion. /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatCompletionParams] to customize the request. + /// * `options` - Optional [ChatParams] to customize the request. /// /// # Returns /// - /// Returns a [Result] containing a `Stream` of `ChatCompletionStreamChunk` if the request is successful, + /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, /// or an [ApiError] if there is an error. /// /// # Examples @@ -183,124 +201,176 @@ impl Client { /// ``` /// use futures::stream::StreamExt; /// use mistralai_client::v1::{ - /// chat_completion::{ChatMessage, ChatMessageRole}, + /// chat::{ChatMessage, ChatMessageRole}, /// client::Client, /// constants::Model, /// }; + /// use std::io::{self, Write}; /// /// #[tokio::main] /// async fn main() { /// let client = Client::new(None, None, None, None).unwrap(); /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::user, + /// role: ChatMessageRole::User, /// content: "Hello, world!".to_string(), + /// tool_calls: None, /// }]; - /// let mut stream = client.chat_stream(Model::OpenMistral7b, messages, None).await.unwrap(); - /// while let Some(chunk_result) = stream.next().await { - /// match chunk_result { - /// Ok(chunk) => { - /// print!("{}", chunk.choices[0].delta.content); + /// + /// let stream_result = client + /// .chat_stream(Model::OpenMistral7b,messages, None) + /// .await + /// .unwrap(); + /// stream_result + /// .for_each(|chunk_result| async { + /// match chunk_result { + /// Ok(chunks) => chunks.iter().for_each(|chunk| { + /// print!("{}", chunk.choices[0].delta.content); + /// io::stdout().flush().unwrap(); + /// // => "Once upon a time, [...]" + /// }), + /// Err(error) => { + /// eprintln!("Error processing chunk: {:?}", error) + /// } /// } - /// Err(error) => { - /// println!("Error: {}", error.message); - /// } - /// } - /// } + /// }) + /// .await; + /// print!("\n") // To persist the last chunk output. /// } pub async fn chat_stream( &self, - model: Model, - messages: Vec, - options: Option, - ) -> Result>, ApiError> { - let request = ChatCompletionRequest::new(model, messages, true, options); + model: constants::Model, + messages: Vec, + options: Option, + ) -> Result< + impl Stream, error::ApiError>>, + error::ApiError, + > { + let request = chat::ChatRequest::new(model, messages, true, options); let response = self .post_stream("/chat/completions", &request) .await - .map_err(|e| ApiError { + .map_err(|e| error::ApiError { message: e.to_string(), })?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); - return Err(ApiError { + return Err(error::ApiError { message: format!("{}: {}", status, text), }); } - let deserialized_stream = - response - .bytes_stream() - .map(|item| -> Result { - match item { - Ok(bytes) => { - let text = String::from_utf8(bytes.to_vec()).map_err(|e| ApiError { - message: e.to_string(), - })?; - let text_trimmed = text.trim_start_matches("data: "); - from_str(&text_trimmed).map_err(|e| ApiError { - message: e.to_string(), - }) - } - Err(e) => Err(ApiError { - message: e.to_string(), - }), + let deserialized_stream = response.bytes_stream().then(|bytes_result| async move { + match bytes_result { + Ok(bytes) => match String::from_utf8(bytes.to_vec()) { + Ok(message) => { + let chunks = message + .lines() + .filter_map( + |line| match chat_stream::get_chunk_from_stream_message_line(line) { + Ok(Some(chunks)) => Some(chunks), + Ok(None) => None, + Err(_error) => None, + }, + ) + .flatten() + .collect(); + + Ok(chunks) } - }); + Err(e) => Err(error::ApiError { + message: e.to_string(), + }), + }, + Err(e) => Err(error::ApiError { + message: e.to_string(), + }), + } + }); Ok(deserialized_stream) } pub fn embeddings( &self, - model: EmbedModel, + model: constants::EmbedModel, input: Vec, - options: Option, - ) -> Result { - let request = EmbeddingRequest::new(model, input, options); + options: Option, + ) -> Result { + let request = embedding::EmbeddingRequest::new(model, input, options); let response = self.post_sync("/embeddings", &request)?; - let result = response.json::(); + let result = response.json::(); match result { - Ok(response) => Ok(response), + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + + Ok(data) + } Err(error) => Err(self.to_api_error(error)), } } pub async fn embeddings_async( &self, - model: EmbedModel, + model: constants::EmbedModel, input: Vec, - options: Option, - ) -> Result { - let request = EmbeddingRequest::new(model, input, options); + options: Option, + ) -> Result { + let request = embedding::EmbeddingRequest::new(model, input, options); let response = self.post_async("/embeddings", &request).await?; - let result = response.json::().await; + let result = response.json::().await; match result { - Ok(response) => Ok(response), + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + + Ok(data) + } Err(error) => Err(self.to_api_error(error)), } } - pub fn list_models(&self) -> Result { + pub fn get_last_function_call_result(&self) -> Option> { + let mut result_lock = self.last_function_call_result.lock().unwrap(); + + result_lock.take() + } + + pub fn list_models(&self) -> Result { let response = self.get_sync("/models")?; - let result = response.json::(); + let result = response.json::(); match result { - Ok(response) => Ok(response), + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + + Ok(data) + } Err(error) => Err(self.to_api_error(error)), } } - pub async fn list_models_async(&self) -> Result { + pub async fn list_models_async( + &self, + ) -> Result { let response = self.get_async("/models").await?; - let result = response.json::().await; + let result = response.json::().await; match result { - Ok(response) => Ok(response), + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + + Ok(data) + } Err(error) => Err(self.to_api_error(error)), } } + pub fn register_function(&mut self, name: String, function: Box) { + let mut functions = self.functions.lock().unwrap(); + + functions.insert(name, function); + } + fn build_request_sync( &self, request: reqwest::blocking::RequestBuilder, @@ -349,9 +419,70 @@ impl Client { request_builder } - fn get_sync(&self, path: &str) -> Result { + fn call_function_if_any(&self, response: chat::ChatResponse) -> () { + let next_result = match response.choices.get(0) { + Some(first_choice) => match first_choice.message.tool_calls.to_owned() { + Some(tool_calls) => match tool_calls.get(0) { + Some(first_tool_call) => { + let functions = self.functions.lock().unwrap(); + match functions.get(&first_tool_call.function.name) { + Some(function) => { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let result = runtime.block_on(async { + function + .execute(first_tool_call.function.arguments.to_owned()) + .await + }); + + Some(result) + } + None => None, + } + } + None => None, + }, + None => None, + }, + None => None, + }; + + let mut last_result_lock = self.last_function_call_result.lock().unwrap(); + *last_result_lock = next_result; + } + + async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () { + let next_result = match response.choices.get(0) { + Some(first_choice) => match first_choice.message.tool_calls.to_owned() { + Some(tool_calls) => match tool_calls.get(0) { + Some(first_tool_call) => { + let functions = self.functions.lock().unwrap(); + match functions.get(&first_tool_call.function.name) { + Some(function) => { + let result = function + .execute(first_tool_call.function.arguments.to_owned()) + .await; + + Some(result) + } + None => None, + } + } + None => None, + }, + None => None, + }, + None => None, + }; + + let mut last_result_lock = self.last_function_call_result.lock().unwrap(); + *last_result_lock = next_result; + } + + fn get_sync(&self, path: &str) -> Result { let reqwest_client = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + let request = self.build_request_sync(reqwest_client.get(url)); let result = request.send(); @@ -360,22 +491,27 @@ impl Client { if response.status().is_success() { Ok(response) } else { - let status = response.status(); - let text = response.text().unwrap(); - Err(ApiError { - message: format!("{}: {}", status, text), + let response_status = response.status(); + let response_body = response.text().unwrap_or_default(); + debug!("Response Status: {}", &response_status); + utils::debug_pretty_json_from_string("Response Data", &response_body); + + Err(error::ApiError { + message: format!("{}: {}", response_status, response_body), }) } } - Err(error) => Err(ApiError { + Err(error) => Err(error::ApiError { message: error.to_string(), }), } } - async fn get_async(&self, path: &str) -> Result { + async fn get_async(&self, path: &str) -> Result { let reqwest_client = reqwest::Client::new(); let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + let request_builder = reqwest_client.get(url); let request = self.build_request_async(request_builder); @@ -385,26 +521,32 @@ impl Client { if response.status().is_success() { Ok(response) } else { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), + let response_status = response.status(); + let response_body = response.text().await.unwrap_or_default(); + debug!("Response Status: {}", &response_status); + utils::debug_pretty_json_from_string("Response Data", &response_body); + + Err(error::ApiError { + message: format!("{}: {}", response_status, response_body), }) } } - Err(error) => Err(ApiError { + Err(error) => Err(error::ApiError { message: error.to_string(), }), } } - fn post_sync( + fn post_sync( &self, path: &str, params: &T, - ) -> Result { + ) -> Result { let reqwest_client = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + utils::debug_pretty_json_from_struct("Request Body", params); + let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_sync(request_builder); @@ -414,14 +556,17 @@ impl Client { if response.status().is_success() { Ok(response) } else { - let status = response.status(); - let text = response.text().unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), + let response_status = response.status(); + let response_body = response.text().unwrap_or_default(); + debug!("Response Status: {}", &response_status); + utils::debug_pretty_json_from_string("Response Data", &response_body); + + Err(error::ApiError { + message: format!("{}: {}", response_body, response_status), }) } } - Err(error) => Err(ApiError { + Err(error) => Err(error::ApiError { message: error.to_string(), }), } @@ -431,9 +576,12 @@ impl Client { &self, path: &str, params: &T, - ) -> Result { + ) -> Result { let reqwest_client = reqwest::Client::new(); let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + utils::debug_pretty_json_from_struct("Request Body", params); + let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_async(request_builder); @@ -443,14 +591,17 @@ impl Client { if response.status().is_success() { Ok(response) } else { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), + let response_status = response.status(); + let response_body = response.text().await.unwrap_or_default(); + debug!("Response Status: {}", &response_status); + utils::debug_pretty_json_from_string("Response Data", &response_body); + + Err(error::ApiError { + message: format!("{}: {}", response_status, response_body), }) } } - Err(error) => Err(ApiError { + Err(error) => Err(error::ApiError { message: error.to_string(), }), } @@ -460,9 +611,12 @@ impl Client { &self, path: &str, params: &T, - ) -> Result { + ) -> Result { let reqwest_client = reqwest::Client::new(); let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + utils::debug_pretty_json_from_struct("Request Body", params); + let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_stream(request_builder); @@ -472,21 +626,24 @@ impl Client { if response.status().is_success() { Ok(response) } else { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), + let response_status = response.status(); + let response_body = response.text().await.unwrap_or_default(); + debug!("Response Status: {}", &response_status); + utils::debug_pretty_json_from_string("Response Data", &response_body); + + Err(error::ApiError { + message: format!("{}: {}", response_status, response_body), }) } } - Err(error) => Err(ApiError { + Err(error) => Err(error::ApiError { message: error.to_string(), }), } } - fn to_api_error(&self, err: ReqwestError) -> ApiError { - ApiError { + fn to_api_error(&self, err: ReqwestError) -> error::ApiError { + error::ApiError { message: err.to_string(), } } diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs index 7edc719..7d1c2d7 100644 --- a/src/v1/embedding.rs +++ b/src/v1/embedding.rs @@ -2,6 +2,9 @@ use serde::{Deserialize, Serialize}; use crate::v1::{common, constants}; +// ----------------------------------------------------------------------------- +// Request + #[derive(Debug)] pub struct EmbeddingRequestOptions { pub encoding_format: Option, @@ -43,6 +46,9 @@ pub enum EmbeddingRequestEncodingFormat { float, } +// ----------------------------------------------------------------------------- +// Response + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct EmbeddingResponse { pub id: String, diff --git a/src/v1/error.rs b/src/v1/error.rs index ef42391..4364b5f 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -14,7 +14,9 @@ impl Error for ApiError {} #[derive(Debug, PartialEq, thiserror::Error)] pub enum ClientError { - #[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")] + #[error( + "You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...)." + )] MissingApiKey, #[error("Failed to read the response text.")] UnreadableResponseText, diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 8f29a01..72165bb 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -1,7 +1,10 @@ -pub mod chat_completion; +pub mod chat; +pub mod chat_stream; pub mod client; pub mod common; pub mod constants; pub mod embedding; pub mod error; pub mod model_list; +pub mod tool; +pub mod utils; diff --git a/src/v1/model_list.rs b/src/v1/model_list.rs index 54670f4..1ff44bf 100644 --- a/src/v1/model_list.rs +++ b/src/v1/model_list.rs @@ -1,5 +1,8 @@ use serde::{Deserialize, Serialize}; +// ----------------------------------------------------------------------------- +// Response + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelListResponse { pub object: String, diff --git a/src/v1/tool.rs b/src/v1/tool.rs new file mode 100644 index 0000000..2249a40 --- /dev/null +++ b/src/v1/tool.rs @@ -0,0 +1,134 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::{any::Any, collections::HashMap}; + +// ----------------------------------------------------------------------------- +// Definitions + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +pub struct ToolCall { + pub function: ToolCallFunction, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +pub struct ToolCallFunction { + pub name: String, + pub arguments: String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Tool { + pub r#type: ToolType, + pub function: ToolFunction, +} +impl Tool { + pub fn new( + function_name: String, + function_description: String, + function_parameters: Vec, + ) -> Self { + let properties: HashMap = function_parameters + .into_iter() + .map(|param| { + ( + param.name, + ToolFunctionParameterProperty { + r#type: param.r#type, + description: param.description, + }, + ) + }) + .collect(); + let property_names = properties.keys().cloned().collect(); + + let parameters = ToolFunctionParameters { + r#type: ToolFunctionParametersType::Object, + properties, + required: property_names, + }; + + Self { + r#type: ToolType::Function, + function: ToolFunction { + name: function_name, + description: function_description, + parameters, + }, + } + } +} + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolFunction { + name: String, + description: String, + parameters: ToolFunctionParameters, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolFunctionParameter { + name: String, + description: String, + r#type: ToolFunctionParameterType, +} +impl ToolFunctionParameter { + pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self { + Self { + name, + r#type, + description, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolFunctionParameters { + r#type: ToolFunctionParametersType, + properties: HashMap, + required: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolFunctionParameterProperty { + r#type: ToolFunctionParameterType, + description: String, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +pub enum ToolFunctionParametersType { + #[serde(rename = "object")] + Object, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +pub enum ToolFunctionParameterType { + #[serde(rename = "string")] + String, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +pub enum ToolType { + #[serde(rename = "function")] + Function, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +pub enum ToolChoice { + #[serde(rename = "any")] + Any, + #[serde(rename = "auto")] + Auto, + #[serde(rename = "none")] + None, +} + +// ----------------------------------------------------------------------------- +// Custom + +#[async_trait] +pub trait Function { + async fn execute(&self, arguments: String) -> Box; +} diff --git a/src/v1/utils.rs b/src/v1/utils.rs new file mode 100644 index 0000000..9c1de79 --- /dev/null +++ b/src/v1/utils.rs @@ -0,0 +1,32 @@ +use std::fmt::Debug; + +use log::debug; +use serde::Serialize; + +pub fn prettify_json_string(json: &String) -> String { + match serde_json::from_str::(&json) { + Ok(json_value) => { + serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| json.to_owned()) + } + Err(_) => json.to_owned(), + } +} + +pub fn prettify_json_struct(value: T) -> String { + match serde_json::to_string_pretty(&value) { + Ok(pretty_json) => pretty_json, + Err(_) => format!("{:?}", value), + } +} + +pub fn debug_pretty_json_from_string(label: &str, json: &String) -> () { + let pretty_json = prettify_json_string(json); + + debug!("{label}: {}", pretty_json); +} + +pub fn debug_pretty_json_from_struct(label: &str, value: &T) -> () { + let pretty_json = prettify_json_struct(value); + + debug!("{label}: {}", pretty_json); +} diff --git a/tests/setup.rs b/tests/setup.rs new file mode 100644 index 0000000..af191da --- /dev/null +++ b/tests/setup.rs @@ -0,0 +1,3 @@ +pub fn setup() { + let _ = env_logger::builder().is_test(true).try_init(); +} diff --git a/tests/v1_client_chat_async_test.rs b/tests/v1_client_chat_async_test.rs index 299d4aa..22c1ce8 100644 --- a/tests/v1_client_chat_async_test.rs +++ b/tests/v1_client_chat_async_test.rs @@ -1,20 +1,24 @@ use jrest::expect; use mistralai_client::v1::{ - chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, + chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, client::Client, constants::Model, + tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, }; +mod setup; + #[tokio::test] async fn test_client_chat_async() { + setup::setup(); + let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; - let messages = vec![ChatMessage { - role: ChatMessageRole::user, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), - }]; - let options = ChatCompletionParams { + let messages = vec![ChatMessage::new_user_message( + "Just guess the next word: \"Eiffel ...\"?", + )]; + let options = ChatParams { temperature: Some(0.0), random_seed: Some(42), ..Default::default() @@ -27,11 +31,70 @@ async fn test_client_chat_async() { expect!(response.model).to_be(Model::OpenMistral7b); expect!(response.object).to_be("chat.completion".to_string()); + expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); - expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::assistant); + expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); + + expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); expect!(response.choices[0].message.content.clone()) .to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); + + expect!(response.usage.prompt_tokens).to_be_greater_than(0); + expect!(response.usage.completion_tokens).to_be_greater_than(0); + expect!(response.usage.total_tokens).to_be_greater_than(0); +} + +#[tokio::test] +async fn test_client_chat_async_with_function_calling() { + setup::setup(); + + let tools = vec![Tool::new( + "get_city_temperature".to_string(), + "Get the current temperature in a city.".to_string(), + vec![ToolFunctionParameter::new( + "city".to_string(), + "The name of the city.".to_string(), + ToolFunctionParameterType::String, + )], + )]; + + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::MistralSmallLatest; + let messages = vec![ChatMessage::new_user_message( + "What's the current temperature in Paris?", + )]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + tool_choice: Some(ToolChoice::Any), + tools: Some(tools), + ..Default::default() + }; + + let response = client + .chat_async(model, messages, Some(options)) + .await + .unwrap(); + + expect!(response.model).to_be(Model::MistralSmallLatest); + expect!(response.object).to_be("chat.completion".to_string()); + + expect!(response.choices.len()).to_be(1); + expect!(response.choices[0].index).to_be(0); + expect!(response.choices[0].finish_reason.clone()) + .to_be(ChatResponseChoiceFinishReason::ToolCalls); + + expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); + expect!(response.choices[0].message.content.clone()).to_be("".to_string()); + // expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall { + // function: ToolCallFunction { + // name: "get_city_temperature".to_string(), + // arguments: "{\"city\": \"Paris\"}".to_string(), + // }, + // }])); + expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0); expect!(response.usage.total_tokens).to_be_greater_than(0); diff --git a/tests/v1_client_chat_stream_test.rs b/tests/v1_client_chat_stream_test.rs index 28449f3..23379fa 100644 --- a/tests/v1_client_chat_stream_test.rs +++ b/tests/v1_client_chat_stream_test.rs @@ -1,40 +1,40 @@ -use futures::stream::StreamExt; -use jrest::expect; -use mistralai_client::v1::{ - chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, - client::Client, - constants::Model, -}; +// use futures::stream::StreamExt; +// use jrest::expect; +// use mistralai_client::v1::{ +// chat_completion::{ChatParams, ChatMessage, ChatMessageRole}, +// client::Client, +// constants::Model, +// }; -#[tokio::test] -async fn test_client_chat_stream() { - let client = Client::new(None, None, None, None).unwrap(); +// #[tokio::test] +// async fn test_client_chat_stream() { +// let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; - let messages = vec![ChatMessage { - role: ChatMessageRole::user, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), - }]; - let options = ChatCompletionParams { - temperature: Some(0.0), - random_seed: Some(42), - ..Default::default() - }; +// let model = Model::OpenMistral7b; +// let messages = vec![ChatMessage::new_user_message( +// "Just guess the next word: \"Eiffel ...\"?", +// )]; +// let options = ChatParams { +// temperature: Some(0.0), +// random_seed: Some(42), +// ..Default::default() +// }; - let stream_result = client.chat_stream(model, messages, Some(options)).await; - let mut stream = stream_result.expect("Failed to create stream."); - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - if chunk.choices[0].delta.role == Some(ChatMessageRole::assistant) - || chunk.choices[0].finish_reason == Some("stop".to_string()) - { - expect!(chunk.choices[0].delta.content.len()).to_be(0); - } else { - expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0); - } - } - Err(e) => eprintln!("Error processing chunk: {:?}", e), - } - } -} +// let stream_result = client.chat_stream(model, messages, Some(options)).await; +// let mut stream = stream_result.expect("Failed to create stream."); +// while let Some(maybe_chunk_result) = stream.next().await { +// match maybe_chunk_result { +// Some(Ok(chunk)) => { +// if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant) +// || chunk.choices[0].finish_reason == Some("stop".to_string()) +// { +// expect!(chunk.choices[0].delta.content.len()).to_be(0); +// } else { +// expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0); +// } +// } +// Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error), +// None => (), +// } +// } +// } diff --git a/tests/v1_client_chat_test.rs b/tests/v1_client_chat_test.rs index 276029d..adca48d 100644 --- a/tests/v1_client_chat_test.rs +++ b/tests/v1_client_chat_test.rs @@ -1,20 +1,24 @@ use jrest::expect; use mistralai_client::v1::{ - chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, + chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, client::Client, constants::Model, + tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, }; +mod setup; + #[test] fn test_client_chat() { + setup::setup(); + let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; - let messages = vec![ChatMessage { - role: ChatMessageRole::user, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), - }]; - let options = ChatCompletionParams { + let messages = vec![ChatMessage::new_user_message( + "Just guess the next word: \"Eiffel ...\"?", + )]; + let options = ChatParams { temperature: Some(0.0), random_seed: Some(42), ..Default::default() @@ -26,9 +30,53 @@ fn test_client_chat() { expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); - expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::assistant); + expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); expect!(response.choices[0].message.content.clone()) .to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); + expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); + expect!(response.usage.prompt_tokens).to_be_greater_than(0); + expect!(response.usage.completion_tokens).to_be_greater_than(0); + expect!(response.usage.total_tokens).to_be_greater_than(0); +} + +#[test] +fn test_client_chat_with_function_calling() { + setup::setup(); + + let tools = vec![Tool::new( + "get_city_temperature".to_string(), + "Get the current temperature in a city.".to_string(), + vec![ToolFunctionParameter::new( + "city".to_string(), + "The name of the city.".to_string(), + ToolFunctionParameterType::String, + )], + )]; + + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::MistralSmallLatest; + let messages = vec![ChatMessage::new_user_message( + "What's the current temperature in Paris?", + )]; + let options = ChatParams { + temperature: Some(0.0), + random_seed: Some(42), + tool_choice: Some(ToolChoice::Auto), + tools: Some(tools), + ..Default::default() + }; + + let response = client.chat(model, messages, Some(options)).unwrap(); + + expect!(response.model).to_be(Model::MistralSmallLatest); + expect!(response.object).to_be("chat.completion".to_string()); + expect!(response.choices.len()).to_be(1); + expect!(response.choices[0].index).to_be(0); + expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); + expect!(response.choices[0].message.content.clone()).to_be("".to_string()); + expect!(response.choices[0].finish_reason.clone()) + .to_be(ChatResponseChoiceFinishReason::ToolCalls); expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0); expect!(response.usage.total_tokens).to_be_greater_than(0); diff --git a/tests/v1_client_list_models_async_test.rs b/tests/v1_client_list_models_async_test.rs index 757f40c..af5c364 100644 --- a/tests/v1_client_list_models_async_test.rs +++ b/tests/v1_client_list_models_async_test.rs @@ -9,12 +9,4 @@ async fn test_client_list_models_async() { expect!(response.object).to_be("list".to_string()); expect!(response.data.len()).to_be_greater_than(0); - - // let open_mistral_7b_data_item = response - // .data - // .iter() - // .find(|item| item.id == "open-mistral-7b") - // .unwrap(); - - // expect!(open_mistral_7b_data_item.id).to_be("open-mistral-7b".to_string()); } diff --git a/tests/v1_client_list_models_test.rs b/tests/v1_client_list_models_test.rs index 6a6e8ef..56e4a57 100644 --- a/tests/v1_client_list_models_test.rs +++ b/tests/v1_client_list_models_test.rs @@ -9,12 +9,4 @@ fn test_client_list_models() { expect!(response.object).to_be("list".to_string()); expect!(response.data.len()).to_be_greater_than(0); - - // let open_mistral_7b_data_item = response - // .data - // .iter() - // .find(|item| item.id == "open-mistral-7b") - // .unwrap(); - - // expect!(open_mistral_7b_data_item.id).to_be("open-mistral-7b".to_string()); }