feat!: add missing api key error

BREAKING CHANGE: `APIError` is renamed to `ApiError`.
This commit is contained in:
Ivan Gabriele
2024-03-04 04:21:03 +01:00
parent b0a3f10c9f
commit 1deab88251
10 changed files with 64 additions and 47 deletions

View File

@@ -1,4 +1,4 @@
use crate::v1::error::APIError;
use crate::v1::error::ApiError;
use minreq::Response;
use crate::v1::{
@@ -7,6 +7,7 @@ use crate::v1::{
},
constants::{EmbedModel, Model, API_URL_BASE},
embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse},
error::ClientError,
model_list::ModelListResponse,
};
@@ -24,7 +25,10 @@ impl Client {
max_retries: Option<u32>,
timeout: Option<u32>,
) -> Self {
let api_key = api_key.unwrap_or(std::env::var("MISTRAL_API_KEY").unwrap());
let api_key = api_key.unwrap_or(
std::env::var("MISTRAL_API_KEY")
.unwrap_or_else(|_| panic!("{}", ClientError::ApiKeyError)),
);
let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string());
let max_retries = max_retries.unwrap_or(5);
let timeout = timeout.unwrap_or(120);
@@ -53,7 +57,7 @@ impl Client {
request
}
pub fn get(&self, path: &str) -> Result<Response, APIError> {
pub fn get(&self, path: &str) -> Result<Response, ApiError> {
let url = format!("{}{}", self.endpoint, path);
let request = self.build_request(minreq::get(url));
@@ -65,7 +69,7 @@ impl Client {
if (200..=299).contains(&response.status_code) {
Ok(response)
} else {
Err(APIError {
Err(ApiError {
message: format!(
"{}: {}",
response.status_code,
@@ -82,7 +86,7 @@ impl Client {
&self,
path: &str,
params: &T,
) -> Result<Response, APIError> {
) -> Result<Response, ApiError> {
// print!("{:?}", params);
let url = format!("{}{}", self.endpoint, path);
@@ -96,7 +100,7 @@ impl Client {
if (200..=299).contains(&response.status_code) {
Ok(response)
} else {
Err(APIError {
Err(ApiError {
message: format!(
"{}: {}",
response.status_code,
@@ -114,7 +118,7 @@ impl Client {
model: Model,
messages: Vec<ChatCompletionMessage>,
options: Option<ChatCompletionParams>,
) -> Result<ChatCompletionResponse, APIError> {
) -> Result<ChatCompletionResponse, ApiError> {
let request = ChatCompletionRequest::new(model, messages, options);
let response = self.post("/chat/completions", &request)?;
@@ -130,7 +134,7 @@ impl Client {
model: EmbedModel,
input: Vec<String>,
options: Option<EmbeddingRequestOptions>,
) -> Result<EmbeddingResponse, APIError> {
) -> Result<EmbeddingResponse, ApiError> {
let request = EmbeddingRequest::new(model, input, options);
let response = self.post("/embeddings", &request)?;
@@ -141,7 +145,7 @@ impl Client {
}
}
pub fn list_models(&self) -> Result<ModelListResponse, APIError> {
pub fn list_models(&self) -> Result<ModelListResponse, ApiError> {
let response = self.get("/models")?;
let result = response.json::<ModelListResponse>();
match result {
@@ -150,8 +154,8 @@ impl Client {
}
}
fn new_error(&self, err: minreq::Error) -> APIError {
APIError {
fn new_error(&self, err: minreq::Error) -> ApiError {
ApiError {
message: err.to_string(),
}
}

View File

@@ -2,14 +2,18 @@ use std::error::Error;
use std::fmt;
#[derive(Debug)]
pub struct APIError {
pub struct ApiError {
pub message: String,
}
impl fmt::Display for APIError {
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "APIError: {}", self.message)
write!(f, "ApiError: {}", self.message)
}
}
impl Error for ApiError {}
impl Error for APIError {}
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")]
ApiKeyError,
}