From b69f7c617c15dd63abb61d004636512916d766bb Mon Sep 17 00:00:00 2001 From: Ivan Gabriele Date: Mon, 4 Mar 2024 06:32:01 +0100 Subject: [PATCH] feat: add client.list_models_async() method --- README.md | 16 +- src/v1/client.rs | 289 +++++++++++----------- tests/v1_client_list_models_async_test.rs | 20 ++ 3 files changed, 183 insertions(+), 142 deletions(-) create mode 100644 tests/v1_client_list_models_async_test.rs diff --git a/README.md b/README.md index 6c6a0bb..aebbe59 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Rust client for the Mistral AI API. - [x] Embedding - [ ] Embedding (async) - [x] List models -- [ ] List models (async) +- [x] List models (async) - [ ] Function Calling - [ ] Function Calling (async) @@ -175,4 +175,16 @@ fn main() { ### List models (async) -_In progress._ +```rs +use mistralai_client::v1::client::Client; + +#[tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).await.unwrap(); + + let result = client.list_models().unwrap(); + println!("First Model ID: {:?}", result.data[0].id); + // => "First Model ID: open-mistral-7b" +} +``` diff --git a/src/v1/client.rs b/src/v1/client.rs index 1313b78..4ec2ce4 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -41,146 +41,6 @@ impl Client { }) } - pub fn build_request_sync( - &self, - request: reqwest::blocking::RequestBuilder, - ) -> reqwest::blocking::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request - .bearer_auth(&self.api_key) - .header("Accept", "application/json") - .header("Content-Type", "application/json") - .header("User-Agent", user_agent); - - request_builder - } - - pub fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request - .bearer_auth(&self.api_key) - .header("Accept", "application/json") - .header("Content-Type", "application/json") - .header("User-Agent", user_agent); - - request_builder - } - - pub fn get_sync(&self, path: &str) -> Result { - let client_sync = reqwest::blocking::Client::new(); - let url = format!("{}{}", self.endpoint, path); - let request = self.build_request_sync(client_sync.get(url)); - - let result = request.send(); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let status = response.status(); - let text = response.text().unwrap(); - Err(ApiError { - message: format!("{}: {}", status, text), - }) - } - } - Err(error) => Err(ApiError { - message: error.to_string(), - }), - } - } - - pub async fn get_async(&self, path: &str) -> Result { - let reqwest_client = reqwest::Client::new(); - let url = format!("{}{}", self.endpoint, path); - let request_builder = reqwest_client.get(url); - let request = self.build_request_async(request_builder); - - let result = request.send().await.map_err(|e| self.to_api_error(e)); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), - }) - } - } - Err(error) => Err(ApiError { - message: error.to_string(), - }), - } - } - - pub fn post_sync( - &self, - path: &str, - params: &T, - ) -> Result { - let reqwest_client = reqwest::blocking::Client::new(); - let url = format!("{}{}", self.endpoint, path); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_sync(request_builder); - - let result = request.send(); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let status = response.status(); - let text = response.text().unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), - }) - } - } - Err(error) => Err(ApiError { - message: error.to_string(), - }), - } - } - - pub async fn post_async( - &self, - path: &str, - params: &T, - ) -> Result { - let reqwest_client = reqwest::Client::new(); - let url = format!("{}{}", self.endpoint, path); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_async(request_builder); - - let result = request.send().await.map_err(|e| self.to_api_error(e)); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), - }) - } - } - Err(error) => Err(ApiError { - message: error.to_string(), - }), - } - } - pub fn chat( &self, model: Model, @@ -238,6 +98,155 @@ impl Client { } } + pub async fn list_models_async(&self) -> Result { + let response = self.get_async("/models").await?; + let result = response.json::().await; + match result { + Ok(response) => Ok(response), + Err(error) => Err(self.to_api_error(error)), + } + } + + fn build_request_sync( + &self, + request: reqwest::blocking::RequestBuilder, + ) -> reqwest::blocking::RequestBuilder { + let user_agent = format!( + "ivangabriele/mistralai-client-rs/{}", + env!("CARGO_PKG_VERSION") + ); + + let request_builder = request + .bearer_auth(&self.api_key) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("User-Agent", user_agent); + + request_builder + } + + fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + let user_agent = format!( + "ivangabriele/mistralai-client-rs/{}", + env!("CARGO_PKG_VERSION") + ); + + let request_builder = request + .bearer_auth(&self.api_key) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("User-Agent", user_agent); + + request_builder + } + + fn get_sync(&self, path: &str) -> Result { + let client_sync = reqwest::blocking::Client::new(); + let url = format!("{}{}", self.endpoint, path); + let request = self.build_request_sync(client_sync.get(url)); + + let result = request.send(); + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let text = response.text().unwrap(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) + } + } + Err(error) => Err(ApiError { + message: error.to_string(), + }), + } + } + + async fn get_async(&self, path: &str) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + let request_builder = reqwest_client.get(url); + let request = self.build_request_async(request_builder); + + let result = request.send().await.map_err(|e| self.to_api_error(e)); + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) + } + } + Err(error) => Err(ApiError { + message: error.to_string(), + }), + } + } + + fn post_sync( + &self, + path: &str, + params: &T, + ) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}{}", self.endpoint, path); + let request_builder = reqwest_client.post(url).json(params); + let request = self.build_request_sync(request_builder); + + let result = request.send(); + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let text = response.text().unwrap_or_default(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) + } + } + Err(error) => Err(ApiError { + message: error.to_string(), + }), + } + } + + async fn post_async( + &self, + path: &str, + params: &T, + ) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + let request_builder = reqwest_client.post(url).json(params); + let request = self.build_request_async(request_builder); + + let result = request.send().await.map_err(|e| self.to_api_error(e)); + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) + } + } + Err(error) => Err(ApiError { + message: error.to_string(), + }), + } + } + fn to_api_error(&self, err: ReqwestError) -> ApiError { ApiError { message: err.to_string(), diff --git a/tests/v1_client_list_models_async_test.rs b/tests/v1_client_list_models_async_test.rs new file mode 100644 index 0000000..757f40c --- /dev/null +++ b/tests/v1_client_list_models_async_test.rs @@ -0,0 +1,20 @@ +use jrest::expect; +use mistralai_client::v1::client::Client; + +#[tokio::test] +async fn test_client_list_models_async() { + let client = Client::new(None, None, None, None).unwrap(); + + let response = client.list_models_async().await.unwrap(); + + expect!(response.object).to_be("list".to_string()); + expect!(response.data.len()).to_be_greater_than(0); + + // let open_mistral_7b_data_item = response + // .data + // .iter() + // .find(|item| item.id == "open-mistral-7b") + // .unwrap(); + + // expect!(open_mistral_7b_data_item.id).to_be("open-mistral-7b".to_string()); +}