fix!: fix failure when api key as param and not env

BREAKING CHANGE:

- Rename `ClientError.ApiKeyError` to `MissingApiKey`.
- Rename `ClientError.ReadResponseTextError` to `ClientError.UnreadableResponseText`.
This commit is contained in:
Ivan Gabriele
2024-03-04 21:11:40 +01:00
parent 5217fcfb94
commit ef5d475e2d
5 changed files with 93 additions and 9 deletions

View File

@@ -19,6 +19,8 @@ futures = "0.3.30"
reqwest = { version = "0.11.24", features = ["json", "blocking", "stream"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
strum = "0.26.1"
strum_macros = "0.26.1"
thiserror = "1.0.57"
tokio = { version = "1.36.0", features = ["full"] }

View File

@@ -11,7 +11,7 @@ pub struct ChatMessage {
pub content: String,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[derive(Clone, Debug, strum_macros::Display, Eq, PartialEq, Deserialize, Serialize)]
#[allow(non_camel_case_types)]
pub enum ChatMessageRole {
assistant,

View File

@@ -25,16 +25,39 @@ pub struct Client {
}
impl Client {
/// Constructs a new `Client`.
///
/// # Arguments
///
/// * `api_key` - An optional API key.
/// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable.
/// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided.
/// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`.
/// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::client::Client;
///
/// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60));
/// assert!(client.is_ok());
/// ```
///
/// # Errors
///
/// This method fails whenever neither the `api_key` is provided
/// nor the `MISTRAL_API_KEY` environment variable is set.
pub fn new(
api_key: Option<String>,
endpoint: Option<String>,
max_retries: Option<u32>,
timeout: Option<u32>,
) -> Result<Self, ClientError> {
let api_key = api_key.unwrap_or(match std::env::var("MISTRAL_API_KEY") {
Ok(api_key_from_env) => api_key_from_env,
Err(_) => return Err(ClientError::ApiKeyError),
});
let api_key = match api_key {
Some(api_key_from_param) => api_key_from_param,
None => std::env::var("MISTRAL_API_KEY").map_err(|_| ClientError::MissingApiKey)?,
};
let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string());
let max_retries = max_retries.unwrap_or(5);
let timeout = timeout.unwrap_or(120);
@@ -63,6 +86,34 @@ impl Client {
}
}
/// Asynchronously sends a chat completion request and returns the response.
///
/// # Arguments
///
/// * `model` - The model to use for the chat completion.
/// * `messages` - A vector of `ChatMessage` to send as part of the chat.
/// * `options` - Optional `ChatCompletionParams` to customize the request.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::{
/// chat_completion::{ChatMessage, ChatMessageRole},
/// client::Client,
/// constants::Model,
/// };
///
/// #[tokio::main]
/// async fn main() {
/// let client = Client::new(None, None, None, None).unwrap();
/// let messages = vec![ChatMessage {
/// role: ChatMessageRole::user,
/// content: "Hello, world!".to_string(),
/// }];
/// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap();
/// println!("{}: {}", response.choices[0].message.role, response.choices[0].message.content);
/// }
/// ```
pub async fn chat_async(
&self,
model: Model,

View File

@@ -15,7 +15,7 @@ impl Error for ApiError {}
#[derive(Debug, PartialEq, thiserror::Error)]
pub enum ClientError {
#[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")]
ApiKeyError,
MissingApiKey,
#[error("Failed to read the response text.")]
ReadResponseTextError,
UnreadableResponseText,
}

View File

@@ -26,6 +26,37 @@ fn test_client_new_with_none_params() {
fn test_client_new_with_all_params() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
let api_key = Some("test_api_key_from_param".to_string());
let endpoint = Some("https://example.org".to_string());
let max_retries = Some(10);
let timeout = Some(20);
let client = Client::new(
api_key.clone(),
endpoint.clone(),
max_retries.clone(),
timeout.clone(),
)
.unwrap();
expect!(client.api_key).to_be(api_key.unwrap());
expect!(client.endpoint).to_be(endpoint.unwrap());
expect!(client.max_retries).to_be(max_retries.unwrap());
expect!(client.timeout).to_be(timeout.unwrap());
match maybe_original_mistral_api_key {
Some(original_mistral_api_key) => {
std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key)
}
None => std::env::remove_var("MISTRAL_API_KEY"),
}
}
#[test]
fn test_client_new_with_api_key_as_both_env_and_param() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
let api_key = Some("test_api_key_from_param".to_string());
@@ -62,8 +93,8 @@ fn test_client_new_with_missing_api_key() {
let call = || Client::new(None, None, None, None);
match call() {
Ok(_) => panic!("Expected `ClientError::ApiKeyError` but got Ok.`"),
Err(error) => assert_eq!(error, ClientError::ApiKeyError),
Ok(_) => panic!("Expected `ClientError::MissingApiKey` but got Ok.`"),
Err(error) => assert_eq!(error, ClientError::MissingApiKey),
}
match maybe_original_mistral_api_key {