diff --git a/Cargo.toml b/Cargo.toml index 7eff3fa..0a4cedf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ futures = "0.3.30" reqwest = { version = "0.11.24", features = ["json", "blocking", "stream"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" +strum = "0.26.1" +strum_macros = "0.26.1" thiserror = "1.0.57" tokio = { version = "1.36.0", features = ["full"] } diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 52b3cba..701764b 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -11,7 +11,7 @@ pub struct ChatMessage { pub content: String, } -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +#[derive(Clone, Debug, strum_macros::Display, Eq, PartialEq, Deserialize, Serialize)] #[allow(non_camel_case_types)] pub enum ChatMessageRole { assistant, diff --git a/src/v1/client.rs b/src/v1/client.rs index 99d5c64..b4ec8c1 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -25,16 +25,39 @@ pub struct Client { } impl Client { + /// Constructs a new `Client`. + /// + /// # Arguments + /// + /// * `api_key` - An optional API key. + /// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable. + /// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided. + /// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`. + /// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`. + /// + /// # Examples + /// + /// ``` + /// use mistralai_client::v1::client::Client; + /// + /// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60)); + /// assert!(client.is_ok()); + /// ``` + /// + /// # Errors + /// + /// This method fails whenever neither the `api_key` is provided + /// nor the `MISTRAL_API_KEY` environment variable is set. pub fn new( api_key: Option, endpoint: Option, max_retries: Option, timeout: Option, ) -> Result { - let api_key = api_key.unwrap_or(match std::env::var("MISTRAL_API_KEY") { - Ok(api_key_from_env) => api_key_from_env, - Err(_) => return Err(ClientError::ApiKeyError), - }); + let api_key = match api_key { + Some(api_key_from_param) => api_key_from_param, + None => std::env::var("MISTRAL_API_KEY").map_err(|_| ClientError::MissingApiKey)?, + }; let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string()); let max_retries = max_retries.unwrap_or(5); let timeout = timeout.unwrap_or(120); @@ -63,6 +86,34 @@ impl Client { } } + /// Asynchronously sends a chat completion request and returns the response. + /// + /// # Arguments + /// + /// * `model` - The model to use for the chat completion. + /// * `messages` - A vector of `ChatMessage` to send as part of the chat. + /// * `options` - Optional `ChatCompletionParams` to customize the request. + /// + /// # Examples + /// + /// ``` + /// use mistralai_client::v1::{ + /// chat_completion::{ChatMessage, ChatMessageRole}, + /// client::Client, + /// constants::Model, + /// }; + /// + /// #[tokio::main] + /// async fn main() { + /// let client = Client::new(None, None, None, None).unwrap(); + /// let messages = vec![ChatMessage { + /// role: ChatMessageRole::user, + /// content: "Hello, world!".to_string(), + /// }]; + /// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap(); + /// println!("{}: {}", response.choices[0].message.role, response.choices[0].message.content); + /// } + /// ``` pub async fn chat_async( &self, model: Model, diff --git a/src/v1/error.rs b/src/v1/error.rs index 73d219b..ef42391 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -15,7 +15,7 @@ impl Error for ApiError {} #[derive(Debug, PartialEq, thiserror::Error)] pub enum ClientError { #[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")] - ApiKeyError, + MissingApiKey, #[error("Failed to read the response text.")] - ReadResponseTextError, + UnreadableResponseText, } diff --git a/tests/v1_client_new_test.rs b/tests/v1_client_new_test.rs index ca69e41..a7a43a7 100644 --- a/tests/v1_client_new_test.rs +++ b/tests/v1_client_new_test.rs @@ -26,6 +26,37 @@ fn test_client_new_with_none_params() { fn test_client_new_with_all_params() { let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); std::env::remove_var("MISTRAL_API_KEY"); + + let api_key = Some("test_api_key_from_param".to_string()); + let endpoint = Some("https://example.org".to_string()); + let max_retries = Some(10); + let timeout = Some(20); + + let client = Client::new( + api_key.clone(), + endpoint.clone(), + max_retries.clone(), + timeout.clone(), + ) + .unwrap(); + + expect!(client.api_key).to_be(api_key.unwrap()); + expect!(client.endpoint).to_be(endpoint.unwrap()); + expect!(client.max_retries).to_be(max_retries.unwrap()); + expect!(client.timeout).to_be(timeout.unwrap()); + + match maybe_original_mistral_api_key { + Some(original_mistral_api_key) => { + std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) + } + None => std::env::remove_var("MISTRAL_API_KEY"), + } +} + +#[test] +fn test_client_new_with_api_key_as_both_env_and_param() { + let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); + std::env::remove_var("MISTRAL_API_KEY"); std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); let api_key = Some("test_api_key_from_param".to_string()); @@ -62,8 +93,8 @@ fn test_client_new_with_missing_api_key() { let call = || Client::new(None, None, None, None); match call() { - Ok(_) => panic!("Expected `ClientError::ApiKeyError` but got Ok.`"), - Err(error) => assert_eq!(error, ClientError::ApiKeyError), + Ok(_) => panic!("Expected `ClientError::MissingApiKey` but got Ok.`"), + Err(error) => assert_eq!(error, ClientError::MissingApiKey), } match maybe_original_mistral_api_key {