feat!: add missing api key error
BREAKING CHANGE: `APIError` is renamed to `ApiError`.
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
# This key is only used for development purposes.
|
# This key is only used for development purposes.
|
||||||
# You'll only need one if you want to contribute to this library.
|
# You'll only need one if you want to contribute to this library.
|
||||||
MISTRAL_API_KEY=
|
export MISTRAL_API_KEY=
|
||||||
|
|||||||
@@ -27,8 +27,15 @@ Then run:
|
|||||||
git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork
|
git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork
|
||||||
cd ./mistralai-client-rs
|
cd ./mistralai-client-rs
|
||||||
cargo build
|
cargo build
|
||||||
|
cp .env.example .env
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Then edit the `.env` file to set your `MISTRAL_API_KEY`.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> All tests use either the `open-mistral-7b` or `mistral-embed` models and only consume a few dozen tokens.
|
||||||
|
> So you would have to run them thousands of times to even reach a single dollar of usage.
|
||||||
|
|
||||||
### Optional requirements
|
### Optional requirements
|
||||||
|
|
||||||
- [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-*-watch`.
|
- [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-*-watch`.
|
||||||
@@ -51,5 +58,4 @@ Help us keep this project open and inclusive. Please read and follow our [Code o
|
|||||||
|
|
||||||
## Commit Message Format
|
## Commit Message Format
|
||||||
|
|
||||||
This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification and
|
This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
|
||||||
specificaly the [Angular Commit Message Guidelines](https://github.com/angular/angular/blob/main/CONTRIBUTING.md#commit).
|
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ minreq = { version = "2.11.0", features = ["https-rustls", "json-using-serde"] }
|
|||||||
serde = { version = "1.0.197", features = ["derive"] }
|
serde = { version = "1.0.197", features = ["derive"] }
|
||||||
serde_json = "1.0.114"
|
serde_json = "1.0.114"
|
||||||
thiserror = "1.0.57"
|
thiserror = "1.0.57"
|
||||||
tokio = { version = "1.36.0", features = ["full"] }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
dotenv = "0.15.0"
|
|
||||||
jrest = "0.2.3"
|
jrest = "0.2.3"
|
||||||
|
|||||||
18
Makefile
18
Makefile
@@ -1,7 +1,9 @@
|
|||||||
|
SHELL := /bin/bash
|
||||||
|
|
||||||
.PHONY: test
|
.PHONY: test
|
||||||
|
|
||||||
define RELEASE_TEMPLATE
|
define RELEASE_TEMPLATE
|
||||||
conventional-changelog -p conventionalcommits -i CHANGELOG.md -s
|
conventional-changelog -p conventionalcommits -i ./CHANGELOG.md -s
|
||||||
git add .
|
git add .
|
||||||
git commit -m "docs(changelog): update"
|
git commit -m "docs(changelog): update"
|
||||||
git push origin HEAD
|
git push origin HEAD
|
||||||
@@ -9,13 +11,6 @@ define RELEASE_TEMPLATE
|
|||||||
git push origin HEAD --tags
|
git push origin HEAD --tags
|
||||||
endef
|
endef
|
||||||
|
|
||||||
test:
|
|
||||||
cargo test --no-fail-fast
|
|
||||||
test-cover:
|
|
||||||
cargo tarpaulin --frozen --no-fail-fast --out Xml --skip-clean
|
|
||||||
test-watch:
|
|
||||||
cargo watch -x "test -- --nocapture"
|
|
||||||
|
|
||||||
release-patch:
|
release-patch:
|
||||||
$(call RELEASE_TEMPLATE,patch)
|
$(call RELEASE_TEMPLATE,patch)
|
||||||
|
|
||||||
@@ -24,3 +19,10 @@ release-minor:
|
|||||||
|
|
||||||
release-major:
|
release-major:
|
||||||
$(call RELEASE_TEMPLATE,major)
|
$(call RELEASE_TEMPLATE,major)
|
||||||
|
|
||||||
|
test:
|
||||||
|
@source ./.env && cargo test --all-targets --no-fail-fast
|
||||||
|
test-cover:
|
||||||
|
cargo tarpaulin --all-targets --frozen --no-fail-fast --out Xml --skip-clean
|
||||||
|
test-watch:
|
||||||
|
cargo watch -x "test -- --all-targets --nocapture"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::v1::error::APIError;
|
use crate::v1::error::ApiError;
|
||||||
use minreq::Response;
|
use minreq::Response;
|
||||||
|
|
||||||
use crate::v1::{
|
use crate::v1::{
|
||||||
@@ -7,6 +7,7 @@ use crate::v1::{
|
|||||||
},
|
},
|
||||||
constants::{EmbedModel, Model, API_URL_BASE},
|
constants::{EmbedModel, Model, API_URL_BASE},
|
||||||
embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse},
|
embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse},
|
||||||
|
error::ClientError,
|
||||||
model_list::ModelListResponse,
|
model_list::ModelListResponse,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -24,7 +25,10 @@ impl Client {
|
|||||||
max_retries: Option<u32>,
|
max_retries: Option<u32>,
|
||||||
timeout: Option<u32>,
|
timeout: Option<u32>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let api_key = api_key.unwrap_or(std::env::var("MISTRAL_API_KEY").unwrap());
|
let api_key = api_key.unwrap_or(
|
||||||
|
std::env::var("MISTRAL_API_KEY")
|
||||||
|
.unwrap_or_else(|_| panic!("{}", ClientError::ApiKeyError)),
|
||||||
|
);
|
||||||
let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string());
|
let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string());
|
||||||
let max_retries = max_retries.unwrap_or(5);
|
let max_retries = max_retries.unwrap_or(5);
|
||||||
let timeout = timeout.unwrap_or(120);
|
let timeout = timeout.unwrap_or(120);
|
||||||
@@ -53,7 +57,7 @@ impl Client {
|
|||||||
request
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, path: &str) -> Result<Response, APIError> {
|
pub fn get(&self, path: &str) -> Result<Response, ApiError> {
|
||||||
let url = format!("{}{}", self.endpoint, path);
|
let url = format!("{}{}", self.endpoint, path);
|
||||||
let request = self.build_request(minreq::get(url));
|
let request = self.build_request(minreq::get(url));
|
||||||
|
|
||||||
@@ -65,7 +69,7 @@ impl Client {
|
|||||||
if (200..=299).contains(&response.status_code) {
|
if (200..=299).contains(&response.status_code) {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
} else {
|
} else {
|
||||||
Err(APIError {
|
Err(ApiError {
|
||||||
message: format!(
|
message: format!(
|
||||||
"{}: {}",
|
"{}: {}",
|
||||||
response.status_code,
|
response.status_code,
|
||||||
@@ -82,7 +86,7 @@ impl Client {
|
|||||||
&self,
|
&self,
|
||||||
path: &str,
|
path: &str,
|
||||||
params: &T,
|
params: &T,
|
||||||
) -> Result<Response, APIError> {
|
) -> Result<Response, ApiError> {
|
||||||
// print!("{:?}", params);
|
// print!("{:?}", params);
|
||||||
|
|
||||||
let url = format!("{}{}", self.endpoint, path);
|
let url = format!("{}{}", self.endpoint, path);
|
||||||
@@ -96,7 +100,7 @@ impl Client {
|
|||||||
if (200..=299).contains(&response.status_code) {
|
if (200..=299).contains(&response.status_code) {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
} else {
|
} else {
|
||||||
Err(APIError {
|
Err(ApiError {
|
||||||
message: format!(
|
message: format!(
|
||||||
"{}: {}",
|
"{}: {}",
|
||||||
response.status_code,
|
response.status_code,
|
||||||
@@ -114,7 +118,7 @@ impl Client {
|
|||||||
model: Model,
|
model: Model,
|
||||||
messages: Vec<ChatCompletionMessage>,
|
messages: Vec<ChatCompletionMessage>,
|
||||||
options: Option<ChatCompletionParams>,
|
options: Option<ChatCompletionParams>,
|
||||||
) -> Result<ChatCompletionResponse, APIError> {
|
) -> Result<ChatCompletionResponse, ApiError> {
|
||||||
let request = ChatCompletionRequest::new(model, messages, options);
|
let request = ChatCompletionRequest::new(model, messages, options);
|
||||||
|
|
||||||
let response = self.post("/chat/completions", &request)?;
|
let response = self.post("/chat/completions", &request)?;
|
||||||
@@ -130,7 +134,7 @@ impl Client {
|
|||||||
model: EmbedModel,
|
model: EmbedModel,
|
||||||
input: Vec<String>,
|
input: Vec<String>,
|
||||||
options: Option<EmbeddingRequestOptions>,
|
options: Option<EmbeddingRequestOptions>,
|
||||||
) -> Result<EmbeddingResponse, APIError> {
|
) -> Result<EmbeddingResponse, ApiError> {
|
||||||
let request = EmbeddingRequest::new(model, input, options);
|
let request = EmbeddingRequest::new(model, input, options);
|
||||||
|
|
||||||
let response = self.post("/embeddings", &request)?;
|
let response = self.post("/embeddings", &request)?;
|
||||||
@@ -141,7 +145,7 @@ impl Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn list_models(&self) -> Result<ModelListResponse, APIError> {
|
pub fn list_models(&self) -> Result<ModelListResponse, ApiError> {
|
||||||
let response = self.get("/models")?;
|
let response = self.get("/models")?;
|
||||||
let result = response.json::<ModelListResponse>();
|
let result = response.json::<ModelListResponse>();
|
||||||
match result {
|
match result {
|
||||||
@@ -150,8 +154,8 @@ impl Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_error(&self, err: minreq::Error) -> APIError {
|
fn new_error(&self, err: minreq::Error) -> ApiError {
|
||||||
APIError {
|
ApiError {
|
||||||
message: err.to_string(),
|
message: err.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,18 @@ use std::error::Error;
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct APIError {
|
pub struct ApiError {
|
||||||
pub message: String,
|
pub message: String,
|
||||||
}
|
}
|
||||||
|
impl fmt::Display for ApiError {
|
||||||
impl fmt::Display for APIError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
write!(f, "APIError: {}", self.message)
|
write!(f, "ApiError: {}", self.message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl Error for ApiError {}
|
||||||
|
|
||||||
impl Error for APIError {}
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ClientError {
|
||||||
|
#[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")]
|
||||||
|
ApiKeyError,
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,12 +6,7 @@ use mistralai_client::v1::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_completion() {
|
fn test_client_chat() {
|
||||||
extern crate dotenv;
|
|
||||||
|
|
||||||
use dotenv::dotenv;
|
|
||||||
dotenv().ok();
|
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None);
|
let client = Client::new(None, None, None, None);
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::OpenMistral7b;
|
||||||
@@ -2,12 +2,7 @@ use jrest::expect;
|
|||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_embeddings() {
|
fn test_client_embeddings() {
|
||||||
extern crate dotenv;
|
|
||||||
|
|
||||||
use dotenv::dotenv;
|
|
||||||
dotenv().ok();
|
|
||||||
|
|
||||||
let client: Client = Client::new(None, None, None, None);
|
let client: Client = Client::new(None, None, None, None);
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let model = EmbedModel::MistralEmbed;
|
||||||
@@ -2,12 +2,7 @@ use jrest::expect;
|
|||||||
use mistralai_client::v1::client::Client;
|
use mistralai_client::v1::client::Client;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_list_models() {
|
fn test_client_list_models() {
|
||||||
extern crate dotenv;
|
|
||||||
|
|
||||||
use dotenv::dotenv;
|
|
||||||
dotenv().ok();
|
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None);
|
let client = Client::new(None, None, None, None);
|
||||||
|
|
||||||
let response = client.list_models().unwrap();
|
let response = client.list_models().unwrap();
|
||||||
@@ -4,6 +4,7 @@ use mistralai_client::v1::client::Client;
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_client_new_with_none_params() {
|
fn test_client_new_with_none_params() {
|
||||||
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
||||||
|
std::env::remove_var("MISTRAL_API_KEY");
|
||||||
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
|
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None);
|
let client = Client::new(None, None, None, None);
|
||||||
@@ -24,6 +25,7 @@ fn test_client_new_with_none_params() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_client_new_with_all_params() {
|
fn test_client_new_with_all_params() {
|
||||||
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
||||||
|
std::env::remove_var("MISTRAL_API_KEY");
|
||||||
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
|
std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env");
|
||||||
|
|
||||||
let api_key = Some("test_api_key_from_param".to_string());
|
let api_key = Some("test_api_key_from_param".to_string());
|
||||||
@@ -50,3 +52,19 @@ fn test_client_new_with_all_params() {
|
|||||||
None => std::env::remove_var("MISTRAL_API_KEY"),
|
None => std::env::remove_var("MISTRAL_API_KEY"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn test_client_new_with_missing_api_key() {
|
||||||
|
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
||||||
|
std::env::remove_var("MISTRAL_API_KEY");
|
||||||
|
|
||||||
|
let _client = Client::new(None, None, None, None);
|
||||||
|
|
||||||
|
match maybe_original_mistral_api_key {
|
||||||
|
Some(original_mistral_api_key) => {
|
||||||
|
std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key)
|
||||||
|
}
|
||||||
|
None => std::env::remove_var("MISTRAL_API_KEY"),
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user