feat: add client.list_models_async() method

This commit is contained in:
Ivan Gabriele
2024-03-04 06:32:01 +01:00
parent 75788b9395
commit b69f7c617c
3 changed files with 183 additions and 142 deletions

View File

@@ -33,7 +33,7 @@ Rust client for the Mistral AI API.
- [x] Embedding
- [ ] Embedding (async)
- [x] List models
- [ ] List models (async)
- [x] List models (async)
- [ ] Function Calling
- [ ] Function Calling (async)
@@ -175,4 +175,16 @@ fn main() {
### List models (async)
_In progress._
```rs
use mistralai_client::v1::client::Client;
#[tokio::main]
async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).await.unwrap();
let result = client.list_models().unwrap();
println!("First Model ID: {:?}", result.data[0].id);
// => "First Model ID: open-mistral-7b"
}
```

View File

@@ -41,146 +41,6 @@ impl Client {
})
}
pub fn build_request_sync(
&self,
request: reqwest::blocking::RequestBuilder,
) -> reqwest::blocking::RequestBuilder {
let user_agent = format!(
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
.bearer_auth(&self.api_key)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("User-Agent", user_agent);
request_builder
}
pub fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let user_agent = format!(
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
.bearer_auth(&self.api_key)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("User-Agent", user_agent);
request_builder
}
pub fn get_sync(&self, path: &str) -> Result<reqwest::blocking::Response, ApiError> {
let client_sync = reqwest::blocking::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request = self.build_request_sync(client_sync.get(url));
let result = request.send();
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().unwrap();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
pub async fn get_async(&self, path: &str) -> Result<reqwest::Response, ApiError> {
let reqwest_client = reqwest::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request_builder = reqwest_client.get(url);
let request = self.build_request_async(request_builder);
let result = request.send().await.map_err(|e| self.to_api_error(e));
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().await.unwrap_or_default();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
pub fn post_sync<T: serde::ser::Serialize + std::fmt::Debug>(
&self,
path: &str,
params: &T,
) -> Result<reqwest::blocking::Response, ApiError> {
let reqwest_client = reqwest::blocking::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request_builder = reqwest_client.post(url).json(params);
let request = self.build_request_sync(request_builder);
let result = request.send();
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().unwrap_or_default();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
pub async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
&self,
path: &str,
params: &T,
) -> Result<reqwest::Response, ApiError> {
let reqwest_client = reqwest::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request_builder = reqwest_client.post(url).json(params);
let request = self.build_request_async(request_builder);
let result = request.send().await.map_err(|e| self.to_api_error(e));
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().await.unwrap_or_default();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
pub fn chat(
&self,
model: Model,
@@ -238,6 +98,155 @@ impl Client {
}
}
pub async fn list_models_async(&self) -> Result<ModelListResponse, ApiError> {
let response = self.get_async("/models").await?;
let result = response.json::<ModelListResponse>().await;
match result {
Ok(response) => Ok(response),
Err(error) => Err(self.to_api_error(error)),
}
}
fn build_request_sync(
&self,
request: reqwest::blocking::RequestBuilder,
) -> reqwest::blocking::RequestBuilder {
let user_agent = format!(
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
.bearer_auth(&self.api_key)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("User-Agent", user_agent);
request_builder
}
fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let user_agent = format!(
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
.bearer_auth(&self.api_key)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("User-Agent", user_agent);
request_builder
}
fn get_sync(&self, path: &str) -> Result<reqwest::blocking::Response, ApiError> {
let client_sync = reqwest::blocking::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request = self.build_request_sync(client_sync.get(url));
let result = request.send();
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().unwrap();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
async fn get_async(&self, path: &str) -> Result<reqwest::Response, ApiError> {
let reqwest_client = reqwest::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request_builder = reqwest_client.get(url);
let request = self.build_request_async(request_builder);
let result = request.send().await.map_err(|e| self.to_api_error(e));
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().await.unwrap_or_default();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
fn post_sync<T: serde::ser::Serialize + std::fmt::Debug>(
&self,
path: &str,
params: &T,
) -> Result<reqwest::blocking::Response, ApiError> {
let reqwest_client = reqwest::blocking::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request_builder = reqwest_client.post(url).json(params);
let request = self.build_request_sync(request_builder);
let result = request.send();
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().unwrap_or_default();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
&self,
path: &str,
params: &T,
) -> Result<reqwest::Response, ApiError> {
let reqwest_client = reqwest::Client::new();
let url = format!("{}{}", self.endpoint, path);
let request_builder = reqwest_client.post(url).json(params);
let request = self.build_request_async(request_builder);
let result = request.send().await.map_err(|e| self.to_api_error(e));
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let text = response.text().await.unwrap_or_default();
Err(ApiError {
message: format!("{}: {}", status, text),
})
}
}
Err(error) => Err(ApiError {
message: error.to_string(),
}),
}
}
fn to_api_error(&self, err: ReqwestError) -> ApiError {
ApiError {
message: err.to_string(),

View File

@@ -0,0 +1,20 @@
use jrest::expect;
use mistralai_client::v1::client::Client;
#[tokio::test]
async fn test_client_list_models_async() {
let client = Client::new(None, None, None, None).unwrap();
let response = client.list_models_async().await.unwrap();
expect!(response.object).to_be("list".to_string());
expect!(response.data.len()).to_be_greater_than(0);
// let open_mistral_7b_data_item = response
// .data
// .iter()
// .find(|item| item.id == "open-mistral-7b")
// .unwrap();
// expect!(open_mistral_7b_data_item.id).to_be("open-mistral-7b".to_string());
}