diff --git a/.env.example b/.env.example index 385fddf..ad6f95e 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,3 @@ # This key is only used for development purposes. # You'll only need one if you want to contribute to this library. -MISTRAL_API_KEY= +export MISTRAL_API_KEY= diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f5d34d9..22b0ea6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,8 +27,15 @@ Then run: git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork cd ./mistralai-client-rs cargo build +cp .env.example .env ``` +Then edit the `.env` file to set your `MISTRAL_API_KEY`. + +> [!NOTE] +> All tests use either the `open-mistral-7b` or `mistral-embed` models and only consume a few dozen tokens. +> So you would have to run them thousands of times to even reach a single dollar of usage. + ### Optional requirements - [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-*-watch`. @@ -51,5 +58,4 @@ Help us keep this project open and inclusive. Please read and follow our [Code o ## Commit Message Format -This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification and -specificaly the [Angular Commit Message Guidelines](https://github.com/angular/angular/blob/main/CONTRIBUTING.md#commit). +This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification. diff --git a/Cargo.toml b/Cargo.toml index d428654..0a60c77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,6 @@ minreq = { version = "2.11.0", features = ["https-rustls", "json-using-serde"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" thiserror = "1.0.57" -tokio = { version = "1.36.0", features = ["full"] } [dev-dependencies] -dotenv = "0.15.0" jrest = "0.2.3" diff --git a/Makefile b/Makefile index 07b1b2c..ed2623c 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,9 @@ +SHELL := /bin/bash + .PHONY: test define RELEASE_TEMPLATE - conventional-changelog -p conventionalcommits -i CHANGELOG.md -s + conventional-changelog -p conventionalcommits -i ./CHANGELOG.md -s git add . git commit -m "docs(changelog): update" git push origin HEAD @@ -9,13 +11,6 @@ define RELEASE_TEMPLATE git push origin HEAD --tags endef -test: - cargo test --no-fail-fast -test-cover: - cargo tarpaulin --frozen --no-fail-fast --out Xml --skip-clean -test-watch: - cargo watch -x "test -- --nocapture" - release-patch: $(call RELEASE_TEMPLATE,patch) @@ -24,3 +19,10 @@ release-minor: release-major: $(call RELEASE_TEMPLATE,major) + +test: + @source ./.env && cargo test --all-targets --no-fail-fast +test-cover: + cargo tarpaulin --all-targets --frozen --no-fail-fast --out Xml --skip-clean +test-watch: + cargo watch -x "test -- --all-targets --nocapture" diff --git a/src/v1/client.rs b/src/v1/client.rs index f454bd2..9248c2c 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -1,4 +1,4 @@ -use crate::v1::error::APIError; +use crate::v1::error::ApiError; use minreq::Response; use crate::v1::{ @@ -7,6 +7,7 @@ use crate::v1::{ }, constants::{EmbedModel, Model, API_URL_BASE}, embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse}, + error::ClientError, model_list::ModelListResponse, }; @@ -24,7 +25,10 @@ impl Client { max_retries: Option, timeout: Option, ) -> Self { - let api_key = api_key.unwrap_or(std::env::var("MISTRAL_API_KEY").unwrap()); + let api_key = api_key.unwrap_or( + std::env::var("MISTRAL_API_KEY") + .unwrap_or_else(|_| panic!("{}", ClientError::ApiKeyError)), + ); let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string()); let max_retries = max_retries.unwrap_or(5); let timeout = timeout.unwrap_or(120); @@ -53,7 +57,7 @@ impl Client { request } - pub fn get(&self, path: &str) -> Result { + pub fn get(&self, path: &str) -> Result { let url = format!("{}{}", self.endpoint, path); let request = self.build_request(minreq::get(url)); @@ -65,7 +69,7 @@ impl Client { if (200..=299).contains(&response.status_code) { Ok(response) } else { - Err(APIError { + Err(ApiError { message: format!( "{}: {}", response.status_code, @@ -82,7 +86,7 @@ impl Client { &self, path: &str, params: &T, - ) -> Result { + ) -> Result { // print!("{:?}", params); let url = format!("{}{}", self.endpoint, path); @@ -96,7 +100,7 @@ impl Client { if (200..=299).contains(&response.status_code) { Ok(response) } else { - Err(APIError { + Err(ApiError { message: format!( "{}: {}", response.status_code, @@ -114,7 +118,7 @@ impl Client { model: Model, messages: Vec, options: Option, - ) -> Result { + ) -> Result { let request = ChatCompletionRequest::new(model, messages, options); let response = self.post("/chat/completions", &request)?; @@ -130,7 +134,7 @@ impl Client { model: EmbedModel, input: Vec, options: Option, - ) -> Result { + ) -> Result { let request = EmbeddingRequest::new(model, input, options); let response = self.post("/embeddings", &request)?; @@ -141,7 +145,7 @@ impl Client { } } - pub fn list_models(&self) -> Result { + pub fn list_models(&self) -> Result { let response = self.get("/models")?; let result = response.json::(); match result { @@ -150,8 +154,8 @@ impl Client { } } - fn new_error(&self, err: minreq::Error) -> APIError { - APIError { + fn new_error(&self, err: minreq::Error) -> ApiError { + ApiError { message: err.to_string(), } } diff --git a/src/v1/error.rs b/src/v1/error.rs index b2d2bf0..74f3893 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -2,14 +2,18 @@ use std::error::Error; use std::fmt; #[derive(Debug)] -pub struct APIError { +pub struct ApiError { pub message: String, } - -impl fmt::Display for APIError { +impl fmt::Display for ApiError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "APIError: {}", self.message) + write!(f, "ApiError: {}", self.message) } } +impl Error for ApiError {} -impl Error for APIError {} +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")] + ApiKeyError, +} diff --git a/tests/v1_chat_completion_test.rs b/tests/v1_client_chat_test.rs similarity index 93% rename from tests/v1_chat_completion_test.rs rename to tests/v1_client_chat_test.rs index 60133f6..d7cc9b6 100644 --- a/tests/v1_chat_completion_test.rs +++ b/tests/v1_client_chat_test.rs @@ -6,12 +6,7 @@ use mistralai_client::v1::{ }; #[test] -fn test_chat_completion() { - extern crate dotenv; - - use dotenv::dotenv; - dotenv().ok(); - +fn test_client_chat() { let client = Client::new(None, None, None, None); let model = Model::OpenMistral7b; diff --git a/tests/v1_embeddings_test.rs b/tests/v1_client_embeddings_test.rs similarity index 91% rename from tests/v1_embeddings_test.rs rename to tests/v1_client_embeddings_test.rs index a0be427..2c21ba6 100644 --- a/tests/v1_embeddings_test.rs +++ b/tests/v1_client_embeddings_test.rs @@ -2,12 +2,7 @@ use jrest::expect; use mistralai_client::v1::{client::Client, constants::EmbedModel}; #[test] -fn test_embeddings() { - extern crate dotenv; - - use dotenv::dotenv; - dotenv().ok(); - +fn test_client_embeddings() { let client: Client = Client::new(None, None, None, None); let model = EmbedModel::MistralEmbed; diff --git a/tests/v1_list_models_test.rs b/tests/v1_client_list_models_test.rs similarity index 85% rename from tests/v1_list_models_test.rs rename to tests/v1_client_list_models_test.rs index b2ea09e..299866a 100644 --- a/tests/v1_list_models_test.rs +++ b/tests/v1_client_list_models_test.rs @@ -2,12 +2,7 @@ use jrest::expect; use mistralai_client::v1::client::Client; #[test] -fn test_list_models() { - extern crate dotenv; - - use dotenv::dotenv; - dotenv().ok(); - +fn test_client_list_models() { let client = Client::new(None, None, None, None); let response = client.list_models().unwrap(); diff --git a/tests/v1_client_test.rs b/tests/v1_client_new_test.rs similarity index 75% rename from tests/v1_client_test.rs rename to tests/v1_client_new_test.rs index 91a0c39..e461447 100644 --- a/tests/v1_client_test.rs +++ b/tests/v1_client_new_test.rs @@ -4,6 +4,7 @@ use mistralai_client::v1::client::Client; #[test] fn test_client_new_with_none_params() { let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); + std::env::remove_var("MISTRAL_API_KEY"); std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); let client = Client::new(None, None, None, None); @@ -24,6 +25,7 @@ fn test_client_new_with_none_params() { #[test] fn test_client_new_with_all_params() { let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); + std::env::remove_var("MISTRAL_API_KEY"); std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); let api_key = Some("test_api_key_from_param".to_string()); @@ -50,3 +52,19 @@ fn test_client_new_with_all_params() { None => std::env::remove_var("MISTRAL_API_KEY"), } } + +#[test] +#[should_panic] +fn test_client_new_with_missing_api_key() { + let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); + std::env::remove_var("MISTRAL_API_KEY"); + + let _client = Client::new(None, None, None, None); + + match maybe_original_mistral_api_key { + Some(original_mistral_api_key) => { + std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) + } + None => std::env::remove_var("MISTRAL_API_KEY"), + } +}