From 75788b9395f05f35ee5755dcc956feda26203604 Mon Sep 17 00:00:00 2001 From: Ivan Gabriele Date: Mon, 4 Mar 2024 06:28:41 +0100 Subject: [PATCH] refactor: migrate to reqwest-only --- Cargo.toml | 3 +- src/v1/client.rs | 180 +++++++++++++++++++++++++++++------------------ src/v1/error.rs | 2 + 3 files changed, 114 insertions(+), 71 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b9547d5..4b50884 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,7 @@ readme = "README.md" repository = "https://github.com/ivangabriele/mistralai-client-rs" [dependencies] -minreq = { version = "2.11.0", features = ["https-rustls", "json-using-serde"] } -reqwest = { version = "0.11.24", features = ["json"] } +reqwest = { version = "0.11.24", features = ["json", "blocking"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" thiserror = "1.0.57" diff --git a/src/v1/client.rs b/src/v1/client.rs index d6f9389..1313b78 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -1,6 +1,5 @@ use crate::v1::error::ApiError; -use minreq::Response; -use reqwest::{Client as ReqwestClient, Error as ReqwestError}; +use reqwest::Error as ReqwestError; use crate::v1::{ chat_completion::{ @@ -42,44 +41,85 @@ impl Client { }) } - pub fn build_request(&self, request: minreq::Request) -> minreq::Request { - let authorization = format!("Bearer {}", self.api_key); + pub fn build_request_sync( + &self, + request: reqwest::blocking::RequestBuilder, + ) -> reqwest::blocking::RequestBuilder { let user_agent = format!( "ivangabriele/mistralai-client-rs/{}", env!("CARGO_PKG_VERSION") ); - let request = request - .with_header("Authorization", authorization) - .with_header("Accept", "application/json") - .with_header("Content-Type", "application/json") - .with_header("User-Agent", user_agent); + let request_builder = request + .bearer_auth(&self.api_key) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("User-Agent", user_agent); - request + request_builder } - pub fn get_sync(&self, path: &str) -> Result { + pub fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + let user_agent = format!( + "ivangabriele/mistralai-client-rs/{}", + env!("CARGO_PKG_VERSION") + ); + + let request_builder = request + .bearer_auth(&self.api_key) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("User-Agent", user_agent); + + request_builder + } + + pub fn get_sync(&self, path: &str) -> Result { + let client_sync = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); - let request = self.build_request(minreq::get(url)); + let request = self.build_request_sync(client_sync.get(url)); let result = request.send(); match result { Ok(response) => { - print!("{:?}", response.as_str().unwrap()); - - if (200..=299).contains(&response.status_code) { + if response.status().is_success() { Ok(response) } else { + let status = response.status(); + let text = response.text().unwrap(); Err(ApiError { - message: format!( - "{}: {}", - response.status_code, - response.as_str().unwrap() - ), + message: format!("{}: {}", status, text), }) } } - Err(error) => Err(self.new_minreq_error(error)), + Err(error) => Err(ApiError { + message: error.to_string(), + }), + } + } + + pub async fn get_async(&self, path: &str) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + let request_builder = reqwest_client.get(url); + let request = self.build_request_async(request_builder); + + let result = request.send().await.map_err(|e| self.to_api_error(e)); + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) + } + } + Err(error) => Err(ApiError { + message: error.to_string(), + }), } } @@ -87,30 +127,57 @@ impl Client { &self, path: &str, params: &T, - ) -> Result { - // print!("{:?}", params); - + ) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); - let request = self.build_request(minreq::post(url)); + let request_builder = reqwest_client.post(url).json(params); + let request = self.build_request_sync(request_builder); - let result = request.with_json(params).unwrap().send(); + let result = request.send(); match result { Ok(response) => { - print!("{:?}", response.as_str().unwrap()); - - if (200..=299).contains(&response.status_code) { + if response.status().is_success() { Ok(response) } else { + let status = response.status(); + let text = response.text().unwrap_or_default(); Err(ApiError { - message: format!( - "{}: {}", - response.status_code, - response.as_str().unwrap() - ), + message: format!("{}: {}", status, text), }) } } - Err(error) => Err(self.new_minreq_error(error)), + Err(error) => Err(ApiError { + message: error.to_string(), + }), + } + } + + pub async fn post_async( + &self, + path: &str, + params: &T, + ) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + let request_builder = reqwest_client.post(url).json(params); + let request = self.build_request_async(request_builder); + + let result = request.send().await.map_err(|e| self.to_api_error(e)); + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) + } + } + Err(error) => Err(ApiError { + message: error.to_string(), + }), } } @@ -126,7 +193,7 @@ impl Client { let result = response.json::(); match result { Ok(response) => Ok(response), - Err(error) => Err(self.new_minreq_error(error)), + Err(error) => Err(self.to_api_error(error)), } } @@ -137,31 +204,12 @@ impl Client { options: Option, ) -> Result { let request = ChatCompletionRequest::new(model, messages, options); - let client = ReqwestClient::new(); - let response = client - .post(format!("{}{}", self.endpoint, "/chat/completions")) - .json(&request) - .bearer_auth(&self.api_key) - .header( - "User-Agent", - format!("mistralai-client-rs/{}", env!("CARGO_PKG_VERSION")), - ) - .send() - .await - .map_err(|e| self.new_reqwest_error(e))?; - - if response.status().is_success() { - response - .json::() - .await - .map_err(|e| self.new_reqwest_error(e)) - } else { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - Err(ApiError { - message: format!("{}: {}", status, text), - }) + let response = self.post_async("/chat/completions", &request).await?; + let result = response.json::().await; + match result { + Ok(response) => Ok(response), + Err(error) => Err(self.to_api_error(error)), } } @@ -177,7 +225,7 @@ impl Client { let result = response.json::(); match result { Ok(response) => Ok(response), - Err(error) => Err(self.new_minreq_error(error)), + Err(error) => Err(self.to_api_error(error)), } } @@ -186,17 +234,11 @@ impl Client { let result = response.json::(); match result { Ok(response) => Ok(response), - Err(error) => Err(self.new_minreq_error(error)), + Err(error) => Err(self.to_api_error(error)), } } - fn new_minreq_error(&self, err: minreq::Error) -> ApiError { - ApiError { - message: err.to_string(), - } - } - - fn new_reqwest_error(&self, err: ReqwestError) -> ApiError { + fn to_api_error(&self, err: ReqwestError) -> ApiError { ApiError { message: err.to_string(), } diff --git a/src/v1/error.rs b/src/v1/error.rs index c5f12cf..73d219b 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -16,4 +16,6 @@ impl Error for ApiError {} pub enum ClientError { #[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")] ApiKeyError, + #[error("Failed to read the response text.")] + ReadResponseTextError, }