From 1dd59f67048c10458ab0382af8fdfe4ed21c82fa Mon Sep 17 00:00:00 2001 From: Ivan Gabriele Date: Mon, 4 Mar 2024 04:57:48 +0100 Subject: [PATCH] feat: add client.chat_async() method --- Cargo.toml | 2 + README.md | 46 +++++++++++++++++++++ src/v1/client.rs | 64 +++++++++++++++++++++++++----- tests/v1_client_chat_async_test.rs | 38 ++++++++++++++++++ 4 files changed, 139 insertions(+), 11 deletions(-) create mode 100644 tests/v1_client_chat_async_test.rs diff --git a/Cargo.toml b/Cargo.toml index 0a60c77..37c08f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,11 @@ repository = "https://github.com/ivangabriele/mistralai-client-rs" [dependencies] minreq = { version = "2.11.0", features = ["https-rustls", "json-using-serde"] } +reqwest = { version = "0.11.24", features = ["json"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" thiserror = "1.0.57" +tokio = { version = "1.36.0", features = ["full"] } [dev-dependencies] jrest = "0.2.3" diff --git a/README.md b/README.md index 47dcdb9..6c6a0bb 100644 --- a/README.md +++ b/README.md @@ -16,19 +16,26 @@ Rust client for the Mistral AI API. - [As a client argument](#as-a-client-argument) - [Usage](#usage) - [Chat without streaming](#chat-without-streaming) + - [Chat without streaming (async)](#chat-without-streaming-async) - [Chat with streaming](#chat-with-streaming) - [Embeddings](#embeddings) + - [Embeddings (async)](#embeddings-async) - [List models](#list-models) + - [List models (async)](#list-models-async) --- ## Supported APIs - [x] Chat without streaming +- [x] Chat without streaming (async) - [ ] Chat with streaming - [x] Embedding +- [ ] Embedding (async) - [x] List models +- [ ] List models (async) - [ ] Function Calling +- [ ] Function Calling (async) ## Installation @@ -90,6 +97,37 @@ fn main() { } ``` +### Chat without streaming (async) + +```rs +use mistralai_client::v1::{ + chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequestOptions}, + client::Client, + constants::Model, +}; + +#[tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::OpenMistral7b; + let messages = vec![ChatCompletionMessage { + role: ChatCompletionMessageRole::user, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + }]; + let options = ChatCompletionRequestOptions { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let result = client.chat(model, messages, Some(options)).await.unwrap(); + println!("Assistant: {}", result.choices[0].message.content); + // => "Assistant: Tower. [...]" +} +``` + ### Chat with streaming _In progress._ @@ -116,6 +154,10 @@ fn main() { } ``` +### Embeddings (async) + +_In progress._ + ### List models ```rs @@ -130,3 +172,7 @@ fn main() { // => "First Model ID: open-mistral-7b" } ``` + +### List models (async) + +_In progress._ diff --git a/src/v1/client.rs b/src/v1/client.rs index 7f32f2f..d6f9389 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -1,5 +1,6 @@ use crate::v1::error::ApiError; use minreq::Response; +use reqwest::{Client as ReqwestClient, Error as ReqwestError}; use crate::v1::{ chat_completion::{ @@ -57,7 +58,7 @@ impl Client { request } - pub fn get(&self, path: &str) -> Result { + pub fn get_sync(&self, path: &str) -> Result { let url = format!("{}{}", self.endpoint, path); let request = self.build_request(minreq::get(url)); @@ -78,11 +79,11 @@ impl Client { }) } } - Err(error) => Err(self.new_error(error)), + Err(error) => Err(self.new_minreq_error(error)), } } - pub fn post( + pub fn post_sync( &self, path: &str, params: &T, @@ -109,7 +110,7 @@ impl Client { }) } } - Err(error) => Err(self.new_error(error)), + Err(error) => Err(self.new_minreq_error(error)), } } @@ -121,11 +122,46 @@ impl Client { ) -> Result { let request = ChatCompletionRequest::new(model, messages, options); - let response = self.post("/chat/completions", &request)?; + let response = self.post_sync("/chat/completions", &request)?; let result = response.json::(); match result { Ok(response) => Ok(response), - Err(error) => Err(self.new_error(error)), + Err(error) => Err(self.new_minreq_error(error)), + } + } + + pub async fn chat_async( + &self, + model: Model, + messages: Vec, + options: Option, + ) -> Result { + let request = ChatCompletionRequest::new(model, messages, options); + let client = ReqwestClient::new(); + + let response = client + .post(format!("{}{}", self.endpoint, "/chat/completions")) + .json(&request) + .bearer_auth(&self.api_key) + .header( + "User-Agent", + format!("mistralai-client-rs/{}", env!("CARGO_PKG_VERSION")), + ) + .send() + .await + .map_err(|e| self.new_reqwest_error(e))?; + + if response.status().is_success() { + response + .json::() + .await + .map_err(|e| self.new_reqwest_error(e)) + } else { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) } } @@ -137,24 +173,30 @@ impl Client { ) -> Result { let request = EmbeddingRequest::new(model, input, options); - let response = self.post("/embeddings", &request)?; + let response = self.post_sync("/embeddings", &request)?; let result = response.json::(); match result { Ok(response) => Ok(response), - Err(error) => Err(self.new_error(error)), + Err(error) => Err(self.new_minreq_error(error)), } } pub fn list_models(&self) -> Result { - let response = self.get("/models")?; + let response = self.get_sync("/models")?; let result = response.json::(); match result { Ok(response) => Ok(response), - Err(error) => Err(self.new_error(error)), + Err(error) => Err(self.new_minreq_error(error)), } } - fn new_error(&self, err: minreq::Error) -> ApiError { + fn new_minreq_error(&self, err: minreq::Error) -> ApiError { + ApiError { + message: err.to_string(), + } + } + + fn new_reqwest_error(&self, err: ReqwestError) -> ApiError { ApiError { message: err.to_string(), } diff --git a/tests/v1_client_chat_async_test.rs b/tests/v1_client_chat_async_test.rs new file mode 100644 index 0000000..2c8eb72 --- /dev/null +++ b/tests/v1_client_chat_async_test.rs @@ -0,0 +1,38 @@ +use jrest::expect; +use mistralai_client::v1::{ + chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionParams}, + client::Client, + constants::Model, +}; + +#[tokio::test] +async fn test_client_chat_async() { + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::OpenMistral7b; + let messages = vec![ChatCompletionMessage { + role: ChatCompletionMessageRole::user, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + }]; + let options = ChatCompletionParams { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let response = client + .chat_async(model, messages, Some(options)) + .await + .unwrap(); + + expect!(response.model).to_be(Model::OpenMistral7b); + expect!(response.object).to_be("chat.completion".to_string()); + expect!(response.choices.len()).to_be(1); + expect!(response.choices[0].index).to_be(0); + expect!(response.choices[0].message.role.clone()).to_be(ChatCompletionMessageRole::assistant); + expect!(response.choices[0].message.content.clone()) + .to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); + expect!(response.usage.prompt_tokens).to_be_greater_than(0); + expect!(response.usage.completion_tokens).to_be_greater_than(0); + expect!(response.usage.total_tokens).to_be_greater_than(0); +}