use futures::stream::StreamExt; use futures::Stream; use reqwest::Error as ReqwestError; use serde_json::from_str; use crate::v1::error::ApiError; use crate::v1::{ chat_completion::{ ChatCompletionParams, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, }, constants::{EmbedModel, Model, API_URL_BASE}, embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse}, error::ClientError, model_list::ModelListResponse, }; use super::chat_completion::ChatCompletionStreamChunk; pub struct Client { pub api_key: String, pub endpoint: String, pub max_retries: u32, pub timeout: u32, } impl Client { pub fn new( api_key: Option, endpoint: Option, max_retries: Option, timeout: Option, ) -> Result { let api_key = api_key.unwrap_or(match std::env::var("MISTRAL_API_KEY") { Ok(api_key_from_env) => api_key_from_env, Err(_) => return Err(ClientError::ApiKeyError), }); let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string()); let max_retries = max_retries.unwrap_or(5); let timeout = timeout.unwrap_or(120); Ok(Self { api_key, endpoint, max_retries, timeout, }) } pub fn chat( &self, model: Model, messages: Vec, options: Option, ) -> Result { let request = ChatCompletionRequest::new(model, messages, false, options); let response = self.post_sync("/chat/completions", &request)?; let result = response.json::(); match result { Ok(response) => Ok(response), Err(error) => Err(self.to_api_error(error)), } } pub async fn chat_async( &self, model: Model, messages: Vec, options: Option, ) -> Result { let request = ChatCompletionRequest::new(model, messages, false, options); let response = self.post_async("/chat/completions", &request).await?; let result = response.json::().await; match result { Ok(response) => Ok(response), Err(error) => Err(self.to_api_error(error)), } } pub async fn chat_stream( &self, model: Model, messages: Vec, options: Option, ) -> Result>, ApiError> { let request = ChatCompletionRequest::new(model, messages, true, options); let response = self .post_stream("/chat/completions", &request) .await .map_err(|e| ApiError { message: e.to_string(), })?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); return Err(ApiError { message: format!("{}: {}", status, text), }); } let deserialized_stream = response .bytes_stream() .map(|item| -> Result { match item { Ok(bytes) => { let text = String::from_utf8(bytes.to_vec()).map_err(|e| ApiError { message: e.to_string(), })?; let text_trimmed = text.trim_start_matches("data: "); from_str(&text_trimmed).map_err(|e| ApiError { message: e.to_string(), }) } Err(e) => Err(ApiError { message: e.to_string(), }), } }); Ok(deserialized_stream) } pub fn embeddings( &self, model: EmbedModel, input: Vec, options: Option, ) -> Result { let request = EmbeddingRequest::new(model, input, options); let response = self.post_sync("/embeddings", &request)?; let result = response.json::(); match result { Ok(response) => Ok(response), Err(error) => Err(self.to_api_error(error)), } } pub async fn embeddings_async( &self, model: EmbedModel, input: Vec, options: Option, ) -> Result { let request = EmbeddingRequest::new(model, input, options); let response = self.post_async("/embeddings", &request).await?; let result = response.json::().await; match result { Ok(response) => Ok(response), Err(error) => Err(self.to_api_error(error)), } } pub fn list_models(&self) -> Result { let response = self.get_sync("/models")?; let result = response.json::(); match result { Ok(response) => Ok(response), Err(error) => Err(self.to_api_error(error)), } } pub async fn list_models_async(&self) -> Result { let response = self.get_async("/models").await?; let result = response.json::().await; match result { Ok(response) => Ok(response), Err(error) => Err(self.to_api_error(error)), } } fn build_request_sync( &self, request: reqwest::blocking::RequestBuilder, ) -> reqwest::blocking::RequestBuilder { let user_agent = format!( "ivangabriele/mistralai-client-rs/{}", env!("CARGO_PKG_VERSION") ); let request_builder = request .bearer_auth(&self.api_key) .header("Accept", "application/json") .header("Content-Type", "application/json") .header("User-Agent", user_agent); request_builder } fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { let user_agent = format!( "ivangabriele/mistralai-client-rs/{}", env!("CARGO_PKG_VERSION") ); let request_builder = request .bearer_auth(&self.api_key) .header("Accept", "application/json") .header("Content-Type", "application/json") .header("User-Agent", user_agent); request_builder } fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { let user_agent = format!( "ivangabriele/mistralai-client-rs/{}", env!("CARGO_PKG_VERSION") ); let request_builder = request .bearer_auth(&self.api_key) .header("Accept", "text/event-stream") .header("Content-Type", "application/json") .header("User-Agent", user_agent); request_builder } fn get_sync(&self, path: &str) -> Result { let reqwest_client = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); let request = self.build_request_sync(reqwest_client.get(url)); let result = request.send(); match result { Ok(response) => { if response.status().is_success() { Ok(response) } else { let status = response.status(); let text = response.text().unwrap(); Err(ApiError { message: format!("{}: {}", status, text), }) } } Err(error) => Err(ApiError { message: error.to_string(), }), } } async fn get_async(&self, path: &str) -> Result { let reqwest_client = reqwest::Client::new(); let url = format!("{}{}", self.endpoint, path); let request_builder = reqwest_client.get(url); let request = self.build_request_async(request_builder); let result = request.send().await; match result { Ok(response) => { if response.status().is_success() { Ok(response) } else { let status = response.status(); let text = response.text().await.unwrap_or_default(); Err(ApiError { message: format!("{}: {}", status, text), }) } } Err(error) => Err(ApiError { message: error.to_string(), }), } } fn post_sync( &self, path: &str, params: &T, ) -> Result { let reqwest_client = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_sync(request_builder); let result = request.send(); match result { Ok(response) => { if response.status().is_success() { Ok(response) } else { let status = response.status(); let text = response.text().unwrap_or_default(); Err(ApiError { message: format!("{}: {}", status, text), }) } } Err(error) => Err(ApiError { message: error.to_string(), }), } } async fn post_async( &self, path: &str, params: &T, ) -> Result { let reqwest_client = reqwest::Client::new(); let url = format!("{}{}", self.endpoint, path); let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_async(request_builder); let result = request.send().await; match result { Ok(response) => { if response.status().is_success() { Ok(response) } else { let status = response.status(); let text = response.text().await.unwrap_or_default(); Err(ApiError { message: format!("{}: {}", status, text), }) } } Err(error) => Err(ApiError { message: error.to_string(), }), } } async fn post_stream( &self, path: &str, params: &T, ) -> Result { let reqwest_client = reqwest::Client::new(); let url = format!("{}{}", self.endpoint, path); let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_stream(request_builder); let result = request.send().await; match result { Ok(response) => { if response.status().is_success() { Ok(response) } else { let status = response.status(); let text = response.text().await.unwrap_or_default(); Err(ApiError { message: format!("{}: {}", status, text), }) } } Err(error) => Err(ApiError { message: error.to_string(), }), } } fn to_api_error(&self, err: ReqwestError) -> ApiError { ApiError { message: err.to_string(), } } }