Files
mistralai-client-rs/src/v1/embedding.rs

87 lines
2.3 KiB
Rust
Raw Normal View History

use serde::{Deserialize, Serialize};
use crate::v1::{common, constants};
// -----------------------------------------------------------------------------
// Request
#[derive(Debug)]
pub struct EmbeddingRequestOptions {
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
pub output_dimension: Option<u32>,
pub output_dtype: Option<EmbeddingOutputDtype>,
}
impl Default for EmbeddingRequestOptions {
fn default() -> Self {
Self {
encoding_format: None,
output_dimension: None,
output_dtype: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub model: constants::Model,
pub input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dimension: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dtype: Option<EmbeddingOutputDtype>,
}
impl EmbeddingRequest {
pub fn new(
model: constants::Model,
input: Vec<String>,
options: Option<EmbeddingRequestOptions>,
) -> Self {
let opts = options.unwrap_or_default();
Self {
model,
input,
encoding_format: opts.encoding_format,
output_dimension: opts.output_dimension,
output_dtype: opts.output_dtype,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingRequestEncodingFormat {
Float,
Base64,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingOutputDtype {
Float,
Int8,
Uint8,
Binary,
Ubinary,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub model: constants::Model,
pub data: Vec<EmbeddingResponseDataItem>,
pub usage: common::ResponseUsage,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmbeddingResponseDataItem {
pub index: u32,
pub embedding: Vec<f32>,
pub object: String,
}