feat: add client.list_models() method

This commit is contained in:
Ivan Gabriele
2024-03-03 19:38:34 +01:00
parent 7de2b19b98
commit 814b9918b3
6 changed files with 100 additions and 65 deletions

View File

@@ -2,12 +2,13 @@ use crate::v1::error::APIError;
use minreq::Response;
use crate::v1::{
chat_completion::{ChatCompletionRequest, ChatCompletionResponse},
chat_completion::{
ChatCompletionMessage, ChatCompletionParams, ChatCompletionRequest, ChatCompletionResponse,
},
constants::API_URL_BASE,
list_models::ListModelsResponse,
};
use super::chat_completion::{ChatCompletionMessage, ChatCompletionParams};
pub struct Client {
pub api_key: String,
pub endpoint: String,
@@ -38,7 +39,7 @@ impl Client {
pub fn build_request(&self, request: minreq::Request) -> minreq::Request {
let authorization = format!("Bearer {}", self.api_key);
let user_agent = format!(
"ivangabriele/mistral-client-rs/{}",
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
@@ -53,20 +54,24 @@ impl Client {
pub fn get(&self, path: &str) -> Result<Response, APIError> {
let url = format!("{}{}", self.endpoint, path);
let request = self.build_request(minreq::post(url));
let request = self.build_request(minreq::get(url));
let result = request.send();
match result {
Ok(res) => {
if (200..=299).contains(&res.status_code) {
Ok(res)
Ok(response) => {
if (200..=299).contains(&response.status_code) {
Ok(response)
} else {
Err(APIError {
message: format!("{}: {}", res.status_code, res.as_str().unwrap()),
message: format!(
"{}: {}",
response.status_code,
response.as_str().unwrap()
),
})
}
}
Err(e) => Err(self.new_error(e)),
Err(error) => Err(self.new_error(error)),
}
}
@@ -83,7 +88,7 @@ impl Client {
let result = request.with_json(params).unwrap().send();
match result {
Ok(response) => {
print!("{:?}", response.as_str().unwrap());
// print!("{:?}", response.as_str().unwrap());
if (200..=299).contains(&response.status_code) {
Ok(response)
@@ -124,24 +129,6 @@ impl Client {
}
}
// pub fn completion(&self, req: CompletionRequest) -> Result<CompletionResponse, APIError> {
// let res = self.post("/completions", &req)?;
// let r = res.json::<CompletionResponse>();
// match r {
// Ok(r) => Ok(r),
// Err(e) => Err(self.new_error(e)),
// }
// }
// pub fn embedding(&self, req: EmbeddingRequest) -> Result<EmbeddingResponse, APIError> {
// let res = self.post("/embeddings", &req)?;
// let r = res.json::<EmbeddingResponse>();
// match r {
// Ok(r) => Ok(r),
// Err(e) => Err(self.new_error(e)),
// }
// }
pub fn chat(
&self,
model: String,
@@ -158,35 +145,18 @@ impl Client {
}
}
pub fn list_models(&self) -> Result<ListModelsResponse, APIError> {
let response = self.get("/models")?;
let result = response.json::<ListModelsResponse>();
match result {
Ok(response) => Ok(response),
Err(error) => Err(self.new_error(error)),
}
}
fn new_error(&self, err: minreq::Error) -> APIError {
APIError {
message: err.to_string(),
}
}
// fn query_params(
// limit: Option<i64>,
// order: Option<String>,
// after: Option<String>,
// before: Option<String>,
// mut url: String,
// ) -> String {
// let mut params = vec![];
// if let Some(limit) = limit {
// params.push(format!("limit={}", limit));
// }
// if let Some(order) = order {
// params.push(format!("order={}", order));
// }
// if let Some(after) = after {
// params.push(format!("after={}", after));
// }
// if let Some(before) = before {
// params.push(format!("before={}", before));
// }
// if !params.is_empty() {
// url = format!("{}?{}", url, params.join("&"));
// }
// url
// }
}