diff --git a/README.md b/README.md index b13e9c2..0cd2baf 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Rust client for the Mistral AI API. - [x] Chat without streaming - [ ] Chat with streaming - [ ] Embedding -- [ ] List models +- [x] List models - [ ] Function Calling ## Installation @@ -63,11 +63,8 @@ fn main() { ### Chat without streaming ```rs -use mistralai::v1::{ - chat_completion::{ - ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequest, - ChatCompletionRequestOptions, - }, +use mistralai_client::v1::{ + chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequestOptions}, client::Client, constants::OPEN_MISTRAL_7B, }; @@ -103,4 +100,15 @@ _In progress._ ### List models -_In progress._ +```rs +use mistralai_client::v1::client::Client; + +fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None); + + let result = client.list_models(model, messages, Some(options)).unwrap(); + println!("First Model ID: {:?}", result.data[0].id); + // => "First Model ID: open-mistral-7b" +} +``` diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index c527bd9..998677c 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -44,9 +44,9 @@ pub struct ChatCompletionRequest { pub stream: Option, #[serde(skip_serializing_if = "Option::is_none")] pub safe_prompt: Option, - // TODO Check that prop (seen in official Python client but not in API doc). + // TODO Check this prop (seen in official Python client but not in API doc). // pub tool_choice: Option, - // TODO Check that prop (seen in official Python client but not in API doc). + // TODO Check this prop (seen in official Python client but not in API doc). // pub response_format: Option, } impl ChatCompletionRequest { @@ -95,7 +95,7 @@ pub struct ChatCompletionChoice { pub index: u32, pub message: ChatCompletionMessage, pub finish_reason: String, - // TODO Check that prop (seen in API responses but undocumented). + // TODO Check this prop (seen in API responses but undocumented). // pub logprobs: ??? } diff --git a/src/v1/client.rs b/src/v1/client.rs index 2e95266..6e8981d 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -2,12 +2,13 @@ use crate::v1::error::APIError; use minreq::Response; use crate::v1::{ - chat_completion::{ChatCompletionRequest, ChatCompletionResponse}, + chat_completion::{ + ChatCompletionMessage, ChatCompletionParams, ChatCompletionRequest, ChatCompletionResponse, + }, constants::API_URL_BASE, + list_models::ListModelsResponse, }; -use super::chat_completion::{ChatCompletionMessage, ChatCompletionParams}; - pub struct Client { pub api_key: String, pub endpoint: String, @@ -38,7 +39,7 @@ impl Client { pub fn build_request(&self, request: minreq::Request) -> minreq::Request { let authorization = format!("Bearer {}", self.api_key); let user_agent = format!( - "ivangabriele/mistral-client-rs/{}", + "ivangabriele/mistralai-client-rs/{}", env!("CARGO_PKG_VERSION") ); @@ -53,20 +54,24 @@ impl Client { pub fn get(&self, path: &str) -> Result { let url = format!("{}{}", self.endpoint, path); - let request = self.build_request(minreq::post(url)); + let request = self.build_request(minreq::get(url)); let result = request.send(); match result { - Ok(res) => { - if (200..=299).contains(&res.status_code) { - Ok(res) + Ok(response) => { + if (200..=299).contains(&response.status_code) { + Ok(response) } else { Err(APIError { - message: format!("{}: {}", res.status_code, res.as_str().unwrap()), + message: format!( + "{}: {}", + response.status_code, + response.as_str().unwrap() + ), }) } } - Err(e) => Err(self.new_error(e)), + Err(error) => Err(self.new_error(error)), } } @@ -83,7 +88,7 @@ impl Client { let result = request.with_json(params).unwrap().send(); match result { Ok(response) => { - print!("{:?}", response.as_str().unwrap()); + // print!("{:?}", response.as_str().unwrap()); if (200..=299).contains(&response.status_code) { Ok(response) @@ -124,24 +129,6 @@ impl Client { } } - // pub fn completion(&self, req: CompletionRequest) -> Result { - // let res = self.post("/completions", &req)?; - // let r = res.json::(); - // match r { - // Ok(r) => Ok(r), - // Err(e) => Err(self.new_error(e)), - // } - // } - - // pub fn embedding(&self, req: EmbeddingRequest) -> Result { - // let res = self.post("/embeddings", &req)?; - // let r = res.json::(); - // match r { - // Ok(r) => Ok(r), - // Err(e) => Err(self.new_error(e)), - // } - // } - pub fn chat( &self, model: String, @@ -158,35 +145,18 @@ impl Client { } } + pub fn list_models(&self) -> Result { + let response = self.get("/models")?; + let result = response.json::(); + match result { + Ok(response) => Ok(response), + Err(error) => Err(self.new_error(error)), + } + } + fn new_error(&self, err: minreq::Error) -> APIError { APIError { message: err.to_string(), } } - - // fn query_params( - // limit: Option, - // order: Option, - // after: Option, - // before: Option, - // mut url: String, - // ) -> String { - // let mut params = vec![]; - // if let Some(limit) = limit { - // params.push(format!("limit={}", limit)); - // } - // if let Some(order) = order { - // params.push(format!("order={}", order)); - // } - // if let Some(after) = after { - // params.push(format!("after={}", after)); - // } - // if let Some(before) = before { - // params.push(format!("before={}", before)); - // } - // if !params.is_empty() { - // url = format!("{}?{}", url, params.join("&")); - // } - // url - // } } diff --git a/src/v1/list_models.rs b/src/v1/list_models.rs new file mode 100644 index 0000000..1d6515e --- /dev/null +++ b/src/v1/list_models.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ListModelsResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ListModelsModel { + pub id: String, + pub object: String, + /// Unix timestamp (in seconds). + pub created: u32, + pub owned_by: String, + pub permission: Vec, + // TODO Check this prop (seen in API responses but undocumented). + // pub root: ???, + // TODO Check this prop (seen in API responses but undocumented). + // pub parent: ???, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ListModelsModelPermission { + pub id: String, + pub object: String, + /// Unix timestamp (in seconds). + pub created: u32, + pub allow_create_engine: bool, + pub allow_sampling: bool, + pub allow_logprobs: bool, + pub allow_search_indices: bool, + pub allow_view: bool, + pub allow_fine_tuning: bool, + pub organization: String, + pub is_blocking: bool, + // TODO Check this prop (seen in API responses but undocumented). + // pub group: ???, +} diff --git a/src/v1/mod.rs b/src/v1/mod.rs index c7032e1..7a06729 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -3,3 +3,4 @@ pub mod client; pub mod common; pub mod constants; pub mod error; +pub mod list_models; diff --git a/tests/v1_list_models_test.rs b/tests/v1_list_models_test.rs new file mode 100644 index 0000000..13ec9fb --- /dev/null +++ b/tests/v1_list_models_test.rs @@ -0,0 +1,17 @@ +use jrest::expect; +use mistralai_client::v1::client::Client; + +#[test] +fn test_list_models() { + extern crate dotenv; + + use dotenv::dotenv; + dotenv().ok(); + + let client = Client::new(None, None, None, None); + + let response = client.list_models().unwrap(); + + expect!(response.object).to_be("list".to_string()); + expect!(response.data.len()).to_be_greater_than(0); +}