feat!: add missing api key error

BREAKING CHANGE: `APIError` is renamed to `ApiError`.
This commit is contained in:
Ivan Gabriele
2024-03-04 04:21:03 +01:00
parent b0a3f10c9f
commit 1deab88251
10 changed files with 64 additions and 47 deletions

View File

@@ -1,3 +1,3 @@
# This key is only used for development purposes. # This key is only used for development purposes.
# You'll only need one if you want to contribute to this library. # You'll only need one if you want to contribute to this library.
MISTRAL_API_KEY= export MISTRAL_API_KEY=

View File

@@ -27,8 +27,15 @@ Then run:
git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork
cd ./mistralai-client-rs cd ./mistralai-client-rs
cargo build cargo build
cp .env.example .env
``` ```
Then edit the `.env` file to set your `MISTRAL_API_KEY`.
> [!NOTE]
> All tests use either the `open-mistral-7b` or `mistral-embed` models and only consume a few dozen tokens.
> So you would have to run them thousands of times to even reach a single dollar of usage.
### Optional requirements ### Optional requirements
- [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-*-watch`. - [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-*-watch`.
@@ -51,5 +58,4 @@ Help us keep this project open and inclusive. Please read and follow our [Code o
## Commit Message Format ## Commit Message Format
This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification and This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
specificaly the [Angular Commit Message Guidelines](https://github.com/angular/angular/blob/main/CONTRIBUTING.md#commit).

View File

@@ -19,8 +19,6 @@ minreq = { version = "2.11.0", features = ["https-rustls", "json-using-serde"] }
serde = { version = "1.0.197", features = ["derive"] } serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114" serde_json = "1.0.114"
thiserror = "1.0.57" thiserror = "1.0.57"
tokio = { version = "1.36.0", features = ["full"] }
[dev-dependencies] [dev-dependencies]
dotenv = "0.15.0"
jrest = "0.2.3" jrest = "0.2.3"

View File

@@ -1,7 +1,9 @@
SHELL := /bin/bash
.PHONY: test .PHONY: test
define RELEASE_TEMPLATE define RELEASE_TEMPLATE
conventional-changelog -p conventionalcommits -i CHANGELOG.md -s conventional-changelog -p conventionalcommits -i ./CHANGELOG.md -s
git add . git add .
git commit -m "docs(changelog): update" git commit -m "docs(changelog): update"
git push origin HEAD git push origin HEAD
@@ -9,13 +11,6 @@ define RELEASE_TEMPLATE
git push origin HEAD --tags git push origin HEAD --tags
endef endef
test:
cargo test --no-fail-fast
test-cover:
cargo tarpaulin --frozen --no-fail-fast --out Xml --skip-clean
test-watch:
cargo watch -x "test -- --nocapture"
release-patch: release-patch:
$(call RELEASE_TEMPLATE,patch) $(call RELEASE_TEMPLATE,patch)
@@ -24,3 +19,10 @@ release-minor:
release-major: release-major:
$(call RELEASE_TEMPLATE,major) $(call RELEASE_TEMPLATE,major)
test:
@source ./.env && cargo test --all-targets --no-fail-fast
test-cover:
cargo tarpaulin --all-targets --frozen --no-fail-fast --out Xml --skip-clean
test-watch:
cargo watch -x "test -- --all-targets --nocapture"

View File

@@ -1,4 +1,4 @@
use crate::v1::error::APIError; use crate::v1::error::ApiError;
use minreq::Response; use minreq::Response;
use crate::v1::{ use crate::v1::{
@@ -7,6 +7,7 @@ use crate::v1::{
}, },
constants::{EmbedModel, Model, API_URL_BASE}, constants::{EmbedModel, Model, API_URL_BASE},
embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse}, embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse},
error::ClientError,
model_list::ModelListResponse, model_list::ModelListResponse,
}; };
@@ -24,7 +25,10 @@ impl Client {
max_retries: Option<u32>, max_retries: Option<u32>,
timeout: Option<u32>, timeout: Option<u32>,
) -> Self { ) -> Self {
let api_key = api_key.unwrap_or(std::env::var("MISTRAL_API_KEY").unwrap()); let api_key = api_key.unwrap_or(
std::env::var("MISTRAL_API_KEY")
.unwrap_or_else(|_| panic!("{}", ClientError::ApiKeyError)),
);
let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string()); let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string());
let max_retries = max_retries.unwrap_or(5); let max_retries = max_retries.unwrap_or(5);
let timeout = timeout.unwrap_or(120); let timeout = timeout.unwrap_or(120);
@@ -53,7 +57,7 @@ impl Client {
request request
} }
pub fn get(&self, path: &str) -> Result<Response, APIError> { pub fn get(&self, path: &str) -> Result<Response, ApiError> {
let url = format!("{}{}", self.endpoint, path); let url = format!("{}{}", self.endpoint, path);
let request = self.build_request(minreq::get(url)); let request = self.build_request(minreq::get(url));
@@ -65,7 +69,7 @@ impl Client {
if (200..=299).contains(&response.status_code) { if (200..=299).contains(&response.status_code) {
Ok(response) Ok(response)
} else { } else {
Err(APIError { Err(ApiError {
message: format!( message: format!(
"{}: {}", "{}: {}",
response.status_code, response.status_code,
@@ -82,7 +86,7 @@ impl Client {
&self, &self,
path: &str, path: &str,
params: &T, params: &T,
) -> Result<Response, APIError> { ) -> Result<Response, ApiError> {
// print!("{:?}", params); // print!("{:?}", params);
let url = format!("{}{}", self.endpoint, path); let url = format!("{}{}", self.endpoint, path);
@@ -96,7 +100,7 @@ impl Client {
if (200..=299).contains(&response.status_code) { if (200..=299).contains(&response.status_code) {
Ok(response) Ok(response)
} else { } else {
Err(APIError { Err(ApiError {
message: format!( message: format!(
"{}: {}", "{}: {}",
response.status_code, response.status_code,
@@ -114,7 +118,7 @@ impl Client {
model: Model, model: Model,
messages: Vec<ChatCompletionMessage>, messages: Vec<ChatCompletionMessage>,
options: Option<ChatCompletionParams>, options: Option<ChatCompletionParams>,
) -> Result<ChatCompletionResponse, APIError> { ) -> Result<ChatCompletionResponse, ApiError> {
let request = ChatCompletionRequest::new(model, messages, options); let request = ChatCompletionRequest::new(model, messages, options);
let response = self.post("/chat/completions", &request)?; let response = self.post("/chat/completions", &request)?;
@@ -130,7 +134,7 @@ impl Client {
model: EmbedModel, model: EmbedModel,
input: Vec<String>, input: Vec<String>,
options: Option<EmbeddingRequestOptions>, options: Option<EmbeddingRequestOptions>,
) -> Result<EmbeddingResponse, APIError> { ) -> Result<EmbeddingResponse, ApiError> {
let request = EmbeddingRequest::new(model, input, options); let request = EmbeddingRequest::new(model, input, options);
let response = self.post("/embeddings", &request)?; let response = self.post("/embeddings", &request)?;
@@ -141,7 +145,7 @@ impl Client {
} }
} }
pub fn list_models(&self) -> Result<ModelListResponse, APIError> { pub fn list_models(&self) -> Result<ModelListResponse, ApiError> {
let response = self.get("/models")?; let response = self.get("/models")?;
let result = response.json::<ModelListResponse>(); let result = response.json::<ModelListResponse>();
match result { match result {
@@ -150,8 +154,8 @@ impl Client {
} }
} }
fn new_error(&self, err: minreq::Error) -> APIError { fn new_error(&self, err: minreq::Error) -> ApiError {
APIError { ApiError {
message: err.to_string(), message: err.to_string(),
} }
} }

View File

@@ -2,14 +2,18 @@ use std::error::Error;
use std::fmt; use std::fmt;
#[derive(Debug)] #[derive(Debug)]
pub struct APIError { pub struct ApiError {
pub message: String, pub message: String,
} }
impl fmt::Display for ApiError {
impl fmt::Display for APIError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "APIError: {}", self.message) write!(f, "ApiError: {}", self.message)
} }
} }
impl Error for ApiError {}
impl Error for APIError {} #[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")]
ApiKeyError,
}

View File

@@ -6,12 +6,7 @@ use mistralai_client::v1::{
}; };
#[test] #[test]
fn test_chat_completion() { fn test_client_chat() {
extern crate dotenv;
use dotenv::dotenv;
dotenv().ok();
let client = Client::new(None, None, None, None); let client = Client::new(None, None, None, None);
let model = Model::OpenMistral7b; let model = Model::OpenMistral7b;

View File

@@ -2,12 +2,7 @@ use jrest::expect;
use mistralai_client::v1::{client::Client, constants::EmbedModel}; use mistralai_client::v1::{client::Client, constants::EmbedModel};
#[test] #[test]
fn test_embeddings() { fn test_client_embeddings() {
extern crate dotenv;
use dotenv::dotenv;
dotenv().ok();
let client: Client = Client::new(None, None, None, None); let client: Client = Client::new(None, None, None, None);
let model = EmbedModel::MistralEmbed; let model = EmbedModel::MistralEmbed;

View File

@@ -2,12 +2,7 @@ use jrest::expect;
use mistralai_client::v1::client::Client; use mistralai_client::v1::client::Client;
#[test] #[test]
fn test_list_models() { fn test_client_list_models() {
extern crate dotenv;
use dotenv::dotenv;
dotenv().ok();
let client = Client::new(None, None, None, None); let client = Client::new(None, None, None, None);
let response = client.list_models().unwrap(); let response = client.list_models().unwrap();

View File

@@ -4,6 +4,7 @@ use mistralai_client::v1::client::Client;
#[test] #[test]
fn test_client_new_with_none_params() { fn test_client_new_with_none_params() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
let client = Client::new(None, None, None, None); let client = Client::new(None, None, None, None);
@@ -24,6 +25,7 @@ fn test_client_new_with_none_params() {
#[test] #[test]
fn test_client_new_with_all_params() { fn test_client_new_with_all_params() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
let api_key = Some("test_api_key_from_param".to_string()); let api_key = Some("test_api_key_from_param".to_string());
@@ -50,3 +52,19 @@ fn test_client_new_with_all_params() {
None => std::env::remove_var("MISTRAL_API_KEY"), None => std::env::remove_var("MISTRAL_API_KEY"),
} }
} }
#[test]
#[should_panic]
fn test_client_new_with_missing_api_key() {
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
std::env::remove_var("MISTRAL_API_KEY");
let _client = Client::new(None, None, None, None);
match maybe_original_mistral_api_key {
Some(original_mistral_api_key) => {
std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key)
}
None => std::env::remove_var("MISTRAL_API_KEY"),
}
}