diff --git a/src/v1/agents.rs b/src/v1/agents.rs new file mode 100644 index 0000000..4b4f1e2 --- /dev/null +++ b/src/v1/agents.rs @@ -0,0 +1,98 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::{chat, common, constants, tool}; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug)] +pub struct AgentCompletionParams { + pub max_tokens: Option, + pub min_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub random_seed: Option, + pub stop: Option>, + pub response_format: Option, + pub tools: Option>, + pub tool_choice: Option, +} +impl Default for AgentCompletionParams { + fn default() -> Self { + Self { + max_tokens: None, + min_tokens: None, + temperature: None, + top_p: None, + random_seed: None, + stop: None, + response_format: None, + tools: None, + tool_choice: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AgentCompletionRequest { + pub agent_id: String, + pub messages: Vec, + pub stream: bool, + + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub random_seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, +} +impl AgentCompletionRequest { + pub fn new( + agent_id: String, + messages: Vec, + stream: bool, + options: Option, + ) -> Self { + let opts = options.unwrap_or_default(); + + Self { + agent_id, + messages, + stream, + max_tokens: opts.max_tokens, + min_tokens: opts.min_tokens, + temperature: opts.temperature, + top_p: opts.top_p, + random_seed: opts.random_seed, + stop: opts.stop, + response_format: opts.response_format, + tools: opts.tools, + tool_choice: opts.tool_choice, + } + } +} + +// ----------------------------------------------------------------------------- +// Response (same shape as chat completions) + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AgentCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: constants::Model, + pub choices: Vec, + pub usage: common::ResponseUsage, +} diff --git a/src/v1/audio.rs b/src/v1/audio.rs new file mode 100644 index 0000000..98cdfb7 --- /dev/null +++ b/src/v1/audio.rs @@ -0,0 +1,78 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request (multipart form, but we define the params struct) + +#[derive(Debug)] +pub struct AudioTranscriptionParams { + pub model: constants::Model, + pub language: Option, + pub temperature: Option, + pub diarize: Option, + pub timestamp_granularities: Option>, +} +impl Default for AudioTranscriptionParams { + fn default() -> Self { + Self { + model: constants::Model::voxtral_mini_transcribe(), + language: None, + temperature: None, + diarize: None, + timestamp_granularities: None, + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum TimestampGranularity { + Segment, + Word, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AudioTranscriptionResponse { + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub words: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TranscriptionSegment { + pub id: u32, + pub start: f32, + pub end: f32, + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub speaker: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TranscriptionWord { + pub word: String, + pub start: f32, + pub end: f32, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AudioUsage { + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_audio_seconds: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, +} diff --git a/src/v1/batch.rs b/src/v1/batch.rs new file mode 100644 index 0000000..e1e1be1 --- /dev/null +++ b/src/v1/batch.rs @@ -0,0 +1,53 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchJobRequest { + pub input_files: Vec, + pub model: constants::Model, + pub endpoint: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct BatchJobResponse { + pub id: String, + pub object: String, + pub model: constants::Model, + pub endpoint: String, + pub input_files: Vec, + pub status: String, + pub created_at: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_file: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_file: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_requests: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_requests: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub succeeded_requests: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub failed_requests: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct BatchJobListResponse { + pub data: Vec, + pub object: String, + #[serde(default)] + pub total: u32, +} diff --git a/src/v1/client.rs b/src/v1/client.rs index ef396f3..7b2ecc8 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -5,10 +5,14 @@ use reqwest::Error as ReqwestError; use std::{ any::Any, collections::HashMap, + path::Path, sync::{Arc, Mutex}, }; -use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils}; +use crate::v1::{ + agents, audio, batch, chat, chat_stream, constants, embedding, error, files, fim, fine_tuning, + model_list, moderation, ocr, tool, utils, +}; #[derive(Debug)] pub struct Client { @@ -200,6 +204,48 @@ impl Client { } } + // ========================================================================= + // FIM (Fill-in-the-Middle) + // ========================================================================= + + pub fn fim( + &self, + model: constants::Model, + prompt: String, + options: Option, + ) -> Result { + let request = fim::FimRequest::new(model, prompt, false, options); + + let response = self.post_sync("/fim/completions", &request)?; + let result = response.json::(); + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } + } + + pub async fn fim_async( + &self, + model: constants::Model, + prompt: String, + options: Option, + ) -> Result { + let request = fim::FimRequest::new(model, prompt, false, options); + + let response = self.post_async("/fim/completions", &request).await?; + let result = response.json::().await; + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } + } + // ========================================================================= // Models // ========================================================================= @@ -283,6 +329,577 @@ impl Client { } } + // ========================================================================= + // Files + // ========================================================================= + + pub fn list_files(&self) -> Result { + let response = self.get_sync("/files")?; + let result = response.json::(); + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + pub async fn list_files_async(&self) -> Result { + let response = self.get_async("/files").await?; + let result = response.json::().await; + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + pub fn upload_file( + &self, + file_path: &Path, + purpose: files::FilePurpose, + ) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}/files", self.endpoint); + + let purpose_str = serde_json::to_value(&purpose) + .unwrap() + .as_str() + .unwrap() + .to_string(); + + let file_bytes = std::fs::read(file_path).map_err(|e| error::ApiError { + message: format!("Failed to read file: {}", e), + })?; + + let file_name = file_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = reqwest::blocking::multipart::Part::bytes(file_bytes).file_name(file_name); + + let form = reqwest::blocking::multipart::Form::new() + .text("purpose", purpose_str) + .part("file", part); + + let request = self.build_request_sync_no_accept(reqwest_client.post(url).multipart(form)); + let result = request.send(); + match result { + Ok(response) => { + if response.status().is_success() { + response + .json::() + .map_err(|e| self.to_api_error(e)) + } else { + let status = response.status(); + let body = response.text().unwrap_or_default(); + Err(error::ApiError { + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + pub async fn upload_file_async( + &self, + file_path: &Path, + purpose: files::FilePurpose, + ) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}/files", self.endpoint); + + let purpose_str = serde_json::to_value(&purpose) + .unwrap() + .as_str() + .unwrap() + .to_string(); + + let file_bytes = tokio::fs::read(file_path).await.map_err(|e| error::ApiError { + message: format!("Failed to read file: {}", e), + })?; + + let file_name = file_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = reqwest::multipart::Part::bytes(file_bytes).file_name(file_name); + + let form = reqwest::multipart::Form::new() + .text("purpose", purpose_str) + .part("file", part); + + let request = self.build_request_async_no_accept(reqwest_client.post(url).multipart(form)); + let result = request.send().await; + match result { + Ok(response) => { + if response.status().is_success() { + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } else { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + Err(error::ApiError { + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + pub fn get_file(&self, file_id: &str) -> Result { + let response = self.get_sync(&format!("/files/{}", file_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_file_async( + &self, + file_id: &str, + ) -> Result { + let response = self.get_async(&format!("/files/{}", file_id)).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn delete_file(&self, file_id: &str) -> Result { + let response = self.delete_sync(&format!("/files/{}", file_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn delete_file_async( + &self, + file_id: &str, + ) -> Result { + let response = self.delete_async(&format!("/files/{}", file_id)).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_file_url(&self, file_id: &str) -> Result { + let response = self.get_sync(&format!("/files/{}/url", file_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_file_url_async( + &self, + file_id: &str, + ) -> Result { + let response = self.get_async(&format!("/files/{}/url", file_id)).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Fine-Tuning + // ========================================================================= + + pub fn create_fine_tuning_job( + &self, + request: &fine_tuning::FineTuningJobRequest, + ) -> Result { + let response = self.post_sync("/fine_tuning/jobs", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn create_fine_tuning_job_async( + &self, + request: &fine_tuning::FineTuningJobRequest, + ) -> Result { + let response = self.post_async("/fine_tuning/jobs", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn list_fine_tuning_jobs( + &self, + ) -> Result { + let response = self.get_sync("/fine_tuning/jobs")?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn list_fine_tuning_jobs_async( + &self, + ) -> Result { + let response = self.get_async("/fine_tuning/jobs").await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_fine_tuning_job( + &self, + job_id: &str, + ) -> Result { + let response = self.get_sync(&format!("/fine_tuning/jobs/{}", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_fine_tuning_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .get_async(&format!("/fine_tuning/jobs/{}", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn cancel_fine_tuning_job( + &self, + job_id: &str, + ) -> Result { + let response = self.post_sync_empty(&format!("/fine_tuning/jobs/{}/cancel", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn cancel_fine_tuning_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .post_async_empty(&format!("/fine_tuning/jobs/{}/cancel", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn start_fine_tuning_job( + &self, + job_id: &str, + ) -> Result { + let response = self.post_sync_empty(&format!("/fine_tuning/jobs/{}/start", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn start_fine_tuning_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .post_async_empty(&format!("/fine_tuning/jobs/{}/start", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Batch Jobs + // ========================================================================= + + pub fn create_batch_job( + &self, + request: &batch::BatchJobRequest, + ) -> Result { + let response = self.post_sync("/batch/jobs", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn create_batch_job_async( + &self, + request: &batch::BatchJobRequest, + ) -> Result { + let response = self.post_async("/batch/jobs", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn list_batch_jobs(&self) -> Result { + let response = self.get_sync("/batch/jobs")?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn list_batch_jobs_async( + &self, + ) -> Result { + let response = self.get_async("/batch/jobs").await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_batch_job( + &self, + job_id: &str, + ) -> Result { + let response = self.get_sync(&format!("/batch/jobs/{}", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_batch_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .get_async(&format!("/batch/jobs/{}", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn cancel_batch_job( + &self, + job_id: &str, + ) -> Result { + let response = self.post_sync_empty(&format!("/batch/jobs/{}/cancel", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn cancel_batch_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .post_async_empty(&format!("/batch/jobs/{}/cancel", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // OCR + // ========================================================================= + + pub fn ocr( + &self, + request: &ocr::OcrRequest, + ) -> Result { + let response = self.post_sync("/ocr", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn ocr_async( + &self, + request: &ocr::OcrRequest, + ) -> Result { + let response = self.post_async("/ocr", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Audio Transcription + // ========================================================================= + + pub async fn transcribe_audio_async( + &self, + file_path: &Path, + params: Option, + ) -> Result { + let opts = params.unwrap_or_default(); + let reqwest_client = reqwest::Client::new(); + let url = format!("{}/audio/transcriptions", self.endpoint); + + let file_bytes = tokio::fs::read(file_path).await.map_err(|e| error::ApiError { + message: format!("Failed to read file: {}", e), + })?; + + let file_name = file_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = reqwest::multipart::Part::bytes(file_bytes).file_name(file_name); + + let mut form = reqwest::multipart::Form::new() + .text("model", opts.model.0) + .part("file", part); + + if let Some(lang) = opts.language { + form = form.text("language", lang); + } + if let Some(temp) = opts.temperature { + form = form.text("temperature", temp.to_string()); + } + if let Some(diarize) = opts.diarize { + form = form.text("diarize", diarize.to_string()); + } + + let request = self.build_request_async_no_accept(reqwest_client.post(url).multipart(form)); + let result = request.send().await; + match result { + Ok(response) => { + if response.status().is_success() { + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } else { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + Err(error::ApiError { + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + // ========================================================================= + // Moderations & Classifications + // ========================================================================= + + pub fn moderations( + &self, + model: constants::Model, + input: Vec, + ) -> Result { + let request = moderation::ModerationRequest::new(model, input); + let response = self.post_sync("/moderations", &request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn moderations_async( + &self, + model: constants::Model, + input: Vec, + ) -> Result { + let request = moderation::ModerationRequest::new(model, input); + let response = self.post_async("/moderations", &request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn chat_moderations( + &self, + request: &moderation::ChatModerationRequest, + ) -> Result { + let response = self.post_sync("/chat/moderations", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn chat_moderations_async( + &self, + request: &moderation::ChatModerationRequest, + ) -> Result { + let response = self.post_async("/chat/moderations", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn classifications( + &self, + request: &moderation::ClassificationRequest, + ) -> Result { + let response = self.post_sync("/classifications", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn classifications_async( + &self, + request: &moderation::ClassificationRequest, + ) -> Result { + let response = self.post_async("/classifications", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Agent Completions + // ========================================================================= + + pub fn agent_completion( + &self, + agent_id: String, + messages: Vec, + options: Option, + ) -> Result { + let request = agents::AgentCompletionRequest::new(agent_id, messages, false, options); + let response = self.post_sync("/agents/completions", &request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn agent_completion_async( + &self, + agent_id: String, + messages: Vec, + options: Option, + ) -> Result { + let request = agents::AgentCompletionRequest::new(agent_id, messages, false, options); + let response = self.post_async("/agents/completions", &request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + // ========================================================================= // Function Calling // ========================================================================= diff --git a/src/v1/files.rs b/src/v1/files.rs new file mode 100644 index 0000000..e667983 --- /dev/null +++ b/src/v1/files.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum FilePurpose { + FineTune, + Batch, + Ocr, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileListResponse { + pub data: Vec, + pub object: String, + #[serde(default)] + pub total: u32, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileObject { + pub id: String, + pub object: String, + pub bytes: u64, + pub created_at: u64, + pub filename: String, + pub purpose: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub sample_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_lines: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mimetype: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileDeleteResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileUrlResponse { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option, +} diff --git a/src/v1/fim.rs b/src/v1/fim.rs new file mode 100644 index 0000000..5d6ff05 --- /dev/null +++ b/src/v1/fim.rs @@ -0,0 +1,101 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::{common, constants}; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug)] +pub struct FimParams { + pub suffix: Option, + pub max_tokens: Option, + pub min_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub stop: Option>, + pub random_seed: Option, +} +impl Default for FimParams { + fn default() -> Self { + Self { + suffix: None, + max_tokens: None, + min_tokens: None, + temperature: None, + top_p: None, + stop: None, + random_seed: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FimRequest { + pub model: constants::Model, + pub prompt: String, + pub stream: bool, + + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub random_seed: Option, +} +impl FimRequest { + pub fn new( + model: constants::Model, + prompt: String, + stream: bool, + options: Option, + ) -> Self { + let opts = options.unwrap_or_default(); + + Self { + model, + prompt, + stream, + suffix: opts.suffix, + max_tokens: opts.max_tokens, + min_tokens: opts.min_tokens, + temperature: opts.temperature, + top_p: opts.top_p, + stop: opts.stop, + random_seed: opts.random_seed, + } + } +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FimResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: constants::Model, + pub choices: Vec, + pub usage: common::ResponseUsage, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FimResponseChoice { + pub index: u32, + pub message: FimResponseMessage, + pub finish_reason: String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FimResponseMessage { + pub role: String, + pub content: String, +} diff --git a/src/v1/fine_tuning.rs b/src/v1/fine_tuning.rs new file mode 100644 index 0000000..c5cce2d --- /dev/null +++ b/src/v1/fine_tuning.rs @@ -0,0 +1,101 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct FineTuningJobRequest { + pub model: constants::Model, + pub training_files: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub validation_files: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub hyperparameters: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub auto_start: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub job_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub integrations: Option>, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TrainingFile { + pub file_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub weight: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Hyperparameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub learning_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub training_steps: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub warmup_fraction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub epochs: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Integration { + pub r#type: String, + pub project: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub api_key: Option, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FineTuningJobResponse { + pub id: String, + pub object: String, + pub model: constants::Model, + pub status: FineTuningJobStatus, + pub created_at: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_at: Option, + pub training_files: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub validation_files: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub hyperparameters: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub fine_tuned_model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub integrations: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub trained_tokens: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum FineTuningJobStatus { + Queued, + Running, + Success, + Failed, + TimeoutExceeded, + CancellationRequested, + Cancelled, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FineTuningJobListResponse { + pub data: Vec, + pub object: String, + #[serde(default)] + pub total: u32, +} diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 72165bb..e1140b6 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -1,3 +1,6 @@ +pub mod agents; +pub mod audio; +pub mod batch; pub mod chat; pub mod chat_stream; pub mod client; @@ -5,6 +8,11 @@ pub mod common; pub mod constants; pub mod embedding; pub mod error; +pub mod files; +pub mod fim; +pub mod fine_tuning; pub mod model_list; +pub mod moderation; +pub mod ocr; pub mod tool; pub mod utils; diff --git a/src/v1/moderation.rs b/src/v1/moderation.rs new file mode 100644 index 0000000..b8e199b --- /dev/null +++ b/src/v1/moderation.rs @@ -0,0 +1,70 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct ModerationRequest { + pub model: constants::Model, + pub input: Vec, +} +impl ModerationRequest { + pub fn new(model: constants::Model, input: Vec) -> Self { + Self { model, input } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatModerationRequest { + pub model: constants::Model, + pub input: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatModerationInput { + pub role: String, + pub content: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ClassificationRequest { + pub model: constants::Model, + pub input: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatClassificationRequest { + pub model: constants::Model, + pub input: Vec, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ModerationResponse { + pub id: String, + pub model: constants::Model, + pub results: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ModerationResult { + pub categories: serde_json::Value, + pub category_scores: serde_json::Value, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ClassificationResponse { + pub id: String, + pub model: constants::Model, + pub results: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ClassificationResult { + pub categories: serde_json::Value, + pub category_scores: serde_json::Value, +} diff --git a/src/v1/ocr.rs b/src/v1/ocr.rs new file mode 100644 index 0000000..3e4cff3 --- /dev/null +++ b/src/v1/ocr.rs @@ -0,0 +1,96 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct OcrRequest { + pub model: constants::Model, + pub document: OcrDocument, + + #[serde(skip_serializing_if = "Option::is_none")] + pub pages: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub table_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_image_base64: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_limit: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OcrDocument { + #[serde(rename = "type")] + pub type_: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub document_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_id: Option, +} +impl OcrDocument { + pub fn from_url(url: &str) -> Self { + Self { + type_: "document_url".to_string(), + document_url: Some(url.to_string()), + file_id: None, + } + } + + pub fn from_file_id(file_id: &str) -> Self { + Self { + type_: "file_id".to_string(), + document_url: None, + file_id: Some(file_id.to_string()), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum OcrTableFormat { + Markdown, + Html, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrResponse { + pub pages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage_info: Option, + pub model: constants::Model, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrPage { + pub index: u32, + pub markdown: String, + #[serde(default)] + pub images: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub dimensions: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrImage { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_base64: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrPageDimensions { + pub width: f32, + pub height: f32, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrUsageInfo { + pub pages_processed: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub doc_size_bytes: Option, +}