From 7d3b438d16e9936591b6454525968c5c2cdfd6ad Mon Sep 17 00:00:00 2001 From: Ivan Gabriele Date: Sun, 3 Mar 2024 15:20:30 +0100 Subject: [PATCH] feat: add chat completion without streaming --- .editorconfig | 33 ++++ .env.example | 3 + .github/ISSUE_TEMPLATE/bug_report.md | 35 ++++ .github/ISSUE_TEMPLATE/feature_request.md | 24 +++ .github/pull_request_template.md | 8 + .github/renovate.json | 3 + .github/workflows/test.yml | 26 +++ .gitignore | 25 +++ CODE_OF_CONDUCT.md | 132 ++++++++++++++ CONTRIBUTING.md | 55 ++++++ Cargo.toml | 17 ++ LICENSE.md | 201 ++++++++++++++++++++++ Makefile | 6 + README.md | 107 ++++++++++++ SECURITY.md | 9 + src/lib.rs | 1 + src/v1/chat_completion.rs | 113 ++++++++++++ src/v1/client.rs | 175 +++++++++++++++++++ src/v1/common.rs | 8 + src/v1/constants.rs | 7 + src/v1/error.rs | 15 ++ src/v1/mod.rs | 5 + tests/v1_chat_completion_test.rs | 53 ++++++ tests/v1_client_test.rs | 52 ++++++ 24 files changed, 1113 insertions(+) create mode 100644 .editorconfig create mode 100644 .env.example create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/pull_request_template.md create mode 100644 .github/renovate.json create mode 100644 .github/workflows/test.yml create mode 100644 .gitignore create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 Cargo.toml create mode 100644 LICENSE.md create mode 100644 Makefile create mode 100644 README.md create mode 100644 SECURITY.md create mode 100644 src/lib.rs create mode 100644 src/v1/chat_completion.rs create mode 100644 src/v1/client.rs create mode 100644 src/v1/common.rs create mode 100644 src/v1/constants.rs create mode 100644 src/v1/error.rs create mode 100644 src/v1/mod.rs create mode 100644 tests/v1_chat_completion_test.rs create mode 100644 tests/v1_client_test.rs diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..750c5f1 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,33 @@ +# https://editorconfig.org +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 2 +indent_style = space +insert_final_newline = true +max_line_length = 120 +trim_trailing_whitespace = true + +[*.md] +max_line_length = 0 +trim_trailing_whitespace = false + +[*.py] +indent_size = 4 + +[*.rs] +indent_size = 4 +max_line_length = 80 + +[*.xml] +trim_trailing_whitespace = false + +[COMMIT_EDITMSG] +max_line_length = 0 + +[Makefile] +indent_size = 8 +indent_style = tab +max_line_length = 80 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..385fddf --- /dev/null +++ b/.env.example @@ -0,0 +1,3 @@ +# This key is only used for development purposes. +# You'll only need one if you want to contribute to this library. +MISTRAL_API_KEY= diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..d25542c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,35 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the bug** + +A clear and concise description of what the bug is. + +**To Reproduce** + +Steps to reproduce the behavior: + +1. ... +2. ... + +**Expected behavior** + +A clear and concise description of what you expected to happen. + +**Screenshots** + +If applicable, add screenshots to help explain your problem. + +**Version** + +If applicable, what version did you use? + +**Environment** + +Add useful information about your configuration and environment here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..f0291e0 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,24 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** + +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** + +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** + +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** + +Add any other context or screenshots about the feature request here. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..6e32d77 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,8 @@ +## Description + +A clear and concise description of what your pull request is about. + +## Checklist + +- [ ] I updated the documentation accordingly. Or I don't need to. +- [ ] I updated the tests accordingly. Or I don't need to. diff --git a/.github/renovate.json b/.github/renovate.json new file mode 100644 index 0000000..6a36a09 --- /dev/null +++ b/.github/renovate.json @@ -0,0 +1,3 @@ +{ + "extends": ["github>ivangabriele/renovate-config"] +} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..12feca2 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,26 @@ +name: Test + +on: push + +jobs: + test: + name: Test + runs-on: ubuntu-latest + container: + image: xd009642/tarpaulin + # https://github.com/xd009642/tarpaulin#github-actions + options: --security-opt seccomp=unconfined + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: 1.76.0 + - name: Run tests (with coverage) + run: make test-cover + - name: Upload tests coverage + uses: codecov/codecov-action@v3 + with: + fail_ci_if_error: true + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..be294f9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +######################################## +# Rust + +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +######################################## +# Custom + +# Tarpaulin coverage output +/cobertura.xml + +.env diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..b8a3c88 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,132 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +[INSERT CONTACT METHOD]. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +[https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available +at [https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..f5d34d9 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,55 @@ +# Contribute + +- [Getting Started](#getting-started) + - [Requirements](#requirements) + - [First setup](#first-setup) + - [Optional requirements](#optional-requirements) + - [Test](#test) +- [Code of Conduct](#code-of-conduct) +- [Commit Message Format](#commit-message-format) + +--- + +## Getting Started + +### Requirements + +- [Rust](https://www.rust-lang.org/tools/install): v1 + +### First setup + +> [!IMPORTANT] +> If you're under **Windows**, you nust run all CLI commands under a Linux shell-like terminal (i.e.: WSL or Git Bash). + +Then run: + +```sh +git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork +cd ./mistralai-client-rs +cargo build +``` + +### Optional requirements + +- [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-*-watch`. + +### Test + +```sh +make test +``` + +or + +```sh +make test-watch +``` + +## Code of Conduct + +Help us keep this project open and inclusive. Please read and follow our [Code of Conduct](./CODE_OF_CONDUCT.md). + +## Commit Message Format + +This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification and +specificaly the [Angular Commit Message Guidelines](https://github.com/angular/angular/blob/main/CONTRIBUTING.md#commit). diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..3a88879 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "mistralai-client" +description = "Mistral AI API client library for Rust (unofficial)." +license = "Apache-2.0" +version = "0.0.0" +edition = "2021" + +[dependencies] +minreq = { version = "2.11.0", features = ["https-rustls", "json-using-serde"] } +serde = { version = "1.0.197", features = ["derive"] } +serde_json = "1.0.114" +thiserror = "1.0.57" +tokio = { version = "1.36.0", features = ["full"] } + +[dev-dependencies] +dotenv = "0.15.0" +jrest = "0.2.3" diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..09b0d19 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +test: + cargo test --no-fail-fast +test-cover: + cargo tarpaulin --frozen --no-fail-fast --out Xml --skip-clean +test-watch: + cargo watch -x "test -- --nocapture" diff --git a/README.md b/README.md new file mode 100644 index 0000000..d6da208 --- /dev/null +++ b/README.md @@ -0,0 +1,107 @@ +# Mistral AI Rust Client + +[![Crates.io Package](https://img.shields.io/crates/v/mistralai-client?style=for-the-badge)](https://crates.io/crates/mistralai-client-rs) +[![Docs.rs Documentation](https://img.shields.io/docsrs/mistralai-client-rs/latest?style=for-the-badge)](https://docs.rs/mistralai-client-rs/latest/mistralai-client-rs) +[![Test Workflow Status](https://img.shields.io/github/actions/workflow/status/ivangabriele/mistralai-client-rs/test.yml?label=CI&style=for-the-badge)](https://github.com/ivangabriele/mistralai-client-rs/actions?query=branch%3Amain+workflow%3ATest++) +[![Code Coverage](https://img.shields.io/codecov/c/github/ivangabriele/mistralai-client-rs/main?label=Cov&style=for-the-badge)](https://app.codecov.io/github/ivangabriele/mistralai-client-rs) + +Rust client for the Mistral AI API. + +--- + +- [Supported APIs](#supported-apis) +- [Installation](#installation) + - [Mistral API Key](#mistral-api-key) + - [As an environment variable](#as-an-environment-variable) + - [As a client argument](#as-a-client-argument) +- [Usage](#usage) + - [Chat without streaming](#chat-without-streaming) + - [Chat with streaming](#chat-with-streaming) + - [Embeddings](#embeddings) + - [List models](#list-models) + +--- + +## Supported APIs + +- [x] Chat without streaming +- [ ] Chat with streaming +- [ ] Embedding +- [ ] List models +- [ ] Function Calling + +## Installation + +You can install the library in your project using: + +```sh +cargo add mistralai-client +``` + +### Mistral API Key + +You can get your Mistral API Key there: . + +#### As an environment variable + +Just set the `MISTRAL_API_KEY` environment variable. + +#### As a client argument + +```rs +use mistralai_client::v1::client::Client; + +fn main() { + let api_key = "your_api_key"; + + let client = Client::new(Some(api_key), None, None, None); +} +``` + +## Usage + +### Chat without streaming + +```rs +use mistralai::v1::{ + chat_completion::{ + ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequest, + ChatCompletionRequestOptions, + }, + client::Client, + constants::OPEN_MISTRAL_7B, +}; + +fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None); + + let model = OPEN_MISTRAL_7B.to_string(); + let messages = vec![ChatCompletionMessage { + role: ChatCompletionMessageRole::user, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + }]; + let options = ChatCompletionRequestOptions { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let chat_completion_request = ChatCompletionRequest::new(model, messages, Some(options)); + let result = client.chat(chat_completion_request).unwrap(); + println!("Assistant: {}", result.choices[0].message.content); + // => "Assistant: Tower. [...]" +} +``` + +### Chat with streaming + +_In progress._ + +### Embeddings + +_In progress._ + +### List models + +_In progress._ diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..eba3ab2 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,9 @@ +# Security Policy + +## Supported Versions + +We only support the latest version of this project. + +## Reporting a Vulnerability + +You can report a vulnerability by opening an issue on this repository. diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a3a6d96 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1 @@ +pub mod v1; diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs new file mode 100644 index 0000000..25a2467 --- /dev/null +++ b/src/v1/chat_completion.rs @@ -0,0 +1,113 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::common; + +#[derive(Debug)] +pub struct ChatCompletionRequestOptions { + pub tools: Option, + pub temperature: Option, + pub max_tokens: Option, + pub top_p: Option, + pub random_seed: Option, + pub stream: Option, + pub safe_prompt: Option, +} +impl Default for ChatCompletionRequestOptions { + fn default() -> Self { + Self { + tools: None, + temperature: None, + max_tokens: None, + top_p: None, + random_seed: None, + stream: None, + safe_prompt: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletionRequest { + pub messages: Vec, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub random_seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub safe_prompt: Option, + // TODO Check that prop (seen in official Python client but not in API doc). + // pub tool_choice: Option, + // TODO Check that prop (seen in official Python client but not in API doc). + // pub response_format: Option, +} +impl ChatCompletionRequest { + pub fn new( + model: String, + messages: Vec, + options: Option, + ) -> Self { + let ChatCompletionRequestOptions { + tools, + temperature, + max_tokens, + top_p, + random_seed, + stream, + safe_prompt, + } = options.unwrap_or_default(); + + Self { + messages, + model, + tools, + temperature, + max_tokens, + top_p, + random_seed, + stream, + safe_prompt, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + /// Unix timestamp (in seconds). + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: common::ResponseUsage, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ChatCompletionChoice { + pub index: u32, + pub message: ChatCompletionMessage, + pub finish_reason: String, + // TODO Check that prop (seen in API responses but undocumented). + // pub logprobs: ??? +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ChatCompletionMessage { + pub role: ChatCompletionMessageRole, + pub content: String, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +#[allow(non_camel_case_types)] +pub enum ChatCompletionMessageRole { + assistant, + user, +} diff --git a/src/v1/client.rs b/src/v1/client.rs new file mode 100644 index 0000000..293cef9 --- /dev/null +++ b/src/v1/client.rs @@ -0,0 +1,175 @@ +use crate::v1::error::APIError; +use minreq::Response; + +use crate::v1::{ + chat_completion::{ChatCompletionRequest, ChatCompletionResponse}, + constants::API_URL_BASE, +}; + +pub struct Client { + pub api_key: String, + pub endpoint: String, + pub max_retries: u32, + pub timeout: u32, +} + +impl Client { + pub fn new( + api_key: Option, + endpoint: Option, + max_retries: Option, + timeout: Option, + ) -> Self { + let api_key = api_key.unwrap_or(std::env::var("MISTRAL_API_KEY").unwrap()); + let endpoint = endpoint.unwrap_or(API_URL_BASE.to_string()); + let max_retries = max_retries.unwrap_or(5); + let timeout = timeout.unwrap_or(120); + + Self { + api_key, + endpoint, + max_retries, + timeout, + } + } + + pub fn build_request(&self, request: minreq::Request) -> minreq::Request { + let authorization = format!("Bearer {}", self.api_key); + let user_agent = format!( + "ivangabriele/mistral-client-rs/{}", + env!("CARGO_PKG_VERSION") + ); + + let request = request + .with_header("Authorization", authorization) + .with_header("Accept", "application/json") + .with_header("Content-Type", "application/json") + .with_header("User-Agent", user_agent); + + request + } + + pub fn get(&self, path: &str) -> Result { + let url = format!("{}{}", self.endpoint, path); + let request = self.build_request(minreq::post(url)); + + let result = request.send(); + match result { + Ok(res) => { + if (200..=299).contains(&res.status_code) { + Ok(res) + } else { + Err(APIError { + message: format!("{}: {}", res.status_code, res.as_str().unwrap()), + }) + } + } + Err(e) => Err(self.new_error(e)), + } + } + + pub fn post( + &self, + path: &str, + params: &T, + ) -> Result { + // print!("{:?}", params); + + let url = format!("{}{}", self.endpoint, path); + let request = self.build_request(minreq::post(url)); + + let result = request.with_json(params).unwrap().send(); + match result { + Ok(res) => { + print!("{:?}", res.as_str().unwrap()); + + if (200..=299).contains(&res.status_code) { + Ok(res) + } else { + Err(APIError { + message: format!("{}: {}", res.status_code, res.as_str().unwrap()), + }) + } + } + Err(e) => Err(self.new_error(e)), + } + } + + pub fn delete(&self, path: &str) -> Result { + let url = format!("{}{}", self.endpoint, path); + let request = self.build_request(minreq::post(url)); + + let result = request.send(); + match result { + Ok(res) => { + if (200..=299).contains(&res.status_code) { + Ok(res) + } else { + Err(APIError { + message: format!("{}: {}", res.status_code, res.as_str().unwrap()), + }) + } + } + Err(e) => Err(self.new_error(e)), + } + } + + // pub fn completion(&self, req: CompletionRequest) -> Result { + // let res = self.post("/completions", &req)?; + // let r = res.json::(); + // match r { + // Ok(r) => Ok(r), + // Err(e) => Err(self.new_error(e)), + // } + // } + + // pub fn embedding(&self, req: EmbeddingRequest) -> Result { + // let res = self.post("/embeddings", &req)?; + // let r = res.json::(); + // match r { + // Ok(r) => Ok(r), + // Err(e) => Err(self.new_error(e)), + // } + // } + + pub fn chat(&self, request: ChatCompletionRequest) -> Result { + let response = self.post("/chat/completions", &request)?; + let result = response.json::(); + match result { + Ok(r) => Ok(r), + Err(e) => Err(self.new_error(e)), + } + } + + fn new_error(&self, err: minreq::Error) -> APIError { + APIError { + message: err.to_string(), + } + } + + // fn query_params( + // limit: Option, + // order: Option, + // after: Option, + // before: Option, + // mut url: String, + // ) -> String { + // let mut params = vec![]; + // if let Some(limit) = limit { + // params.push(format!("limit={}", limit)); + // } + // if let Some(order) = order { + // params.push(format!("order={}", order)); + // } + // if let Some(after) = after { + // params.push(format!("after={}", after)); + // } + // if let Some(before) = before { + // params.push(format!("before={}", before)); + // } + // if !params.is_empty() { + // url = format!("{}?{}", url, params.join("&")); + // } + // url + // } +} diff --git a/src/v1/common.rs b/src/v1/common.rs new file mode 100644 index 0000000..160073c --- /dev/null +++ b/src/v1/common.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ResponseUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} diff --git a/src/v1/constants.rs b/src/v1/constants.rs new file mode 100644 index 0000000..1cd4c85 --- /dev/null +++ b/src/v1/constants.rs @@ -0,0 +1,7 @@ +pub const API_URL_BASE: &str = "https://api.mistral.ai/v1"; + +pub const OPEN_MISTRAL_7B: &str = "open-mistral-7b"; +pub const OPEN_MISTRAL_8X7B: &str = "open-mixtral-8x7b"; +pub const MISTRAL_SMALL_LATEST: &str = "mistral-small-latest"; +pub const MISTRAL_MEDIUM_LATEST: &str = "mistral-medium-latest"; +pub const MISTRAL_LARGE_LATEST: &str = "mistral-large-latest"; diff --git a/src/v1/error.rs b/src/v1/error.rs new file mode 100644 index 0000000..b2d2bf0 --- /dev/null +++ b/src/v1/error.rs @@ -0,0 +1,15 @@ +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +pub struct APIError { + pub message: String, +} + +impl fmt::Display for APIError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "APIError: {}", self.message) + } +} + +impl Error for APIError {} diff --git a/src/v1/mod.rs b/src/v1/mod.rs new file mode 100644 index 0000000..c7032e1 --- /dev/null +++ b/src/v1/mod.rs @@ -0,0 +1,5 @@ +pub mod chat_completion; +pub mod client; +pub mod common; +pub mod constants; +pub mod error; diff --git a/tests/v1_chat_completion_test.rs b/tests/v1_chat_completion_test.rs new file mode 100644 index 0000000..72d09fd --- /dev/null +++ b/tests/v1_chat_completion_test.rs @@ -0,0 +1,53 @@ +use jrest::expect; +use mistralai_client::v1::{ + chat_completion::{ + ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequest, + ChatCompletionRequestOptions, + }, + client::Client, + constants::OPEN_MISTRAL_7B, +}; + +#[test] +fn test_client_new() { + extern crate dotenv; + + use dotenv::dotenv; + dotenv().ok(); + + let client = Client::new(None, None, None, None); + + let model = OPEN_MISTRAL_7B.to_string(); + let messages = vec![ChatCompletionMessage { + role: ChatCompletionMessageRole::user, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + }]; + let options = ChatCompletionRequestOptions { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let chat_completion_request = ChatCompletionRequest::new(model, messages, Some(options)); + let result = client.chat(chat_completion_request); + + match result { + Ok(res) => { + expect!(res.model).to_be("open-mistral-7b".to_string()); + expect!(res.object).to_be("chat.completion".to_string()); + expect!(res.choices.len()).to_be(1); + expect!(res.choices[0].index).to_be(0); + expect!(res.choices[0].message.role.clone()) + .to_be(ChatCompletionMessageRole::assistant); + expect!(res.choices[0].message.content.clone()).to_be( + "Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string(), + ); + expect!(res.usage.prompt_tokens).to_be_greater_than(0); + expect!(res.usage.completion_tokens).to_be_greater_than(0); + expect!(res.usage.total_tokens).to_be_greater_than(21); + } + Err(err) => { + panic!("Error: {}", err); + } + } +} diff --git a/tests/v1_client_test.rs b/tests/v1_client_test.rs new file mode 100644 index 0000000..91a0c39 --- /dev/null +++ b/tests/v1_client_test.rs @@ -0,0 +1,52 @@ +use jrest::expect; +use mistralai_client::v1::client::Client; + +#[test] +fn test_client_new_with_none_params() { + let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); + std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); + + let client = Client::new(None, None, None, None); + + expect!(client.api_key).to_be("test_api_key_from_env".to_string()); + expect!(client.endpoint).to_be("https://api.mistral.ai/v1".to_string()); + expect!(client.max_retries).to_be(5); + expect!(client.timeout).to_be(120); + + match maybe_original_mistral_api_key { + Some(original_mistral_api_key) => { + std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) + } + None => std::env::remove_var("MISTRAL_API_KEY"), + } +} + +#[test] +fn test_client_new_with_all_params() { + let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); + std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); + + let api_key = Some("test_api_key_from_param".to_string()); + let endpoint = Some("https://example.org".to_string()); + let max_retries = Some(10); + let timeout = Some(20); + + let client = Client::new( + api_key.clone(), + endpoint.clone(), + max_retries.clone(), + timeout.clone(), + ); + + expect!(client.api_key).to_be(api_key.unwrap()); + expect!(client.endpoint).to_be(endpoint.unwrap()); + expect!(client.max_retries).to_be(max_retries.unwrap()); + expect!(client.timeout).to_be(timeout.unwrap()); + + match maybe_original_mistral_api_key { + Some(original_mistral_api_key) => { + std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) + } + None => std::env::remove_var("MISTRAL_API_KEY"), + } +}