diff --git a/README.md b/README.md index 8b6edbd..47dcdb9 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ use mistralai_client::v1::client::Client; fn main() { let api_key = "your_api_key"; - let client = Client::new(Some(api_key), None, None, None); + let client = Client::new(Some(api_key), None, None, None).unwrap(); } ``` @@ -71,7 +71,7 @@ use mistralai_client::v1::{ fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. - let client = Client::new(None, None, None, None); + let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; let messages = vec![ChatCompletionMessage { @@ -101,7 +101,7 @@ use mistralai_client::v1::{client::Client, constants::EmbedModel}; fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. - let client: Client = Client::new(None, None, None, None); + let client: Client = Client::new(None, None, None, None).unwrap(); let model = EmbedModel::MistralEmbed; let input = vec!["Embed this sentence.", "As well as this one."] @@ -123,7 +123,7 @@ use mistralai_client::v1::client::Client; fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. - let client = Client::new(None, None, None, None); + let client = Client::new(None, None, None, None).unwrap(); let result = client.list_models().unwrap(); println!("First Model ID: {:?}", result.data[0].id); diff --git a/src/v1/client.rs b/src/v1/client.rs index 9248c2c..7f32f2f 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -24,21 +24,21 @@ impl Client { endpoint: Option, max_retries: Option, timeout: Option, - ) -> Self { - let api_key = api_key.unwrap_or( - std::env::var("MISTRAL_API_KEY") - .unwrap_or_else(|_| panic!("{}", ClientError::ApiKeyError)), - ); + ) -> Result { + let api_key = api_key.unwrap_or(match std::env::var("MISTRAL_API_KEY") { + Ok(api_key_from_env) => api_key_from_env, + Err(_) => return Err(ClientError::ApiKeyError), + }); let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string()); let max_retries = max_retries.unwrap_or(5); let timeout = timeout.unwrap_or(120); - Self { + Ok(Self { api_key, endpoint, max_retries, timeout, - } + }) } pub fn build_request(&self, request: minreq::Request) -> minreq::Request { diff --git a/src/v1/error.rs b/src/v1/error.rs index 74f3893..c5f12cf 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -12,7 +12,7 @@ impl fmt::Display for ApiError { } impl Error for ApiError {} -#[derive(Debug, thiserror::Error)] +#[derive(Debug, PartialEq, thiserror::Error)] pub enum ClientError { #[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")] ApiKeyError, diff --git a/tests/v1_client_chat_test.rs b/tests/v1_client_chat_test.rs index d7cc9b6..6ae3909 100644 --- a/tests/v1_client_chat_test.rs +++ b/tests/v1_client_chat_test.rs @@ -7,7 +7,7 @@ use mistralai_client::v1::{ #[test] fn test_client_chat() { - let client = Client::new(None, None, None, None); + let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; let messages = vec![ChatCompletionMessage { diff --git a/tests/v1_client_embeddings_test.rs b/tests/v1_client_embeddings_test.rs index 2c21ba6..bb32fa4 100644 --- a/tests/v1_client_embeddings_test.rs +++ b/tests/v1_client_embeddings_test.rs @@ -3,7 +3,7 @@ use mistralai_client::v1::{client::Client, constants::EmbedModel}; #[test] fn test_client_embeddings() { - let client: Client = Client::new(None, None, None, None); + let client: Client = Client::new(None, None, None, None).unwrap(); let model = EmbedModel::MistralEmbed; let input = vec!["Embed this sentence.", "As well as this one."] diff --git a/tests/v1_client_list_models_test.rs b/tests/v1_client_list_models_test.rs index 299866a..6a6e8ef 100644 --- a/tests/v1_client_list_models_test.rs +++ b/tests/v1_client_list_models_test.rs @@ -3,7 +3,7 @@ use mistralai_client::v1::client::Client; #[test] fn test_client_list_models() { - let client = Client::new(None, None, None, None); + let client = Client::new(None, None, None, None).unwrap(); let response = client.list_models().unwrap(); diff --git a/tests/v1_client_new_test.rs b/tests/v1_client_new_test.rs index e461447..ca69e41 100644 --- a/tests/v1_client_new_test.rs +++ b/tests/v1_client_new_test.rs @@ -1,5 +1,5 @@ use jrest::expect; -use mistralai_client::v1::client::Client; +use mistralai_client::v1::{client::Client, error::ClientError}; #[test] fn test_client_new_with_none_params() { @@ -7,7 +7,7 @@ fn test_client_new_with_none_params() { std::env::remove_var("MISTRAL_API_KEY"); std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); - let client = Client::new(None, None, None, None); + let client = Client::new(None, None, None, None).unwrap(); expect!(client.api_key).to_be("test_api_key_from_env".to_string()); expect!(client.endpoint).to_be("https://api.mistral.ai/v1".to_string()); @@ -38,7 +38,8 @@ fn test_client_new_with_all_params() { endpoint.clone(), max_retries.clone(), timeout.clone(), - ); + ) + .unwrap(); expect!(client.api_key).to_be(api_key.unwrap()); expect!(client.endpoint).to_be(endpoint.unwrap()); @@ -54,12 +55,16 @@ fn test_client_new_with_all_params() { } #[test] -#[should_panic] fn test_client_new_with_missing_api_key() { let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); std::env::remove_var("MISTRAL_API_KEY"); - let _client = Client::new(None, None, None, None); + let call = || Client::new(None, None, None, None); + + match call() { + Ok(_) => panic!("Expected `ClientError::ApiKeyError` but got Ok.`"), + Err(error) => assert_eq!(error, ClientError::ApiKeyError), + } match maybe_original_mistral_api_key { Some(original_mistral_api_key) => {