feat!: add missing api key error

BREAKING CHANGE: `APIError` is renamed to `ApiError`.
This commit is contained in:
Ivan Gabriele
2024-03-04 04:21:03 +01:00
parent b0a3f10c9f
commit 1deab88251
10 changed files with 64 additions and 47 deletions

View File

@@ -1,3 +1,3 @@
# This key is only used for development purposes.
# You'll only need one if you want to contribute to this library.
MISTRAL_API_KEY=
export MISTRAL_API_KEY=

View File

@@ -27,8 +27,15 @@ Then run:
git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork
cd ./mistralai-client-rs
cargo build
cp .env.example .env
```
Then edit the `.env` file to set your `MISTRAL_API_KEY`.
> [!NOTE]
> All tests use either the `open-mistral-7b` or `mistral-embed` models and only consume a few dozen tokens.
> So you would have to run them thousands of times to even reach a single dollar of usage.
### Optional requirements
- [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-*-watch`.
@@ -51,5 +58,4 @@ Help us keep this project open and inclusive. Please read and follow our [Code o
## Commit Message Format
This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification and
specificaly the [Angular Commit Message Guidelines](https://github.com/angular/angular/blob/main/CONTRIBUTING.md#commit).
This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.

View File

@@ -19,8 +19,6 @@ minreq = { version = "2.11.0", features = ["https-rustls", "json-using-serde"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
thiserror = "1.0.57"
tokio = { version = "1.36.0", features = ["full"] }
[dev-dependencies]
dotenv = "0.15.0"
jrest = "0.2.3"

View File

@@ -1,7 +1,9 @@
SHELL := /bin/bash
.PHONY: test
define RELEASE_TEMPLATE
conventional-changelog -p conventionalcommits -i CHANGELOG.md -s
conventional-changelog -p conventionalcommits -i ./CHANGELOG.md -s
git add .
git commit -m "docs(changelog): update"
git push origin HEAD
@@ -9,13 +11,6 @@ define RELEASE_TEMPLATE
git push origin HEAD --tags
endef
test:
cargo test --no-fail-fast
test-cover:
cargo tarpaulin --frozen --no-fail-fast --out Xml --skip-clean
test-watch:
cargo watch -x "test -- --nocapture"
release-patch:
$(call RELEASE_TEMPLATE,patch)
@@ -24,3 +19,10 @@ release-minor:
release-major:
$(call RELEASE_TEMPLATE,major)
test:
@source ./.env && cargo test --all-targets --no-fail-fast
test-cover:
cargo tarpaulin --all-targets --frozen --no-fail-fast --out Xml --skip-clean
test-watch:
cargo watch -x "test -- --all-targets --nocapture"

View File

@@ -1,4 +1,4 @@
use crate::v1::error::APIError;
use crate::v1::error::ApiError;
use minreq::Response;
use crate::v1::{
@@ -7,6 +7,7 @@ use crate::v1::{
},
constants::{EmbedModel, Model, API_URL_BASE},
embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse},
error::ClientError,
model_list::ModelListResponse,
};
@@ -24,7 +25,10 @@ impl Client {
max_retries: Option<u32>,
timeout: Option<u32>,
) -> Self {
let api_key = api_key.unwrap_or(std::env::var("MISTRAL_API_KEY").unwrap());
let api_key = api_key.unwrap_or(
std::env::var("MISTRAL_API_KEY")
.unwrap_or_else(|_| panic!("{}", ClientError::ApiKeyError)),
);
let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string());
let max_retries = max_retries.unwrap_or(5);
let timeout = timeout.unwrap_or(120);
@@ -53,7 +57,7 @@ impl Client {
request
}
pub fn get(&self, path: &str) -> Result<Response, APIError> {
pub fn get(&self, path: &str) -> Result<Response, ApiError> {
let url = format!("{}{}", self.endpoint, path);
let request = self.build_request(minreq::get(url));
@@ -65,7 +69,7 @@ impl Client {
if (200..=299).contains(&response.status_code) {
Ok(response)
} else {
Err(APIError {
Err(ApiError {
message: format!(
"{}: {}",
response.status_code,
@@ -82,7 +86,7 @@ impl Client {
&self,
path: &str,
params: &T,
) -> Result<Response, APIError> {
) -> Result<Response, ApiError> {
// print!("{:?}", params);
let url = format!("{}{}", self.endpoint, path);
@@ -96,7 +100,7 @@ impl Client {
if (200..=299).contains(&response.status_code) {
Ok(response)
} else {
Err(APIError {
Err(ApiError {
message: format!(
"{}: {}",
response.status_code,
@@ -114,7 +118,7 @@ impl Client {
model: Model,
messages: Vec<ChatCompletionMessage>,
options: Option<ChatCompletionParams>,
) -> Result<ChatCompletionResponse, APIError> {
) -> Result<ChatCompletionResponse, ApiError> {
let request = ChatCompletionRequest::new(model, messages, options);
let response = self.post("/chat/completions", &request)?;
@@ -130,7 +134,7 @@ impl Client {
model: EmbedModel,
input: Vec<String>,
options: Option<EmbeddingRequestOptions>,
) -> Result<EmbeddingResponse, APIError> {
) -> Result<EmbeddingResponse, ApiError> {
let request = EmbeddingRequest::new(model, input, options);
let response = self.post("/embeddings", &request)?;
@@ -141,7 +145,7 @@ impl Client {
}
}
pub fn list_models(&self) -> Result<ModelListResponse, APIError> {
pub fn list_models(&self) -> Result<ModelListResponse, ApiError> {
let response = self.get("/models")?;
let result = response.json::<ModelListResponse>();
match result {
@@ -150,8 +154,8 @@ impl Client {
}
}
fn new_error(&self, err: minreq::Error) -> APIError {
APIError {
fn new_error(&self, err: minreq::Error) -> ApiError {
ApiError {
message: err.to_string(),
}
}

View File

@@ -2,14 +2,18 @@ use std::error::Error;
use std::fmt;
#[derive(Debug)]
pub struct APIError {
pub struct ApiError {
pub message: String,
}
impl fmt::Display for APIError {
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "APIError: {}", self.message)
write!(f, "ApiError: {}", self.message)
}
}
impl Error for ApiError {}
impl Error for APIError {}
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")]
ApiKeyError,
}

View File

@@ -6,12 +6,7 @@ use mistralai_client::v1::{
};
#[test]
fn test_chat_completion() {
extern crate dotenv;
use dotenv::dotenv;
dotenv().ok();
fn test_client_chat() {
let client = Client::new(None, None, None, None);
let model = Model::OpenMistral7b;

View File

@@ -2,12 +2,7 @@ use jrest::expect;
use mistralai_client::v1::{client::Client, constants::EmbedModel};
#[test]
fn test_embeddings() {
extern crate dotenv;
use dotenv::dotenv;
dotenv().ok();
fn test_client_embeddings() {
let client: Client = Client::new(None, None, None, None);
let model = EmbedModel::MistralEmbed;

View File

@@ -2,12 +2,7 @@ use jrest::expect;
use mistralai_client::v1::client::Client;
#[test]
fn test_list_models() {
extern crate dotenv;
use dotenv::dotenv;
dotenv().ok();
fn test_client_list_models() {
let client = Client::new(None, None, None, None);
let response = client.list_models().unwrap();

View File

@@ -4,6 +4,7 @@ use mistralai_client::v1::client::Client;
#[test]
fn test_client_new_with_none_params() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
let client = Client::new(None, None, None, None);
@@ -24,6 +25,7 @@ fn test_client_new_with_none_params() {
#[test]
fn test_client_new_with_all_params() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
let api_key = Some("test_api_key_from_param".to_string());
@@ -50,3 +52,19 @@ fn test_client_new_with_all_params() {
None => std::env::remove_var("MISTRAL_API_KEY"),
}
}
#[test]
#[should_panic]
fn test_client_new_with_missing_api_key() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
let _client = Client::new(None, None, None, None);
match maybe_original_mistral_api_key {
Some(original_mistral_api_key) => {
std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key)
}
None => std::env::remove_var("MISTRAL_API_KEY"),
}
}