Files
mistralai-client-rs/src/v1/fine_tuning.rs

102 lines
3.2 KiB
Rust
Raw Normal View History

use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct FineTuningJobRequest {
pub model: constants::Model,
pub training_files: Vec<TrainingFile>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hyperparameters: Option<Hyperparameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_start: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub job_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub integrations: Option<Vec<Integration>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TrainingFile {
pub file_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub weight: Option<f32>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Hyperparameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub learning_rate: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub training_steps: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub warmup_fraction: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub epochs: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Integration {
pub r#type: String,
pub project: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FineTuningJobResponse {
pub id: String,
pub object: String,
pub model: constants::Model,
pub status: FineTuningJobStatus,
pub created_at: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub modified_at: Option<u64>,
pub training_files: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hyperparameters: Option<Hyperparameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fine_tuned_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub integrations: Option<Vec<Integration>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub trained_tokens: Option<u64>,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FineTuningJobStatus {
Queued,
Running,
Success,
Failed,
TimeoutExceeded,
CancellationRequested,
Cancelled,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FineTuningJobListResponse {
pub data: Vec<FineTuningJobResponse>,
pub object: String,
#[serde(default)]
pub total: u32,
}