chore: checkpoint before Python removal

This commit is contained in:
2026-03-26 22:33:59 +00:00
parent 683cec9307
commit e568ddf82a
29972 changed files with 11269302 additions and 2 deletions

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,6 @@
{
"git": {
"sha1": "2c315aa7f9c2a6c1db87f8f51f40623a427c78fd"
},
"path_in_vcs": "quinn-proto"
}

1691
vendor/quinn-proto/Cargo.lock generated vendored Normal file

File diff suppressed because it is too large Load Diff

194
vendor/quinn-proto/Cargo.toml vendored Normal file
View File

@@ -0,0 +1,194 @@
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
#
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
# to registry (e.g., crates.io) dependencies.
#
# If you are reading this file be aware that the original Cargo.toml
# will likely look very different (and much more reasonable).
# See Cargo.toml.orig for the original contents.
[package]
edition = "2021"
rust-version = "1.74.1"
name = "quinn-proto"
version = "0.11.14"
build = false
autolib = false
autobins = false
autoexamples = false
autotests = false
autobenches = false
description = "State machine for the QUIC transport protocol"
readme = false
keywords = ["quic"]
categories = [
"network-programming",
"asynchronous",
]
license = "MIT OR Apache-2.0"
repository = "https://github.com/quinn-rs/quinn"
[package.metadata.docs.rs]
features = [
"rustls-aws-lc-rs",
"rustls-ring",
"platform-verifier",
"log",
"rustls-log",
]
[features]
__rustls-post-quantum-test = []
aws-lc-rs = [
"dep:aws-lc-rs",
"aws-lc-rs?/aws-lc-sys",
"aws-lc-rs?/prebuilt-nasm",
]
aws-lc-rs-fips = [
"aws-lc-rs",
"aws-lc-rs?/fips",
]
bloom = ["dep:fastbloom"]
default = [
"rustls-ring",
"log",
"bloom",
]
log = ["tracing/log"]
platform-verifier = ["dep:rustls-platform-verifier"]
qlog = ["dep:qlog"]
ring = ["dep:ring"]
rustls = ["rustls-ring"]
rustls-aws-lc-rs = [
"dep:rustls",
"rustls?/aws-lc-rs",
"aws-lc-rs",
]
rustls-aws-lc-rs-fips = [
"rustls-aws-lc-rs",
"aws-lc-rs-fips",
]
rustls-log = ["rustls?/logging"]
rustls-ring = [
"dep:rustls",
"rustls?/ring",
"ring",
]
[lib]
name = "quinn_proto"
path = "src/lib.rs"
[dependencies.arbitrary]
version = "1.0.1"
features = ["derive"]
optional = true
[dependencies.aws-lc-rs]
version = "1.9"
optional = true
default-features = false
[dependencies.bytes]
version = "1"
[dependencies.fastbloom]
version = "0.14"
optional = true
[dependencies.lru-slab]
version = "0.1.2"
[dependencies.qlog]
version = "0.15.2"
optional = true
[dependencies.rand]
version = "0.9"
[dependencies.ring]
version = "0.17"
optional = true
[dependencies.rustc-hash]
version = "2"
[dependencies.rustls]
version = "0.23.5"
features = ["std"]
optional = true
default-features = false
[dependencies.rustls-platform-verifier]
version = "0.6"
optional = true
[dependencies.slab]
version = "0.4.6"
[dependencies.thiserror]
version = "2.0.3"
[dependencies.tinyvec]
version = "1.1"
features = [
"alloc",
"alloc",
]
[dependencies.tracing]
version = "0.1.10"
features = ["std"]
default-features = false
[dev-dependencies.assert_matches]
version = "1.1"
[dev-dependencies.hex-literal]
version = "1"
[dev-dependencies.lazy_static]
version = "1"
[dev-dependencies.rand_pcg]
version = "0.9"
[dev-dependencies.rcgen]
version = "0.14"
[dev-dependencies.tracing-subscriber]
version = "0.3.0"
features = [
"env-filter",
"fmt",
"ansi",
"time",
"local-time",
]
default-features = false
[dev-dependencies.wasm-bindgen-test]
version = "0.3.45"
[target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.getrandom]
version = "0.3"
features = ["wasm_js"]
default-features = false
[target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.ring]
version = "0.17"
features = ["wasm32_unknown_unknown_js"]
[target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.rustls-pki-types]
version = "1.7"
features = ["web"]
[target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.web-time]
version = "1"
[lints.rust.unexpected_cfgs]
level = "warn"
priority = 0
check-cfg = ["cfg(fuzzing)"]

201
vendor/quinn-proto/LICENSE-APACHE vendored Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

7
vendor/quinn-proto/LICENSE-MIT vendored Normal file
View File

@@ -0,0 +1,7 @@
Copyright (c) 2018 The quinn Developers
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -0,0 +1,368 @@
use std::{
collections::HashSet,
f64::consts::LN_2,
hash::{BuildHasher, Hasher},
mem::{size_of, take},
sync::Mutex,
};
use fastbloom::BloomFilter;
use rustc_hash::FxBuildHasher;
use tracing::{trace, warn};
use crate::{Duration, SystemTime, TokenLog, TokenReuseError, UNIX_EPOCH};
/// Bloom filter-based [`TokenLog`]
///
/// Parameterizable over an approximate maximum number of bytes to allocate. Starts out by storing
/// used tokens in a hash set. Once the hash set becomes too large, converts it to a bloom filter.
/// This achieves a memory profile of linear growth with an upper bound.
///
/// Divides time into periods based on `lifetime` and stores two filters at any given moment, for
/// each of the two periods currently non-expired tokens could expire in. As such, turns over
/// filters as time goes on to avoid bloom filter false positive rate increasing infinitely over
/// time.
pub struct BloomTokenLog(Mutex<State>);
impl BloomTokenLog {
/// Construct with an approximate maximum memory usage and expected number of validation token
/// usages per expiration period
///
/// Calculates the optimal bloom filter k number automatically.
pub fn new_expected_items(max_bytes: usize, expected_hits: u64) -> Self {
Self::new(max_bytes, optimal_k_num(max_bytes, expected_hits))
}
/// Construct with an approximate maximum memory usage and a [bloom filter k number][bloom]
///
/// [bloom]: https://en.wikipedia.org/wiki/Bloom_filter
///
/// If choosing a custom k number, note that `BloomTokenLog` always maintains two filters
/// between them and divides the allocation budget of `max_bytes` evenly between them. As such,
/// each bloom filter will contain `max_bytes * 4` bits.
pub fn new(max_bytes: usize, k_num: u32) -> Self {
Self(Mutex::new(State {
config: FilterConfig {
filter_max_bytes: max_bytes / 2,
k_num,
},
period_1_start: UNIX_EPOCH,
filter_1: Filter::default(),
filter_2: Filter::default(),
}))
}
}
impl TokenLog for BloomTokenLog {
fn check_and_insert(
&self,
nonce: u128,
issued: SystemTime,
lifetime: Duration,
) -> Result<(), TokenReuseError> {
trace!(%nonce, "check_and_insert");
if lifetime.is_zero() {
// avoid divide-by-zero if lifetime is zero
return Err(TokenReuseError);
}
let mut guard = self.0.lock().unwrap();
let state = &mut *guard;
// calculate how many periods past period 1 the token expires
let expires_at = issued + lifetime;
let Ok(periods_forward) = expires_at
.duration_since(state.period_1_start)
.map(|duration| duration.as_nanos() / lifetime.as_nanos())
else {
// shouldn't happen unless time travels backwards or lifetime changes or the current
// system time is before the Unix epoch
warn!("BloomTokenLog presented with token too far in past");
return Err(TokenReuseError);
};
// get relevant filter
let filter = match periods_forward {
0 => &mut state.filter_1,
1 => &mut state.filter_2,
2 => {
// turn over filter 1
state.filter_1 = take(&mut state.filter_2);
state.period_1_start += lifetime;
&mut state.filter_2
}
_ => {
// turn over both filters
state.filter_1 = Filter::default();
state.filter_2 = Filter::default();
state.period_1_start = expires_at;
&mut state.filter_1
}
};
// insert into the filter
//
// the token's nonce needs to guarantee uniqueness because of the role it plays in the
// encryption of the tokens, so it is 128 bits. but since the token log can tolerate false
// positives, we trim it down to 64 bits, which would still only have a small collision
// rate even at significant amounts of usage, while allowing us to store twice as many in
// the hash set variant.
//
// token nonce values are uniformly randomly generated server-side and cryptographically
// integrity-checked, so we don't need to employ secure hashing to trim it down to 64 bits,
// we can simply truncate.
//
// per the Rust reference, we can truncate by simply casting:
// https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#numeric-cast
filter.check_and_insert(nonce as u64, &state.config)
}
}
/// Default to 20 MiB max memory consumption and expected one million hits
///
/// With the default validation token lifetime of 2 weeks, this corresponds to one token usage per
/// 1.21 seconds.
impl Default for BloomTokenLog {
fn default() -> Self {
Self::new_expected_items(DEFAULT_MAX_BYTES, DEFAULT_EXPECTED_HITS)
}
}
/// Lockable state of [`BloomTokenLog`]
struct State {
config: FilterConfig,
// filter_1 covers tokens that expire in the period starting at period_1_start and extending
// lifetime after. filter_2 covers tokens for the next lifetime after that.
period_1_start: SystemTime,
filter_1: Filter,
filter_2: Filter,
}
/// Unchanging parameters governing [`Filter`] behavior
struct FilterConfig {
filter_max_bytes: usize,
k_num: u32,
}
/// Period filter within [`State`]
enum Filter {
Set(HashSet<u64, IdentityBuildHasher>),
Bloom(BloomFilter<FxBuildHasher>),
}
impl Filter {
fn check_and_insert(
&mut self,
fingerprint: u64,
config: &FilterConfig,
) -> Result<(), TokenReuseError> {
match self {
Self::Set(hset) => {
if !hset.insert(fingerprint) {
return Err(TokenReuseError);
}
if hset.capacity() * size_of::<u64>() <= config.filter_max_bytes {
return Ok(());
}
// convert to bloom
// avoid panicking if user passed in filter_max_bytes of 0. we document that this
// limit is approximate, so just fudge it up to 1.
let mut bloom = BloomFilter::with_num_bits((config.filter_max_bytes * 8).max(1))
.hasher(FxBuildHasher)
.hashes(config.k_num);
for item in &*hset {
bloom.insert(item);
}
*self = Self::Bloom(bloom);
}
Self::Bloom(bloom) => {
if bloom.insert(&fingerprint) {
return Err(TokenReuseError);
}
}
}
Ok(())
}
}
impl Default for Filter {
fn default() -> Self {
Self::Set(HashSet::default())
}
}
/// `BuildHasher` of `IdentityHasher`
#[derive(Default)]
struct IdentityBuildHasher;
impl BuildHasher for IdentityBuildHasher {
type Hasher = IdentityHasher;
fn build_hasher(&self) -> Self::Hasher {
IdentityHasher::default()
}
}
/// Hasher that is the identity operation--it assumes that exactly 8 bytes will be hashed, and the
/// resultant hash is those bytes as a `u64`
#[derive(Default)]
struct IdentityHasher {
data: [u8; 8],
#[cfg(debug_assertions)]
wrote_8_byte_slice: bool,
}
impl Hasher for IdentityHasher {
fn write(&mut self, bytes: &[u8]) {
#[cfg(debug_assertions)]
{
assert!(!self.wrote_8_byte_slice);
assert_eq!(bytes.len(), 8);
self.wrote_8_byte_slice = true;
}
self.data.copy_from_slice(bytes);
}
fn finish(&self) -> u64 {
#[cfg(debug_assertions)]
assert!(self.wrote_8_byte_slice);
u64::from_ne_bytes(self.data)
}
}
fn optimal_k_num(num_bytes: usize, expected_hits: u64) -> u32 {
// be more forgiving rather than panickey here. excessively high num_bits may occur if the user
// wishes it to be unbounded, so just saturate. expected_hits of 0 would cause divide-by-zero,
// so just fudge it up to 1 in that case.
let num_bits = (num_bytes as u64).saturating_mul(8);
let expected_hits = expected_hits.max(1);
// reference for this formula: https://programming.guide/bloom-filter-calculator.html
// optimal k = (m ln 2) / n
// wherein m is the number of bits, and n is the number of elements in the set.
//
// we also impose a minimum return value of 1, to avoid making the bloom filter entirely
// useless in the case that the user provided an absurdly high ratio of hits / bytes.
(((num_bits as f64 / expected_hits as f64) * LN_2).round() as u32).max(1)
}
// remember to change the doc comment for `impl Default for BloomTokenLog` if these ever change
const DEFAULT_MAX_BYTES: usize = 10 << 20;
const DEFAULT_EXPECTED_HITS: u64 = 1_000_000;
#[cfg(test)]
mod test {
use super::*;
use rand::prelude::*;
use rand_pcg::Pcg32;
fn new_rng() -> impl Rng {
Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeef_u128.to_le_bytes())
}
#[test]
fn identity_hash_test() {
let mut rng = new_rng();
let builder = IdentityBuildHasher;
for _ in 0..100 {
let n = rng.random::<u64>();
let hash = builder.hash_one(n);
assert_eq!(hash, n);
}
}
#[test]
fn optimal_k_num_test() {
assert_eq!(optimal_k_num(10 << 20, 1_000_000), 58);
assert_eq!(optimal_k_num(10 << 20, 1_000_000_000_000_000), 1);
// assert that these don't panic:
optimal_k_num(10 << 20, 0);
optimal_k_num(usize::MAX, 1_000_000);
}
#[test]
fn bloom_token_log_conversion() {
let mut rng = new_rng();
let mut log = BloomTokenLog::new_expected_items(800, 200);
let issued = SystemTime::now();
let lifetime = Duration::from_secs(1_000_000);
for i in 0..200 {
let token = rng.random::<u128>();
let result = log.check_and_insert(token, issued, lifetime);
{
let filter = &log.0.lock().unwrap().filter_1;
if let Filter::Set(ref hset) = *filter {
assert!(hset.capacity() * size_of::<u64>() <= 800);
assert_eq!(hset.len(), i + 1);
assert!(result.is_ok());
} else {
assert!(i > 10, "definitely bloomed too early");
}
}
assert!(log.check_and_insert(token, issued, lifetime).is_err());
}
assert!(
matches!(log.0.get_mut().unwrap().filter_1, Filter::Bloom { .. }),
"didn't bloom"
);
}
#[test]
fn turn_over() {
let mut rng = new_rng();
let log = BloomTokenLog::new_expected_items(800, 200);
let lifetime = Duration::from_secs(1_000);
let mut old = Vec::default();
let mut accepted = 0;
for i in 0..200 {
let token = rng.random::<u128>();
let now = UNIX_EPOCH + lifetime * 10 + lifetime * i / 10;
let issued = now - lifetime.mul_f32(rng.random_range(0.0..3.0));
let result = log.check_and_insert(token, issued, lifetime);
if result.is_ok() {
accepted += 1;
}
old.push((token, issued));
let old_idx = rng.random_range(0..old.len());
let (old_token, old_issued) = old[old_idx];
assert!(
log.check_and_insert(old_token, old_issued, lifetime)
.is_err()
);
}
assert!(accepted > 0);
}
fn test_doesnt_panic(log: BloomTokenLog) {
let mut rng = new_rng();
let issued = SystemTime::now();
let lifetime = Duration::from_secs(1_000_000);
for _ in 0..200 {
let _ = log.check_and_insert(rng.random::<u128>(), issued, lifetime);
}
}
#[test]
fn max_bytes_zero() {
// "max bytes" is documented to be approximate. but make sure it doesn't panic.
test_doesnt_panic(BloomTokenLog::new_expected_items(0, 200));
}
#[test]
fn expected_hits_zero() {
test_doesnt_panic(BloomTokenLog::new_expected_items(100, 0));
}
#[test]
fn k_num_zero() {
test_doesnt_panic(BloomTokenLog::new(100, 0));
}
}

180
vendor/quinn-proto/src/cid_generator.rs vendored Normal file
View File

@@ -0,0 +1,180 @@
use std::hash::Hasher;
use rand::{Rng, RngCore};
use crate::Duration;
use crate::MAX_CID_SIZE;
use crate::shared::ConnectionId;
/// Generates connection IDs for incoming connections
pub trait ConnectionIdGenerator: Send + Sync {
/// Generates a new CID
///
/// Connection IDs MUST NOT contain any information that can be used by
/// an external observer (that is, one that does not cooperate with the
/// issuer) to correlate them with other connection IDs for the same
/// connection. They MUST have high entropy, e.g. due to encrypted data
/// or cryptographic-grade random data.
fn generate_cid(&mut self) -> ConnectionId;
/// Quickly determine whether `cid` could have been generated by this generator
///
/// False positives are permitted, but increase the cost of handling invalid packets.
fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
Ok(())
}
/// Returns the length of a CID for connections created by this generator
fn cid_len(&self) -> usize;
/// Returns the lifetime of generated Connection IDs
///
/// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
fn cid_lifetime(&self) -> Option<Duration>;
}
/// The connection ID was not recognized by the [`ConnectionIdGenerator`]
#[derive(Debug, Copy, Clone)]
pub struct InvalidCid;
/// Generates purely random connection IDs of a specified length
///
/// Random CIDs can be smaller than those produced by [`HashedConnectionIdGenerator`], but cannot be
/// usefully [`validate`](ConnectionIdGenerator::validate)d.
#[derive(Debug, Clone, Copy)]
pub struct RandomConnectionIdGenerator {
cid_len: usize,
lifetime: Option<Duration>,
}
impl Default for RandomConnectionIdGenerator {
fn default() -> Self {
Self {
cid_len: 8,
lifetime: None,
}
}
}
impl RandomConnectionIdGenerator {
/// Initialize Random CID generator with a fixed CID length
///
/// The given length must be less than or equal to MAX_CID_SIZE.
pub fn new(cid_len: usize) -> Self {
debug_assert!(cid_len <= MAX_CID_SIZE);
Self {
cid_len,
..Self::default()
}
}
/// Set the lifetime of CIDs created by this generator
pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
self.lifetime = Some(d);
self
}
}
impl ConnectionIdGenerator for RandomConnectionIdGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes_arr = [0; MAX_CID_SIZE];
rand::rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
ConnectionId::new(&bytes_arr[..self.cid_len])
}
/// Provide the length of dst_cid in short header packet
fn cid_len(&self) -> usize {
self.cid_len
}
fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
}
/// Generates 8-byte connection IDs that can be efficiently
/// [`validate`](ConnectionIdGenerator::validate)d
///
/// This generator uses a non-cryptographic hash and can therefore still be spoofed, but nonetheless
/// helps prevents Quinn from responding to non-QUIC packets at very low cost.
pub struct HashedConnectionIdGenerator {
key: u64,
lifetime: Option<Duration>,
}
impl HashedConnectionIdGenerator {
/// Create a generator with a random key
pub fn new() -> Self {
Self::from_key(rand::rng().random())
}
/// Create a generator with a specific key
///
/// Allows [`validate`](ConnectionIdGenerator::validate) to recognize a consistent set of
/// connection IDs across restarts
pub fn from_key(key: u64) -> Self {
Self {
key,
lifetime: None,
}
}
/// Set the lifetime of CIDs created by this generator
pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
self.lifetime = Some(d);
self
}
}
impl Default for HashedConnectionIdGenerator {
fn default() -> Self {
Self::new()
}
}
impl ConnectionIdGenerator for HashedConnectionIdGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
rand::rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
let mut hasher = rustc_hash::FxHasher::default();
hasher.write_u64(self.key);
hasher.write(&bytes_arr[..NONCE_LEN]);
bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
ConnectionId::new(&bytes_arr)
}
fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
let (nonce, signature) = cid.split_at(NONCE_LEN);
let mut hasher = rustc_hash::FxHasher::default();
hasher.write_u64(self.key);
hasher.write(nonce);
let expected = hasher.finish().to_le_bytes();
match expected[..SIGNATURE_LEN] == signature[..] {
true => Ok(()),
false => Err(InvalidCid),
}
}
fn cid_len(&self) -> usize {
NONCE_LEN + SIGNATURE_LEN
}
fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
}
const NONCE_LEN: usize = 3; // Good for more than 16 million connections
const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_keyed_cid() {
let mut generator = HashedConnectionIdGenerator::new();
let cid = generator.generate_cid();
generator.validate(&cid).unwrap();
}
}

303
vendor/quinn-proto/src/cid_queue.rs vendored Normal file
View File

@@ -0,0 +1,303 @@
use std::ops::Range;
use crate::{ConnectionId, ResetToken, frame::NewConnectionId};
/// DataType stored in CidQueue buffer
type CidData = (ConnectionId, Option<ResetToken>);
/// Sliding window of active Connection IDs
///
/// May contain gaps due to packet loss or reordering
#[derive(Debug)]
pub(crate) struct CidQueue {
/// Ring buffer indexed by `self.cursor`
buffer: [Option<CidData>; Self::LEN],
/// Index at which circular buffer addressing is based
cursor: usize,
/// Sequence number of `self.buffer[cursor]`
///
/// The sequence number of the active CID; must be the smallest among CIDs in `buffer`.
offset: u64,
}
impl CidQueue {
pub(crate) fn new(cid: ConnectionId) -> Self {
let mut buffer = [None; Self::LEN];
buffer[0] = Some((cid, None));
Self {
buffer,
cursor: 0,
offset: 0,
}
}
/// Handle a `NEW_CONNECTION_ID` frame
///
/// Returns a non-empty range of retired sequence numbers and the reset token of the new active
/// CID iff any CIDs were retired.
pub(crate) fn insert(
&mut self,
cid: NewConnectionId,
) -> Result<Option<(Range<u64>, ResetToken)>, InsertError> {
// Position of new CID wrt. the current active CID
let index = match cid.sequence.checked_sub(self.offset) {
None => return Err(InsertError::Retired),
Some(x) => x,
};
let retired_count = cid.retire_prior_to.saturating_sub(self.offset);
if index >= Self::LEN as u64 + retired_count {
return Err(InsertError::ExceedsLimit);
}
// Discard retired CIDs, if any
for i in 0..(retired_count.min(Self::LEN as u64) as usize) {
self.buffer[(self.cursor + i) % Self::LEN] = None;
}
// Record the new CID
let index = ((self.cursor as u64 + index) % Self::LEN as u64) as usize;
self.buffer[index] = Some((cid.id, Some(cid.reset_token)));
if retired_count == 0 {
return Ok(None);
}
// The active CID was retired. Find the first known CID with sequence number of at least
// retire_prior_to, and inform the caller that all prior CIDs have been retired, and of
// the new CID's reset token.
self.cursor = ((self.cursor as u64 + retired_count) % Self::LEN as u64) as usize;
let (i, (_, token)) = self
.iter()
.next()
.expect("it is impossible to retire a CID without supplying a new one");
self.cursor = (self.cursor + i) % Self::LEN;
let orig_offset = self.offset;
self.offset = cid.retire_prior_to + i as u64;
// We don't immediately retire CIDs in the range (orig_offset +
// Self::LEN)..self.offset. These are CIDs that we haven't yet received from a
// NEW_CONNECTION_ID frame, since having previously received them would violate the
// connection ID limit we specified based on Self::LEN. If we do receive a such a frame
// in the future, e.g. due to reordering, we'll retire it then. This ensures we can't be
// made to buffer an arbitrarily large number of RETIRE_CONNECTION_ID frames.
Ok(Some((
orig_offset..self.offset.min(orig_offset + Self::LEN as u64),
token.expect("non-initial CID missing reset token"),
)))
}
/// Switch to next active CID if possible, return
/// 1) the corresponding ResetToken and 2) a non-empty range preceding it to retire
pub(crate) fn next(&mut self) -> Option<(ResetToken, Range<u64>)> {
let (i, cid_data) = self.iter().nth(1)?;
self.buffer[self.cursor] = None;
let orig_offset = self.offset;
self.offset += i as u64;
self.cursor = (self.cursor + i) % Self::LEN;
Some((cid_data.1.unwrap(), orig_offset..self.offset))
}
/// Iterate CIDs in CidQueue that are not `None`, including the active CID
fn iter(&self) -> impl Iterator<Item = (usize, CidData)> + '_ {
(0..Self::LEN).filter_map(move |step| {
let index = (self.cursor + step) % Self::LEN;
self.buffer[index].map(|cid_data| (step, cid_data))
})
}
/// Replace the initial CID
pub(crate) fn update_initial_cid(&mut self, cid: ConnectionId) {
debug_assert_eq!(self.offset, 0);
self.buffer[self.cursor] = Some((cid, None));
}
/// Return active remote CID itself
pub(crate) fn active(&self) -> ConnectionId {
self.buffer[self.cursor].unwrap().0
}
/// Return the sequence number of active remote CID
pub(crate) fn active_seq(&self) -> u64 {
self.offset
}
pub(crate) const LEN: usize = 5;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) enum InsertError {
/// CID was already retired
Retired,
/// Sequence number violates the leading edge of the window
ExceedsLimit,
}
#[cfg(test)]
mod tests {
use super::*;
fn cid(sequence: u64, retire_prior_to: u64) -> NewConnectionId {
NewConnectionId {
sequence,
id: ConnectionId::new(&[0xAB; 8]),
reset_token: ResetToken::from([0xCD; crate::RESET_TOKEN_SIZE]),
retire_prior_to,
}
}
fn initial_cid() -> ConnectionId {
ConnectionId::new(&[0xFF; 8])
}
#[test]
fn next_dense() {
let mut q = CidQueue::new(initial_cid());
assert!(q.next().is_none());
assert!(q.next().is_none());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
for i in 1..CidQueue::LEN as u64 {
let (_, retire) = q.next().unwrap();
assert_eq!(q.active_seq(), i);
assert_eq!(retire.end - retire.start, 1);
}
assert!(q.next().is_none());
}
#[test]
fn next_sparse() {
let mut q = CidQueue::new(initial_cid());
let seqs = (1..CidQueue::LEN as u64).filter(|x| x % 2 == 0);
for i in seqs.clone() {
q.insert(cid(i, 0)).unwrap();
}
for i in seqs {
let (_, retire) = q.next().unwrap();
dbg!(&retire);
assert_eq!(q.active_seq(), i);
assert_eq!(retire, (q.active_seq().saturating_sub(2))..q.active_seq());
}
assert!(q.next().is_none());
}
#[test]
fn wrap() {
let mut q = CidQueue::new(initial_cid());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
for _ in 1..(CidQueue::LEN as u64 - 1) {
q.next().unwrap();
}
for i in CidQueue::LEN as u64..(CidQueue::LEN as u64 + 3) {
q.insert(cid(i, 0)).unwrap();
}
for i in (CidQueue::LEN as u64 - 1)..(CidQueue::LEN as u64 + 3) {
q.next().unwrap();
assert_eq!(q.active_seq(), i);
}
assert!(q.next().is_none());
}
#[test]
fn retire_dense() {
let mut q = CidQueue::new(initial_cid());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
assert_eq!(q.active_seq(), 0);
assert_eq!(q.insert(cid(4, 2)).unwrap().unwrap().0, 0..2);
assert_eq!(q.active_seq(), 2);
assert_eq!(q.insert(cid(4, 2)), Ok(None));
for i in 2..(CidQueue::LEN as u64 - 1) {
let _ = q.next().unwrap();
assert_eq!(q.active_seq(), i + 1);
assert_eq!(q.insert(cid(i + 1, i + 1)), Ok(None));
}
assert!(q.next().is_none());
}
#[test]
fn retire_sparse() {
// Retiring CID 0 when CID 1 is not known should retire CID 1 as we move to CID 2
let mut q = CidQueue::new(initial_cid());
q.insert(cid(2, 0)).unwrap();
assert_eq!(q.insert(cid(3, 1)).unwrap().unwrap().0, 0..2,);
assert_eq!(q.active_seq(), 2);
}
#[test]
fn retire_many() {
let mut q = CidQueue::new(initial_cid());
q.insert(cid(2, 0)).unwrap();
assert_eq!(
q.insert(cid(1_000_000, 1_000_000)).unwrap().unwrap().0,
0..CidQueue::LEN as u64,
);
assert_eq!(q.active_seq(), 1_000_000);
}
#[test]
fn insert_limit() {
let mut q = CidQueue::new(initial_cid());
assert_eq!(q.insert(cid(CidQueue::LEN as u64 - 1, 0)), Ok(None));
assert_eq!(
q.insert(cid(CidQueue::LEN as u64, 0)),
Err(InsertError::ExceedsLimit)
);
}
#[test]
fn insert_duplicate() {
let mut q = CidQueue::new(initial_cid());
q.insert(cid(0, 0)).unwrap();
q.insert(cid(0, 0)).unwrap();
}
#[test]
fn insert_retired() {
let mut q = CidQueue::new(initial_cid());
assert_eq!(
q.insert(cid(0, 0)),
Ok(None),
"reinserting active CID succeeds"
);
assert!(q.next().is_none(), "active CID isn't requeued");
q.insert(cid(1, 0)).unwrap();
q.next().unwrap();
assert_eq!(
q.insert(cid(0, 0)),
Err(InsertError::Retired),
"previous active CID is already retired"
);
}
#[test]
fn retire_then_insert_next() {
let mut q = CidQueue::new(initial_cid());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
q.next().unwrap();
q.insert(cid(CidQueue::LEN as u64, 0)).unwrap();
assert_eq!(
q.insert(cid(CidQueue::LEN as u64 + 1, 0)),
Err(InsertError::ExceedsLimit)
);
}
#[test]
fn always_valid() {
let mut q = CidQueue::new(initial_cid());
assert!(q.next().is_none());
assert_eq!(q.active(), initial_cid());
assert_eq!(q.active_seq(), 0);
}
}

130
vendor/quinn-proto/src/coding.rs vendored Normal file
View File

@@ -0,0 +1,130 @@
//! Coding related traits.
use std::net::{Ipv4Addr, Ipv6Addr};
use bytes::{Buf, BufMut};
use thiserror::Error;
use crate::VarInt;
/// Error indicating that the provided buffer was too small
#[derive(Error, Debug, Copy, Clone, Eq, PartialEq)]
#[error("unexpected end of buffer")]
pub struct UnexpectedEnd;
/// Coding result type
pub type Result<T> = ::std::result::Result<T, UnexpectedEnd>;
/// Infallible encoding and decoding of QUIC primitives
pub trait Codec: Sized {
/// Decode a `Self` from the provided buffer, if the buffer is large enough
fn decode<B: Buf>(buf: &mut B) -> Result<Self>;
/// Append the encoding of `self` to the provided buffer
fn encode<B: BufMut>(&self, buf: &mut B);
}
impl Codec for u8 {
fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
if buf.remaining() < 1 {
return Err(UnexpectedEnd);
}
Ok(buf.get_u8())
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.put_u8(*self);
}
}
impl Codec for u16 {
fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
if buf.remaining() < 2 {
return Err(UnexpectedEnd);
}
Ok(buf.get_u16())
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.put_u16(*self);
}
}
impl Codec for u32 {
fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
if buf.remaining() < 4 {
return Err(UnexpectedEnd);
}
Ok(buf.get_u32())
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.put_u32(*self);
}
}
impl Codec for u64 {
fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
if buf.remaining() < 8 {
return Err(UnexpectedEnd);
}
Ok(buf.get_u64())
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.put_u64(*self);
}
}
impl Codec for Ipv4Addr {
fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
if buf.remaining() < 4 {
return Err(UnexpectedEnd);
}
let mut octets = [0; 4];
buf.copy_to_slice(&mut octets);
Ok(octets.into())
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.put_slice(&self.octets());
}
}
impl Codec for Ipv6Addr {
fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
if buf.remaining() < 16 {
return Err(UnexpectedEnd);
}
let mut octets = [0; 16];
buf.copy_to_slice(&mut octets);
Ok(octets.into())
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.put_slice(&self.octets());
}
}
pub(crate) trait BufExt {
fn get<T: Codec>(&mut self) -> Result<T>;
fn get_var(&mut self) -> Result<u64>;
}
impl<T: Buf> BufExt for T {
fn get<U: Codec>(&mut self) -> Result<U> {
U::decode(self)
}
fn get_var(&mut self) -> Result<u64> {
Ok(VarInt::decode(self)?.into_inner())
}
}
pub(crate) trait BufMutExt {
fn write<T: Codec>(&mut self, x: T);
fn write_var(&mut self, x: u64);
}
impl<T: BufMut> BufMutExt for T {
fn write<U: Codec>(&mut self, x: U) {
x.encode(self);
}
fn write_var(&mut self, x: u64) {
VarInt::from_u64(x).unwrap().encode(self);
}
}

697
vendor/quinn-proto/src/config/mod.rs vendored Normal file
View File

@@ -0,0 +1,697 @@
use std::{
fmt,
net::{SocketAddrV4, SocketAddrV6},
num::TryFromIntError,
sync::Arc,
};
#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
use rustls::client::WebPkiServerVerifier;
#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use thiserror::Error;
#[cfg(feature = "bloom")]
use crate::BloomTokenLog;
#[cfg(not(feature = "bloom"))]
use crate::NoneTokenLog;
#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
use crate::crypto::rustls::{QuicServerConfig, configured_provider};
use crate::{
DEFAULT_SUPPORTED_VERSIONS, Duration, MAX_CID_SIZE, RandomConnectionIdGenerator, SystemTime,
TokenLog, TokenMemoryCache, TokenStore, VarInt, VarIntBoundsExceeded,
cid_generator::{ConnectionIdGenerator, HashedConnectionIdGenerator},
crypto::{self, HandshakeTokenKey, HmacKey},
shared::ConnectionId,
};
mod transport;
#[cfg(feature = "qlog")]
pub use transport::QlogConfig;
pub use transport::{AckFrequencyConfig, IdleTimeout, MtuDiscoveryConfig, TransportConfig};
/// Global configuration for the endpoint, affecting all connections
///
/// Default values should be suitable for most internet applications.
#[derive(Clone)]
pub struct EndpointConfig {
pub(crate) reset_key: Arc<dyn HmacKey>,
pub(crate) max_udp_payload_size: VarInt,
/// CID generator factory
///
/// Create a cid generator for local cid in Endpoint struct
pub(crate) connection_id_generator_factory:
Arc<dyn Fn() -> Box<dyn ConnectionIdGenerator> + Send + Sync>,
pub(crate) supported_versions: Vec<u32>,
pub(crate) grease_quic_bit: bool,
/// Minimum interval between outgoing stateless reset packets
pub(crate) min_reset_interval: Duration,
/// Optional seed to be used internally for random number generation
pub(crate) rng_seed: Option<[u8; 32]>,
}
impl EndpointConfig {
/// Create a default config with a particular `reset_key`
pub fn new(reset_key: Arc<dyn HmacKey>) -> Self {
let cid_factory =
|| -> Box<dyn ConnectionIdGenerator> { Box::<HashedConnectionIdGenerator>::default() };
Self {
reset_key,
max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers
connection_id_generator_factory: Arc::new(cid_factory),
supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(),
grease_quic_bit: true,
min_reset_interval: Duration::from_millis(20),
rng_seed: None,
}
}
/// Supply a custom connection ID generator factory
///
/// Called once by each `Endpoint` constructed from this configuration to obtain the CID
/// generator which will be used to generate the CIDs used for incoming packets on all
/// connections involving that `Endpoint`. A custom CID generator allows applications to embed
/// information in local connection IDs, e.g. to support stateless packet-level load balancers.
///
/// Defaults to [`HashedConnectionIdGenerator`].
pub fn cid_generator<F: Fn() -> Box<dyn ConnectionIdGenerator> + Send + Sync + 'static>(
&mut self,
factory: F,
) -> &mut Self {
self.connection_id_generator_factory = Arc::new(factory);
self
}
/// Private key used to send authenticated connection resets to peers who were
/// communicating with a previous instance of this endpoint.
pub fn reset_key(&mut self, key: Arc<dyn HmacKey>) -> &mut Self {
self.reset_key = key;
self
}
/// Maximum UDP payload size accepted from peers (excluding UDP and IP overhead).
///
/// Must be greater or equal than 1200.
///
/// Defaults to 1472, which is the largest UDP payload that can be transmitted in the typical
/// 1500 byte Ethernet MTU. Deployments on links with larger MTUs (e.g. loopback or Ethernet
/// with jumbo frames) can raise this to improve performance at the cost of a linear increase in
/// datagram receive buffer size.
pub fn max_udp_payload_size(&mut self, value: u16) -> Result<&mut Self, ConfigError> {
if !(1200..=65_527).contains(&value) {
return Err(ConfigError::OutOfBounds);
}
self.max_udp_payload_size = value.into();
Ok(self)
}
/// Get the current value of [`max_udp_payload_size`](Self::max_udp_payload_size)
//
// While most parameters don't need to be readable, this must be exposed to allow higher-level
// layers, e.g. the `quinn` crate, to determine how large a receive buffer to allocate to
// support an externally-defined `EndpointConfig`.
//
// While `get_` accessors are typically unidiomatic in Rust, we favor concision for setters,
// which will be used far more heavily.
pub fn get_max_udp_payload_size(&self) -> u64 {
self.max_udp_payload_size.into()
}
/// Override supported QUIC versions
pub fn supported_versions(&mut self, supported_versions: Vec<u32>) -> &mut Self {
self.supported_versions = supported_versions;
self
}
/// Whether to accept QUIC packets containing any value for the fixed bit
///
/// Enabled by default. Helps protect against protocol ossification and makes traffic less
/// identifiable to observers. Disable if helping observers identify this traffic as QUIC is
/// desired.
pub fn grease_quic_bit(&mut self, value: bool) -> &mut Self {
self.grease_quic_bit = value;
self
}
/// Minimum interval between outgoing stateless reset packets
///
/// Defaults to 20ms. Limits the impact of attacks which flood an endpoint with garbage packets,
/// e.g. [ISAKMP/IKE amplification]. Larger values provide a stronger defense, but may delay
/// detection of some error conditions by clients. Using a [`ConnectionIdGenerator`] with a low
/// rate of false positives in [`validate`](ConnectionIdGenerator::validate) reduces the risk
/// incurred by a small minimum reset interval.
///
/// [ISAKMP/IKE
/// amplification]: https://bughunters.google.com/blog/5960150648750080/preventing-cross-service-udp-loops-in-quic#isakmp-ike-amplification-vs-quic
pub fn min_reset_interval(&mut self, value: Duration) -> &mut Self {
self.min_reset_interval = value;
self
}
/// Optional seed to be used internally for random number generation
///
/// By default, quinn will initialize an endpoint's rng using a platform entropy source.
/// However, you can seed the rng yourself through this method (e.g. if you need to run quinn
/// deterministically or if you are using quinn in an environment that doesn't have a source of
/// entropy available).
pub fn rng_seed(&mut self, seed: Option<[u8; 32]>) -> &mut Self {
self.rng_seed = seed;
self
}
}
impl fmt::Debug for EndpointConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("EndpointConfig")
// reset_key not debug
.field("max_udp_payload_size", &self.max_udp_payload_size)
// cid_generator_factory not debug
.field("supported_versions", &self.supported_versions)
.field("grease_quic_bit", &self.grease_quic_bit)
.field("rng_seed", &self.rng_seed)
.finish_non_exhaustive()
}
}
#[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
impl Default for EndpointConfig {
fn default() -> Self {
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use aws_lc_rs::hmac;
use rand::RngCore;
#[cfg(feature = "ring")]
use ring::hmac;
let mut reset_key = [0; 64];
rand::rng().fill_bytes(&mut reset_key);
Self::new(Arc::new(hmac::Key::new(hmac::HMAC_SHA256, &reset_key)))
}
}
/// Parameters governing incoming connections
///
/// Default values should be suitable for most internet applications.
#[derive(Clone)]
pub struct ServerConfig {
/// Transport configuration to use for incoming connections
pub transport: Arc<TransportConfig>,
/// TLS configuration used for incoming connections
///
/// Must be set to use TLS 1.3 only.
pub crypto: Arc<dyn crypto::ServerConfig>,
/// Configuration for sending and handling validation tokens
pub validation_token: ValidationTokenConfig,
/// Used to generate one-time AEAD keys to protect handshake tokens
pub(crate) token_key: Arc<dyn HandshakeTokenKey>,
/// Duration after a retry token was issued for which it's considered valid
pub(crate) retry_token_lifetime: Duration,
/// Whether to allow clients to migrate to new addresses
///
/// Improves behavior for clients that move between different internet connections or suffer NAT
/// rebinding. Enabled by default.
pub(crate) migration: bool,
pub(crate) preferred_address_v4: Option<SocketAddrV4>,
pub(crate) preferred_address_v6: Option<SocketAddrV6>,
pub(crate) max_incoming: usize,
pub(crate) incoming_buffer_size: u64,
pub(crate) incoming_buffer_size_total: u64,
pub(crate) time_source: Arc<dyn TimeSource>,
}
impl ServerConfig {
/// Create a default config with a particular handshake token key
pub fn new(
crypto: Arc<dyn crypto::ServerConfig>,
token_key: Arc<dyn HandshakeTokenKey>,
) -> Self {
Self {
transport: Arc::new(TransportConfig::default()),
crypto,
token_key,
retry_token_lifetime: Duration::from_secs(15),
migration: true,
validation_token: ValidationTokenConfig::default(),
preferred_address_v4: None,
preferred_address_v6: None,
max_incoming: 1 << 16,
incoming_buffer_size: 10 << 20,
incoming_buffer_size_total: 100 << 20,
time_source: Arc::new(StdSystemTime),
}
}
/// Set a custom [`TransportConfig`]
pub fn transport_config(&mut self, transport: Arc<TransportConfig>) -> &mut Self {
self.transport = transport;
self
}
/// Set a custom [`ValidationTokenConfig`]
pub fn validation_token_config(
&mut self,
validation_token: ValidationTokenConfig,
) -> &mut Self {
self.validation_token = validation_token;
self
}
/// Private key used to authenticate data included in handshake tokens
pub fn token_key(&mut self, value: Arc<dyn HandshakeTokenKey>) -> &mut Self {
self.token_key = value;
self
}
/// Duration after a retry token was issued for which it's considered valid
///
/// Defaults to 15 seconds.
pub fn retry_token_lifetime(&mut self, value: Duration) -> &mut Self {
self.retry_token_lifetime = value;
self
}
/// Whether to allow clients to migrate to new addresses
///
/// Improves behavior for clients that move between different internet connections or suffer NAT
/// rebinding. Enabled by default.
pub fn migration(&mut self, value: bool) -> &mut Self {
self.migration = value;
self
}
/// The preferred IPv4 address that will be communicated to clients during handshaking
///
/// If the client is able to reach this address, it will switch to it.
pub fn preferred_address_v4(&mut self, address: Option<SocketAddrV4>) -> &mut Self {
self.preferred_address_v4 = address;
self
}
/// The preferred IPv6 address that will be communicated to clients during handshaking
///
/// If the client is able to reach this address, it will switch to it.
pub fn preferred_address_v6(&mut self, address: Option<SocketAddrV6>) -> &mut Self {
self.preferred_address_v6 = address;
self
}
/// Maximum number of [`Incoming`][crate::Incoming] to allow to exist at a time
///
/// An [`Incoming`][crate::Incoming] comes into existence when an incoming connection attempt
/// is received and stops existing when the application either accepts it or otherwise disposes
/// of it. While this limit is reached, new incoming connection attempts are immediately
/// refused. Larger values have greater worst-case memory consumption, but accommodate greater
/// application latency in handling incoming connection attempts.
///
/// The default value is set to 65536. With a typical Ethernet MTU of 1500 bytes, this limits
/// memory consumption from this to under 100 MiB--a generous amount that still prevents memory
/// exhaustion in most contexts.
pub fn max_incoming(&mut self, max_incoming: usize) -> &mut Self {
self.max_incoming = max_incoming;
self
}
/// Maximum number of received bytes to buffer for each [`Incoming`][crate::Incoming]
///
/// An [`Incoming`][crate::Incoming] comes into existence when an incoming connection attempt
/// is received and stops existing when the application either accepts it or otherwise disposes
/// of it. This limit governs only packets received within that period, and does not include
/// the first packet. Packets received in excess of this limit are dropped, which may cause
/// 0-RTT or handshake data to have to be retransmitted.
///
/// The default value is set to 10 MiB--an amount such that in most situations a client would
/// not transmit that much 0-RTT data faster than the server handles the corresponding
/// [`Incoming`][crate::Incoming].
pub fn incoming_buffer_size(&mut self, incoming_buffer_size: u64) -> &mut Self {
self.incoming_buffer_size = incoming_buffer_size;
self
}
/// Maximum number of received bytes to buffer for all [`Incoming`][crate::Incoming]
/// collectively
///
/// An [`Incoming`][crate::Incoming] comes into existence when an incoming connection attempt
/// is received and stops existing when the application either accepts it or otherwise disposes
/// of it. This limit governs only packets received within that period, and does not include
/// the first packet. Packets received in excess of this limit are dropped, which may cause
/// 0-RTT or handshake data to have to be retransmitted.
///
/// The default value is set to 100 MiB--a generous amount that still prevents memory
/// exhaustion in most contexts.
pub fn incoming_buffer_size_total(&mut self, incoming_buffer_size_total: u64) -> &mut Self {
self.incoming_buffer_size_total = incoming_buffer_size_total;
self
}
/// Object to get current [`SystemTime`]
///
/// This exists to allow system time to be mocked in tests, or wherever else desired.
///
/// Defaults to [`StdSystemTime`], which simply calls [`SystemTime::now()`](SystemTime::now).
pub fn time_source(&mut self, time_source: Arc<dyn TimeSource>) -> &mut Self {
self.time_source = time_source;
self
}
pub(crate) fn has_preferred_address(&self) -> bool {
self.preferred_address_v4.is_some() || self.preferred_address_v6.is_some()
}
}
#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
impl ServerConfig {
/// Create a server config with the given certificate chain to be presented to clients
///
/// Uses a randomized handshake token key.
pub fn with_single_cert(
cert_chain: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) -> Result<Self, rustls::Error> {
Ok(Self::with_crypto(Arc::new(QuicServerConfig::new(
cert_chain, key,
)?)))
}
}
#[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
impl ServerConfig {
/// Create a server config with the given [`crypto::ServerConfig`]
///
/// Uses a randomized handshake token key.
pub fn with_crypto(crypto: Arc<dyn crypto::ServerConfig>) -> Self {
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use aws_lc_rs::hkdf;
use rand::RngCore;
#[cfg(feature = "ring")]
use ring::hkdf;
let rng = &mut rand::rng();
let mut master_key = [0u8; 64];
rng.fill_bytes(&mut master_key);
let master_key = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
Self::new(crypto, Arc::new(master_key))
}
}
impl fmt::Debug for ServerConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("ServerConfig")
.field("transport", &self.transport)
// crypto not debug
// token not debug
.field("retry_token_lifetime", &self.retry_token_lifetime)
.field("validation_token", &self.validation_token)
.field("migration", &self.migration)
.field("preferred_address_v4", &self.preferred_address_v4)
.field("preferred_address_v6", &self.preferred_address_v6)
.field("max_incoming", &self.max_incoming)
.field("incoming_buffer_size", &self.incoming_buffer_size)
.field(
"incoming_buffer_size_total",
&self.incoming_buffer_size_total,
)
// system_time_clock not debug
.finish_non_exhaustive()
}
}
/// Configuration for sending and handling validation tokens in incoming connections
///
/// Default values should be suitable for most internet applications.
///
/// ## QUIC Tokens
///
/// The QUIC protocol defines a concept of "[address validation][1]". Essentially, one side of a
/// QUIC connection may appear to be receiving QUIC packets from a particular remote UDP address,
/// but it will only consider that remote address "validated" once it has convincing evidence that
/// the address is not being [spoofed][2].
///
/// Validation is important primarily because of QUIC's "anti-amplification limit." This limit
/// prevents a QUIC server from sending a client more than three times the number of bytes it has
/// received from the client on a given address until that address is validated. This is designed
/// to mitigate the ability of attackers to use QUIC-based servers as reflectors in [amplification
/// attacks][3].
///
/// A path may become validated in several ways. The server is always considered validated by the
/// client. The client usually begins in an unvalidated state upon first connecting or migrating,
/// but then becomes validated through various mechanisms that usually take one network round trip.
/// However, in some cases, a client which has previously attempted to connect to a server may have
/// been given a one-time use cryptographically secured "token" that it can send in a subsequent
/// connection attempt to be validated immediately.
///
/// There are two ways these tokens can originate:
///
/// - If the server responds to an incoming connection with `retry`, a "retry token" is minted and
/// sent to the client, which the client immediately uses to attempt to connect again. Retry
/// tokens operate on short timescales, such as 15 seconds.
/// - If a client's path within an active connection is validated, the server may send the client
/// one or more "validation tokens," which the client may store for use in later connections to
/// the same server. Validation tokens may be valid for much longer lifetimes than retry token.
///
/// The usage of validation tokens is most impactful in situations where 0-RTT data is also being
/// used--in particular, in situations where the server sends the client more than three times more
/// 0.5-RTT data than it has received 0-RTT data. Since the successful completion of a connection
/// handshake implicitly causes the client's address to be validated, transmission of 0.5-RTT data
/// is the main situation where a server might be sending application data to an address that could
/// be validated by token usage earlier than it would become validated without token usage.
///
/// [1]: https://www.rfc-editor.org/rfc/rfc9000.html#section-8
/// [2]: https://en.wikipedia.org/wiki/IP_address_spoofing
/// [3]: https://en.wikipedia.org/wiki/Denial-of-service_attack#Amplification
///
/// These tokens should not be confused with "stateless reset tokens," which are similarly named
/// but entirely unrelated.
#[derive(Clone)]
pub struct ValidationTokenConfig {
pub(crate) lifetime: Duration,
pub(crate) log: Arc<dyn TokenLog>,
pub(crate) sent: u32,
}
impl ValidationTokenConfig {
/// Duration after an address validation token was issued for which it's considered valid
///
/// This refers only to tokens sent in NEW_TOKEN frames, in contrast to retry tokens.
///
/// Defaults to 2 weeks.
pub fn lifetime(&mut self, value: Duration) -> &mut Self {
self.lifetime = value;
self
}
#[allow(rustdoc::redundant_explicit_links)] // which links are redundant depends on features
/// Set a custom [`TokenLog`]
///
/// If the `bloom` feature is enabled (which it is by default), defaults to a default
/// [`BloomTokenLog`][crate::BloomTokenLog], which is suitable for most internet applications.
///
/// If the `bloom` feature is disabled, defaults to [`NoneTokenLog`][crate::NoneTokenLog],
/// which makes the server ignore all address validation tokens (that is, tokens originating
/// from NEW_TOKEN frames--retry tokens are not affected).
pub fn log(&mut self, log: Arc<dyn TokenLog>) -> &mut Self {
self.log = log;
self
}
/// Number of address validation tokens sent to a client when its path is validated
///
/// This refers only to tokens sent in NEW_TOKEN frames, in contrast to retry tokens.
///
/// If the `bloom` feature is enabled (which it is by default), defaults to 2. Otherwise,
/// defaults to 0.
pub fn sent(&mut self, value: u32) -> &mut Self {
self.sent = value;
self
}
}
impl Default for ValidationTokenConfig {
fn default() -> Self {
#[cfg(feature = "bloom")]
let log = Arc::new(BloomTokenLog::default());
#[cfg(not(feature = "bloom"))]
let log = Arc::new(NoneTokenLog);
Self {
lifetime: Duration::from_secs(2 * 7 * 24 * 60 * 60),
log,
sent: if cfg!(feature = "bloom") { 2 } else { 0 },
}
}
}
impl fmt::Debug for ValidationTokenConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("ServerValidationTokenConfig")
.field("lifetime", &self.lifetime)
// log not debug
.field("sent", &self.sent)
.finish_non_exhaustive()
}
}
/// Configuration for outgoing connections
///
/// Default values should be suitable for most internet applications.
#[derive(Clone)]
#[non_exhaustive]
pub struct ClientConfig {
/// Transport configuration to use
pub(crate) transport: Arc<TransportConfig>,
/// Cryptographic configuration to use
pub(crate) crypto: Arc<dyn crypto::ClientConfig>,
/// Validation token store to use
pub(crate) token_store: Arc<dyn TokenStore>,
/// Provider that populates the destination connection ID of Initial Packets
pub(crate) initial_dst_cid_provider: Arc<dyn Fn() -> ConnectionId + Send + Sync>,
/// QUIC protocol version to use
pub(crate) version: u32,
}
impl ClientConfig {
/// Create a default config with a particular cryptographic config
pub fn new(crypto: Arc<dyn crypto::ClientConfig>) -> Self {
Self {
transport: Default::default(),
crypto,
token_store: Arc::new(TokenMemoryCache::default()),
initial_dst_cid_provider: Arc::new(|| {
RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid()
}),
version: 1,
}
}
/// Configure how to populate the destination CID of the initial packet when attempting to
/// establish a new connection
///
/// By default, it's populated with random bytes with reasonable length, so unless you have
/// a good reason, you do not need to change it.
///
/// When prefer to override the default, please note that the generated connection ID MUST be
/// at least 8 bytes long and unpredictable, as per section 7.2 of RFC 9000.
pub fn initial_dst_cid_provider(
&mut self,
initial_dst_cid_provider: Arc<dyn Fn() -> ConnectionId + Send + Sync>,
) -> &mut Self {
self.initial_dst_cid_provider = initial_dst_cid_provider;
self
}
/// Set a custom [`TransportConfig`]
pub fn transport_config(&mut self, transport: Arc<TransportConfig>) -> &mut Self {
self.transport = transport;
self
}
/// Set a custom [`TokenStore`]
///
/// Defaults to [`TokenMemoryCache`], which is suitable for most internet applications.
pub fn token_store(&mut self, store: Arc<dyn TokenStore>) -> &mut Self {
self.token_store = store;
self
}
/// Set the QUIC version to use
pub fn version(&mut self, version: u32) -> &mut Self {
self.version = version;
self
}
}
#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
impl ClientConfig {
/// Create a client configuration that trusts the platform's native roots
#[deprecated(since = "0.11.13", note = "use `try_with_platform_verifier()` instead")]
#[cfg(feature = "platform-verifier")]
pub fn with_platform_verifier() -> Self {
Self::try_with_platform_verifier().expect("use try_with_platform_verifier() instead")
}
/// Create a client configuration that trusts the platform's native roots
#[cfg(feature = "platform-verifier")]
pub fn try_with_platform_verifier() -> Result<Self, rustls::Error> {
Ok(Self::new(Arc::new(
crypto::rustls::QuicClientConfig::with_platform_verifier()?,
)))
}
/// Create a client configuration that trusts specified trust anchors
pub fn with_root_certificates(
roots: Arc<rustls::RootCertStore>,
) -> Result<Self, rustls::client::VerifierBuilderError> {
Ok(Self::new(Arc::new(crypto::rustls::QuicClientConfig::new(
WebPkiServerVerifier::builder_with_provider(roots, configured_provider()).build()?,
))))
}
}
impl fmt::Debug for ClientConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("ClientConfig")
.field("transport", &self.transport)
// crypto not debug
// token_store not debug
.field("version", &self.version)
.finish_non_exhaustive()
}
}
/// Errors in the configuration of an endpoint
#[derive(Debug, Error, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ConfigError {
/// Value exceeds supported bounds
#[error("value exceeds supported bounds")]
OutOfBounds,
}
impl From<TryFromIntError> for ConfigError {
fn from(_: TryFromIntError) -> Self {
Self::OutOfBounds
}
}
impl From<VarIntBoundsExceeded> for ConfigError {
fn from(_: VarIntBoundsExceeded) -> Self {
Self::OutOfBounds
}
}
/// Object to get current [`SystemTime`]
///
/// This exists to allow system time to be mocked in tests, or wherever else desired.
pub trait TimeSource: Send + Sync {
/// Get [`SystemTime::now()`](SystemTime::now) or the mocked equivalent
fn now(&self) -> SystemTime;
}
/// Default implementation of [`TimeSource`]
///
/// Implements `now` by calling [`SystemTime::now()`](SystemTime::now).
pub struct StdSystemTime;
impl TimeSource for StdSystemTime {
fn now(&self) -> SystemTime {
SystemTime::now()
}
}

View File

@@ -0,0 +1,785 @@
use std::{fmt, sync::Arc};
#[cfg(feature = "qlog")]
use std::{io, sync::Mutex, time::Instant};
#[cfg(feature = "qlog")]
use qlog::streamer::QlogStreamer;
#[cfg(feature = "qlog")]
use crate::QlogStream;
use crate::{
Duration, INITIAL_MTU, MAX_UDP_PAYLOAD, VarInt, VarIntBoundsExceeded, congestion,
connection::qlog::QlogSink,
};
/// Parameters governing the core QUIC state machine
///
/// Default values should be suitable for most internet applications. Applications protocols which
/// forbid remotely-initiated streams should set `max_concurrent_bidi_streams` and
/// `max_concurrent_uni_streams` to zero.
///
/// In some cases, performance or resource requirements can be improved by tuning these values to
/// suit a particular application and/or network connection. In particular, data window sizes can be
/// tuned for a particular expected round trip time, link capacity, and memory availability. Tuning
/// for higher bandwidths and latencies increases worst-case memory consumption, but does not impair
/// performance at lower bandwidths and latencies. The default configuration is tuned for a 100Mbps
/// link with a 100ms round trip time.
pub struct TransportConfig {
pub(crate) max_concurrent_bidi_streams: VarInt,
pub(crate) max_concurrent_uni_streams: VarInt,
pub(crate) max_idle_timeout: Option<VarInt>,
pub(crate) stream_receive_window: VarInt,
pub(crate) receive_window: VarInt,
pub(crate) send_window: u64,
pub(crate) send_fairness: bool,
pub(crate) packet_threshold: u32,
pub(crate) time_threshold: f32,
pub(crate) initial_rtt: Duration,
pub(crate) initial_mtu: u16,
pub(crate) min_mtu: u16,
pub(crate) mtu_discovery_config: Option<MtuDiscoveryConfig>,
pub(crate) pad_to_mtu: bool,
pub(crate) ack_frequency_config: Option<AckFrequencyConfig>,
pub(crate) persistent_congestion_threshold: u32,
pub(crate) keep_alive_interval: Option<Duration>,
pub(crate) crypto_buffer_size: usize,
pub(crate) allow_spin: bool,
pub(crate) datagram_receive_buffer_size: Option<usize>,
pub(crate) datagram_send_buffer_size: usize,
#[cfg(test)]
pub(crate) deterministic_packet_numbers: bool,
pub(crate) congestion_controller_factory: Arc<dyn congestion::ControllerFactory + Send + Sync>,
pub(crate) enable_segmentation_offload: bool,
pub(crate) qlog_sink: QlogSink,
}
impl TransportConfig {
/// Maximum number of incoming bidirectional streams that may be open concurrently
///
/// Must be nonzero for the peer to open any bidirectional streams.
///
/// Worst-case memory use is directly proportional to `max_concurrent_bidi_streams *
/// stream_receive_window`, with an upper bound proportional to `receive_window`.
pub fn max_concurrent_bidi_streams(&mut self, value: VarInt) -> &mut Self {
self.max_concurrent_bidi_streams = value;
self
}
/// Variant of `max_concurrent_bidi_streams` affecting unidirectional streams
pub fn max_concurrent_uni_streams(&mut self, value: VarInt) -> &mut Self {
self.max_concurrent_uni_streams = value;
self
}
/// Maximum duration of inactivity to accept before timing out the connection.
///
/// The true idle timeout is the minimum of this and the peer's own max idle timeout. `None`
/// represents an infinite timeout. Defaults to 30 seconds.
///
/// **WARNING**: If a peer or its network path malfunctions or acts maliciously, an infinite
/// idle timeout can result in permanently hung futures!
///
/// ```
/// # use std::{convert::TryInto, time::Duration};
/// # use quinn_proto::{TransportConfig, VarInt, VarIntBoundsExceeded};
/// # fn main() -> Result<(), VarIntBoundsExceeded> {
/// let mut config = TransportConfig::default();
///
/// // Set the idle timeout as `VarInt`-encoded milliseconds
/// config.max_idle_timeout(Some(VarInt::from_u32(10_000).into()));
///
/// // Set the idle timeout as a `Duration`
/// config.max_idle_timeout(Some(Duration::from_secs(10).try_into()?));
/// # Ok(())
/// # }
/// ```
pub fn max_idle_timeout(&mut self, value: Option<IdleTimeout>) -> &mut Self {
self.max_idle_timeout = value.map(|t| t.0);
self
}
/// Maximum number of bytes the peer may transmit without acknowledgement on any one stream
/// before becoming blocked.
///
/// This should be set to at least the expected connection latency multiplied by the maximum
/// desired throughput. Setting this smaller than `receive_window` helps ensure that a single
/// stream doesn't monopolize receive buffers, which may otherwise occur if the application
/// chooses not to read from a large stream for a time while still requiring data on other
/// streams.
pub fn stream_receive_window(&mut self, value: VarInt) -> &mut Self {
self.stream_receive_window = value;
self
}
/// Maximum number of bytes the peer may transmit across all streams of a connection before
/// becoming blocked.
///
/// This should be set to at least the expected connection latency multiplied by the maximum
/// desired throughput. Larger values can be useful to allow maximum throughput within a
/// stream while another is blocked.
pub fn receive_window(&mut self, value: VarInt) -> &mut Self {
self.receive_window = value;
self
}
/// Maximum number of bytes to transmit to a peer without acknowledgment
///
/// Provides an upper bound on memory when communicating with peers that issue large amounts of
/// flow control credit. Endpoints that wish to handle large numbers of connections robustly
/// should take care to set this low enough to guarantee memory exhaustion does not occur if
/// every connection uses the entire window.
pub fn send_window(&mut self, value: u64) -> &mut Self {
self.send_window = value;
self
}
/// Whether to implement fair queuing for send streams having the same priority.
///
/// When enabled, connections schedule data from outgoing streams having the same priority in a
/// round-robin fashion. When disabled, streams are scheduled in the order they are written to.
///
/// Note that this only affects streams with the same priority. Higher priority streams always
/// take precedence over lower priority streams.
///
/// Disabling fairness can reduce fragmentation and protocol overhead for workloads that use
/// many small streams.
pub fn send_fairness(&mut self, value: bool) -> &mut Self {
self.send_fairness = value;
self
}
/// Maximum reordering in packet number space before FACK style loss detection considers a
/// packet lost. Should not be less than 3, per RFC5681.
pub fn packet_threshold(&mut self, value: u32) -> &mut Self {
self.packet_threshold = value;
self
}
/// Maximum reordering in time space before time based loss detection considers a packet lost,
/// as a factor of RTT
pub fn time_threshold(&mut self, value: f32) -> &mut Self {
self.time_threshold = value;
self
}
/// The RTT used before an RTT sample is taken
pub fn initial_rtt(&mut self, value: Duration) -> &mut Self {
self.initial_rtt = value;
self
}
/// The initial value to be used as the maximum UDP payload size before running MTU discovery
/// (see [`TransportConfig::mtu_discovery_config`]).
///
/// Must be at least 1200, which is the default, and known to be safe for typical internet
/// applications. Larger values are more efficient, but increase the risk of packet loss due to
/// exceeding the network path's IP MTU. If the provided value is higher than what the network
/// path actually supports, packet loss will eventually trigger black hole detection and bring
/// it down to [`TransportConfig::min_mtu`].
pub fn initial_mtu(&mut self, value: u16) -> &mut Self {
self.initial_mtu = value.max(INITIAL_MTU);
self
}
pub(crate) fn get_initial_mtu(&self) -> u16 {
self.initial_mtu.max(self.min_mtu)
}
/// The maximum UDP payload size guaranteed to be supported by the network.
///
/// Must be at least 1200, which is the default, and lower than or equal to
/// [`TransportConfig::initial_mtu`].
///
/// Real-world MTUs can vary according to ISP, VPN, and properties of intermediate network links
/// outside of either endpoint's control. Extreme care should be used when raising this value
/// outside of private networks where these factors are fully controlled. If the provided value
/// is higher than what the network path actually supports, the result will be unpredictable and
/// catastrophic packet loss, without a possibility of repair. Prefer
/// [`TransportConfig::initial_mtu`] together with
/// [`TransportConfig::mtu_discovery_config`] to set a maximum UDP payload size that robustly
/// adapts to the network.
pub fn min_mtu(&mut self, value: u16) -> &mut Self {
self.min_mtu = value.max(INITIAL_MTU);
self
}
/// Specifies the MTU discovery config (see [`MtuDiscoveryConfig`] for details).
///
/// Enabled by default.
pub fn mtu_discovery_config(&mut self, value: Option<MtuDiscoveryConfig>) -> &mut Self {
self.mtu_discovery_config = value;
self
}
/// Pad UDP datagrams carrying application data to current maximum UDP payload size
///
/// Disabled by default. UDP datagrams containing loss probes are exempt from padding.
///
/// Enabling this helps mitigate traffic analysis by network observers, but it increases
/// bandwidth usage. Without this mitigation precise plain text size of application datagrams as
/// well as the total size of stream write bursts can be inferred by observers under certain
/// conditions. This analysis requires either an uncongested connection or application datagrams
/// too large to be coalesced.
pub fn pad_to_mtu(&mut self, value: bool) -> &mut Self {
self.pad_to_mtu = value;
self
}
/// Specifies the ACK frequency config (see [`AckFrequencyConfig`] for details)
///
/// The provided configuration will be ignored if the peer does not support the acknowledgement
/// frequency QUIC extension.
///
/// Defaults to `None`, which disables controlling the peer's acknowledgement frequency. Even
/// if set to `None`, the local side still supports the acknowledgement frequency QUIC
/// extension and may use it in other ways.
pub fn ack_frequency_config(&mut self, value: Option<AckFrequencyConfig>) -> &mut Self {
self.ack_frequency_config = value;
self
}
/// Number of consecutive PTOs after which network is considered to be experiencing persistent congestion.
pub fn persistent_congestion_threshold(&mut self, value: u32) -> &mut Self {
self.persistent_congestion_threshold = value;
self
}
/// Period of inactivity before sending a keep-alive packet
///
/// Keep-alive packets prevent an inactive but otherwise healthy connection from timing out.
///
/// `None` to disable, which is the default. Only one side of any given connection needs keep-alive
/// enabled for the connection to be preserved. Must be set lower than the idle_timeout of both
/// peers to be effective.
pub fn keep_alive_interval(&mut self, value: Option<Duration>) -> &mut Self {
self.keep_alive_interval = value;
self
}
/// Maximum quantity of out-of-order crypto layer data to buffer
pub fn crypto_buffer_size(&mut self, value: usize) -> &mut Self {
self.crypto_buffer_size = value;
self
}
/// Whether the implementation is permitted to set the spin bit on this connection
///
/// This allows passive observers to easily judge the round trip time of a connection, which can
/// be useful for network administration but sacrifices a small amount of privacy.
pub fn allow_spin(&mut self, value: bool) -> &mut Self {
self.allow_spin = value;
self
}
/// Maximum number of incoming application datagram bytes to buffer, or None to disable
/// incoming datagrams
///
/// The peer is forbidden to send single datagrams larger than this size. If the aggregate size
/// of all datagrams that have been received from the peer but not consumed by the application
/// exceeds this value, old datagrams are dropped until it is no longer exceeded.
pub fn datagram_receive_buffer_size(&mut self, value: Option<usize>) -> &mut Self {
self.datagram_receive_buffer_size = value;
self
}
/// Maximum number of outgoing application datagram bytes to buffer
///
/// While datagrams are sent ASAP, it is possible for an application to generate data faster
/// than the link, or even the underlying hardware, can transmit them. This limits the amount of
/// memory that may be consumed in that case. When the send buffer is full and a new datagram is
/// sent, older datagrams are dropped until sufficient space is available.
pub fn datagram_send_buffer_size(&mut self, value: usize) -> &mut Self {
self.datagram_send_buffer_size = value;
self
}
/// Whether to force every packet number to be used
///
/// By default, packet numbers are occasionally skipped to ensure peers aren't ACKing packets
/// before they see them.
#[cfg(test)]
pub(crate) fn deterministic_packet_numbers(&mut self, enabled: bool) -> &mut Self {
self.deterministic_packet_numbers = enabled;
self
}
/// How to construct new `congestion::Controller`s
///
/// Typically the refcounted configuration of a `congestion::Controller`,
/// e.g. a `congestion::NewRenoConfig`.
///
/// # Example
/// ```
/// # use quinn_proto::*; use std::sync::Arc;
/// let mut config = TransportConfig::default();
/// config.congestion_controller_factory(Arc::new(congestion::NewRenoConfig::default()));
/// ```
pub fn congestion_controller_factory(
&mut self,
factory: Arc<dyn congestion::ControllerFactory + Send + Sync + 'static>,
) -> &mut Self {
self.congestion_controller_factory = factory;
self
}
/// Whether to use "Generic Segmentation Offload" to accelerate transmits, when supported by the
/// environment
///
/// Defaults to `true`.
///
/// GSO dramatically reduces CPU consumption when sending large numbers of packets with the same
/// headers, such as when transmitting bulk data on a connection. However, it is not supported
/// by all network interface drivers or packet inspection tools. `quinn-udp` will attempt to
/// disable GSO automatically when unavailable, but this can lead to spurious packet loss at
/// startup, temporarily degrading performance.
pub fn enable_segmentation_offload(&mut self, enabled: bool) -> &mut Self {
self.enable_segmentation_offload = enabled;
self
}
/// qlog capture configuration to use for a particular connection
#[cfg(feature = "qlog")]
pub fn qlog_stream(&mut self, stream: Option<QlogStream>) -> &mut Self {
self.qlog_sink = stream.into();
self
}
}
impl Default for TransportConfig {
fn default() -> Self {
const EXPECTED_RTT: u32 = 100; // ms
const MAX_STREAM_BANDWIDTH: u32 = 12500 * 1000; // bytes/s
// Window size needed to avoid pipeline
// stalls
const STREAM_RWND: u32 = MAX_STREAM_BANDWIDTH / 1000 * EXPECTED_RTT;
Self {
max_concurrent_bidi_streams: 100u32.into(),
max_concurrent_uni_streams: 100u32.into(),
// 30 second default recommended by RFC 9308 § 3.2
max_idle_timeout: Some(VarInt(30_000)),
stream_receive_window: STREAM_RWND.into(),
receive_window: VarInt::MAX,
send_window: (8 * STREAM_RWND).into(),
send_fairness: true,
packet_threshold: 3,
time_threshold: 9.0 / 8.0,
initial_rtt: Duration::from_millis(333), // per spec, intentionally distinct from EXPECTED_RTT
initial_mtu: INITIAL_MTU,
min_mtu: INITIAL_MTU,
mtu_discovery_config: Some(MtuDiscoveryConfig::default()),
pad_to_mtu: false,
ack_frequency_config: None,
persistent_congestion_threshold: 3,
keep_alive_interval: None,
crypto_buffer_size: 16 * 1024,
allow_spin: true,
datagram_receive_buffer_size: Some(STREAM_RWND as usize),
datagram_send_buffer_size: 1024 * 1024,
#[cfg(test)]
deterministic_packet_numbers: false,
congestion_controller_factory: Arc::new(congestion::CubicConfig::default()),
enable_segmentation_offload: true,
qlog_sink: QlogSink::default(),
}
}
}
impl fmt::Debug for TransportConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self {
max_concurrent_bidi_streams,
max_concurrent_uni_streams,
max_idle_timeout,
stream_receive_window,
receive_window,
send_window,
send_fairness,
packet_threshold,
time_threshold,
initial_rtt,
initial_mtu,
min_mtu,
mtu_discovery_config,
pad_to_mtu,
ack_frequency_config,
persistent_congestion_threshold,
keep_alive_interval,
crypto_buffer_size,
allow_spin,
datagram_receive_buffer_size,
datagram_send_buffer_size,
#[cfg(test)]
deterministic_packet_numbers: _,
congestion_controller_factory: _,
enable_segmentation_offload,
qlog_sink,
} = self;
let mut s = fmt.debug_struct("TransportConfig");
s.field("max_concurrent_bidi_streams", max_concurrent_bidi_streams)
.field("max_concurrent_uni_streams", max_concurrent_uni_streams)
.field("max_idle_timeout", max_idle_timeout)
.field("stream_receive_window", stream_receive_window)
.field("receive_window", receive_window)
.field("send_window", send_window)
.field("send_fairness", send_fairness)
.field("packet_threshold", packet_threshold)
.field("time_threshold", time_threshold)
.field("initial_rtt", initial_rtt)
.field("initial_mtu", initial_mtu)
.field("min_mtu", min_mtu)
.field("mtu_discovery_config", mtu_discovery_config)
.field("pad_to_mtu", pad_to_mtu)
.field("ack_frequency_config", ack_frequency_config)
.field(
"persistent_congestion_threshold",
persistent_congestion_threshold,
)
.field("keep_alive_interval", keep_alive_interval)
.field("crypto_buffer_size", crypto_buffer_size)
.field("allow_spin", allow_spin)
.field("datagram_receive_buffer_size", datagram_receive_buffer_size)
.field("datagram_send_buffer_size", datagram_send_buffer_size)
// congestion_controller_factory not debug
.field("enable_segmentation_offload", enable_segmentation_offload);
if cfg!(feature = "qlog") {
s.field("qlog_stream", &qlog_sink.is_enabled());
}
s.finish_non_exhaustive()
}
}
/// Parameters for controlling the peer's acknowledgement frequency
///
/// The parameters provided in this config will be sent to the peer at the beginning of the
/// connection, so it can take them into account when sending acknowledgements (see each parameter's
/// description for details on how it influences acknowledgement frequency).
///
/// Quinn's implementation follows the fourth draft of the
/// [QUIC Acknowledgement Frequency extension](https://datatracker.ietf.org/doc/html/draft-ietf-quic-ack-frequency-04).
/// The defaults produce behavior slightly different than the behavior without this extension,
/// because they change the way reordered packets are handled (see
/// [`AckFrequencyConfig::reordering_threshold`] for details).
#[derive(Clone, Debug)]
pub struct AckFrequencyConfig {
pub(crate) ack_eliciting_threshold: VarInt,
pub(crate) max_ack_delay: Option<Duration>,
pub(crate) reordering_threshold: VarInt,
}
impl AckFrequencyConfig {
/// The ack-eliciting threshold we will request the peer to use
///
/// This threshold represents the number of ack-eliciting packets an endpoint may receive
/// without immediately sending an ACK.
///
/// The remote peer should send at least one ACK frame when more than this number of
/// ack-eliciting packets have been received. A value of 0 results in a receiver immediately
/// acknowledging every ack-eliciting packet.
///
/// Defaults to 1, which sends ACK frames for every other ack-eliciting packet.
pub fn ack_eliciting_threshold(&mut self, value: VarInt) -> &mut Self {
self.ack_eliciting_threshold = value;
self
}
/// The `max_ack_delay` we will request the peer to use
///
/// This parameter represents the maximum amount of time that an endpoint waits before sending
/// an ACK when the ack-eliciting threshold hasn't been reached.
///
/// The effective `max_ack_delay` will be clamped to be at least the peer's `min_ack_delay`
/// transport parameter, and at most the greater of the current path RTT or 25ms.
///
/// Defaults to `None`, in which case the peer's original `max_ack_delay` will be used, as
/// obtained from its transport parameters.
pub fn max_ack_delay(&mut self, value: Option<Duration>) -> &mut Self {
self.max_ack_delay = value;
self
}
/// The reordering threshold we will request the peer to use
///
/// This threshold represents the amount of out-of-order packets that will trigger an endpoint
/// to send an ACK, without waiting for `ack_eliciting_threshold` to be exceeded or for
/// `max_ack_delay` to be elapsed.
///
/// A value of 0 indicates out-of-order packets do not elicit an immediate ACK. A value of 1
/// immediately acknowledges any packets that are received out of order (this is also the
/// behavior when the extension is disabled).
///
/// It is recommended to set this value to [`TransportConfig::packet_threshold`] minus one.
/// Since the default value for [`TransportConfig::packet_threshold`] is 3, this value defaults
/// to 2.
pub fn reordering_threshold(&mut self, value: VarInt) -> &mut Self {
self.reordering_threshold = value;
self
}
}
impl Default for AckFrequencyConfig {
fn default() -> Self {
Self {
ack_eliciting_threshold: VarInt(1),
max_ack_delay: None,
reordering_threshold: VarInt(2),
}
}
}
/// Configuration for qlog trace logging
#[cfg(feature = "qlog")]
pub struct QlogConfig {
writer: Option<Box<dyn io::Write + Send + Sync>>,
title: Option<String>,
description: Option<String>,
start_time: Instant,
}
#[cfg(feature = "qlog")]
impl QlogConfig {
/// Where to write a qlog `TraceSeq`
pub fn writer(&mut self, writer: Box<dyn io::Write + Send + Sync>) -> &mut Self {
self.writer = Some(writer);
self
}
/// Title to record in the qlog capture
pub fn title(&mut self, title: Option<String>) -> &mut Self {
self.title = title;
self
}
/// Description to record in the qlog capture
pub fn description(&mut self, description: Option<String>) -> &mut Self {
self.description = description;
self
}
/// Epoch qlog event times are recorded relative to
pub fn start_time(&mut self, start_time: Instant) -> &mut Self {
self.start_time = start_time;
self
}
/// Construct the [`QlogStream`] described by this configuration
pub fn into_stream(self) -> Option<QlogStream> {
use tracing::warn;
let writer = self.writer?;
let trace = qlog::TraceSeq::new(
qlog::VantagePoint {
name: None,
ty: qlog::VantagePointType::Unknown,
flow: None,
},
self.title.clone(),
self.description.clone(),
Some(qlog::Configuration {
time_offset: Some(0.0),
original_uris: None,
}),
None,
);
let mut streamer = QlogStreamer::new(
qlog::QLOG_VERSION.into(),
self.title,
self.description,
None,
self.start_time,
trace,
qlog::events::EventImportance::Core,
writer,
);
match streamer.start_log() {
Ok(()) => Some(QlogStream(Arc::new(Mutex::new(streamer)))),
Err(e) => {
warn!("could not initialize endpoint qlog streamer: {e}");
None
}
}
}
}
#[cfg(feature = "qlog")]
impl Default for QlogConfig {
fn default() -> Self {
Self {
writer: None,
title: None,
description: None,
start_time: Instant::now(),
}
}
}
/// Parameters governing MTU discovery.
///
/// # The why of MTU discovery
///
/// By design, QUIC ensures during the handshake that the network path between the client and the
/// server is able to transmit unfragmented UDP packets with a body of 1200 bytes. In other words,
/// once the connection is established, we know that the network path's maximum transmission unit
/// (MTU) is of at least 1200 bytes (plus IP and UDP headers). Because of this, a QUIC endpoint can
/// split outgoing data in packets of 1200 bytes, with confidence that the network will be able to
/// deliver them (if the endpoint were to send bigger packets, they could prove too big and end up
/// being dropped).
///
/// There is, however, a significant overhead associated to sending a packet. If the same
/// information can be sent in fewer packets, that results in higher throughput. The amount of
/// packets that need to be sent is inversely proportional to the MTU: the higher the MTU, the
/// bigger the packets that can be sent, and the fewer packets that are needed to transmit a given
/// amount of bytes.
///
/// Most networks have an MTU higher than 1200. Through MTU discovery, endpoints can detect the
/// path's MTU and, if it turns out to be higher, start sending bigger packets.
///
/// # MTU discovery internals
///
/// Quinn implements MTU discovery through DPLPMTUD (Datagram Packetization Layer Path MTU
/// Discovery), described in [section 14.3 of RFC
/// 9000](https://www.rfc-editor.org/rfc/rfc9000.html#section-14.3). This method consists of sending
/// QUIC packets padded to a particular size (called PMTU probes), and waiting to see if the remote
/// peer responds with an ACK. If an ACK is received, that means the probe arrived at the remote
/// peer, which in turn means that the network path's MTU is of at least the packet's size. If the
/// probe is lost, it is sent another 2 times before concluding that the MTU is lower than the
/// packet's size.
///
/// MTU discovery runs on a schedule (e.g. every 600 seconds) specified through
/// [`MtuDiscoveryConfig::interval`]. The first run happens right after the handshake, and
/// subsequent discoveries are scheduled to run when the interval has elapsed, starting from the
/// last time when MTU discovery completed.
///
/// Since the search space for MTUs is quite big (the smallest possible MTU is 1200, and the highest
/// is 65527), Quinn performs a binary search to keep the number of probes as low as possible. The
/// lower bound of the search is equal to [`TransportConfig::initial_mtu`] in the
/// initial MTU discovery run, and is equal to the currently discovered MTU in subsequent runs. The
/// upper bound is determined by the minimum of [`MtuDiscoveryConfig::upper_bound`] and the
/// `max_udp_payload_size` transport parameter received from the peer during the handshake.
///
/// # Black hole detection
///
/// If, at some point, the network path no longer accepts packets of the detected size, packet loss
/// will eventually trigger black hole detection and reset the detected MTU to 1200. In that case,
/// MTU discovery will be triggered after [`MtuDiscoveryConfig::black_hole_cooldown`] (ignoring the
/// timer that was set based on [`MtuDiscoveryConfig::interval`]).
///
/// # Interaction between peers
///
/// There is no guarantee that the MTU on the path between A and B is the same as the MTU of the
/// path between B and A. Therefore, each peer in the connection needs to run MTU discovery
/// independently in order to discover the path's MTU.
#[derive(Clone, Debug)]
pub struct MtuDiscoveryConfig {
pub(crate) interval: Duration,
pub(crate) upper_bound: u16,
pub(crate) minimum_change: u16,
pub(crate) black_hole_cooldown: Duration,
}
impl MtuDiscoveryConfig {
/// Specifies the time to wait after completing MTU discovery before starting a new MTU
/// discovery run.
///
/// Defaults to 600 seconds, as recommended by [RFC
/// 8899](https://www.rfc-editor.org/rfc/rfc8899).
pub fn interval(&mut self, value: Duration) -> &mut Self {
self.interval = value;
self
}
/// Specifies the upper bound to the max UDP payload size that MTU discovery will search for.
///
/// Defaults to 1452, to stay within Ethernet's MTU when using IPv4 and IPv6. The highest
/// allowed value is 65527, which corresponds to the maximum permitted UDP payload on IPv6.
///
/// It is safe to use an arbitrarily high upper bound, regardless of the network path's MTU. The
/// only drawback is that MTU discovery might take more time to finish.
pub fn upper_bound(&mut self, value: u16) -> &mut Self {
self.upper_bound = value.min(MAX_UDP_PAYLOAD);
self
}
/// Specifies the amount of time that MTU discovery should wait after a black hole was detected
/// before running again. Defaults to one minute.
///
/// Black hole detection can be spuriously triggered in case of congestion, so it makes sense to
/// try MTU discovery again after a short period of time.
pub fn black_hole_cooldown(&mut self, value: Duration) -> &mut Self {
self.black_hole_cooldown = value;
self
}
/// Specifies the minimum MTU change to stop the MTU discovery phase.
/// Defaults to 20.
pub fn minimum_change(&mut self, value: u16) -> &mut Self {
self.minimum_change = value;
self
}
}
impl Default for MtuDiscoveryConfig {
fn default() -> Self {
Self {
interval: Duration::from_secs(600),
upper_bound: 1452,
black_hole_cooldown: Duration::from_secs(60),
minimum_change: 20,
}
}
}
/// Maximum duration of inactivity to accept before timing out the connection
///
/// This wraps an underlying [`VarInt`], representing the duration in milliseconds. Values can be
/// constructed by converting directly from `VarInt`, or using `TryFrom<Duration>`.
///
/// ```
/// # use std::{convert::TryFrom, time::Duration};
/// # use quinn_proto::{IdleTimeout, VarIntBoundsExceeded, VarInt};
/// # fn main() -> Result<(), VarIntBoundsExceeded> {
/// // A `VarInt`-encoded value in milliseconds
/// let timeout = IdleTimeout::from(VarInt::from_u32(10_000));
///
/// // Try to convert a `Duration` into a `VarInt`-encoded timeout
/// let timeout = IdleTimeout::try_from(Duration::from_secs(10))?;
/// # Ok(())
/// # }
/// ```
#[derive(Default, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct IdleTimeout(VarInt);
impl From<VarInt> for IdleTimeout {
fn from(inner: VarInt) -> Self {
Self(inner)
}
}
impl std::convert::TryFrom<Duration> for IdleTimeout {
type Error = VarIntBoundsExceeded;
fn try_from(timeout: Duration) -> Result<Self, Self::Error> {
let inner = VarInt::try_from(timeout.as_millis())?;
Ok(Self(inner))
}
}
impl fmt::Debug for IdleTimeout {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

105
vendor/quinn-proto/src/congestion.rs vendored Normal file
View File

@@ -0,0 +1,105 @@
//! Logic for controlling the rate at which data is sent
use crate::Instant;
use crate::connection::RttEstimator;
use std::any::Any;
use std::sync::Arc;
mod bbr;
mod cubic;
mod new_reno;
pub use bbr::{Bbr, BbrConfig};
pub use cubic::{Cubic, CubicConfig};
pub use new_reno::{NewReno, NewRenoConfig};
/// Common interface for different congestion controllers
pub trait Controller: Send + Sync {
/// One or more packets were just sent
#[allow(unused_variables)]
fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) {}
/// Packet deliveries were confirmed
///
/// `app_limited` indicates whether the connection was blocked on outgoing
/// application data prior to receiving these acknowledgements.
#[allow(unused_variables)]
fn on_ack(
&mut self,
now: Instant,
sent: Instant,
bytes: u64,
app_limited: bool,
rtt: &RttEstimator,
) {
}
/// Packets are acked in batches, all with the same `now` argument. This indicates one of those batches has completed.
#[allow(unused_variables)]
fn on_end_acks(
&mut self,
now: Instant,
in_flight: u64,
app_limited: bool,
largest_packet_num_acked: Option<u64>,
) {
}
/// Packets were deemed lost or marked congested
///
/// `in_persistent_congestion` indicates whether all packets sent within the persistent
/// congestion threshold period ending when the most recent packet in this batch was sent were
/// lost.
/// `lost_bytes` indicates how many bytes were lost. This value will be 0 for ECN triggers.
fn on_congestion_event(
&mut self,
now: Instant,
sent: Instant,
is_persistent_congestion: bool,
lost_bytes: u64,
);
/// The known MTU for the current network path has been updated
fn on_mtu_update(&mut self, new_mtu: u16);
/// Number of ack-eliciting bytes that may be in flight
fn window(&self) -> u64;
/// Retrieve implementation-specific metrics used to populate `qlog` traces when they are enabled
fn metrics(&self) -> ControllerMetrics {
ControllerMetrics {
congestion_window: self.window(),
ssthresh: None,
pacing_rate: None,
}
}
/// Duplicate the controller's state
fn clone_box(&self) -> Box<dyn Controller>;
/// Initial congestion window
fn initial_window(&self) -> u64;
/// Returns Self for use in down-casting to extract implementation details
fn into_any(self: Box<Self>) -> Box<dyn Any>;
}
/// Common congestion controller metrics
#[derive(Default)]
#[non_exhaustive]
pub struct ControllerMetrics {
/// Congestion window (bytes)
pub congestion_window: u64,
/// Slow start threshold (bytes)
pub ssthresh: Option<u64>,
/// Pacing rate (bits/s)
pub pacing_rate: Option<u64>,
}
/// Constructs controllers on demand
pub trait ControllerFactory {
/// Construct a fresh `Controller`
fn build(self: Arc<Self>, now: Instant, current_mtu: u16) -> Box<dyn Controller>;
}
const BASE_DATAGRAM_SIZE: u64 = 1200;

View File

@@ -0,0 +1,101 @@
use std::fmt::{Debug, Display, Formatter};
use super::min_max::MinMax;
use crate::{Duration, Instant};
#[derive(Clone, Debug, Default)]
pub(crate) struct BandwidthEstimation {
total_acked: u64,
prev_total_acked: u64,
acked_time: Option<Instant>,
prev_acked_time: Option<Instant>,
total_sent: u64,
prev_total_sent: u64,
sent_time: Option<Instant>,
prev_sent_time: Option<Instant>,
max_filter: MinMax,
acked_at_last_window: u64,
}
impl BandwidthEstimation {
pub(crate) fn on_sent(&mut self, now: Instant, bytes: u64) {
self.prev_total_sent = self.total_sent;
self.total_sent += bytes;
self.prev_sent_time = self.sent_time;
self.sent_time = Some(now);
}
pub(crate) fn on_ack(
&mut self,
now: Instant,
_sent: Instant,
bytes: u64,
round: u64,
app_limited: bool,
) {
self.prev_total_acked = self.total_acked;
self.total_acked += bytes;
self.prev_acked_time = self.acked_time;
self.acked_time = Some(now);
let prev_sent_time = match self.prev_sent_time {
Some(prev_sent_time) => prev_sent_time,
None => return,
};
let send_rate = match self.sent_time {
Some(sent_time) if sent_time > prev_sent_time => Self::bw_from_delta(
self.total_sent - self.prev_total_sent,
sent_time - prev_sent_time,
)
.unwrap_or(0),
_ => u64::MAX, // will take the min of send and ack, so this is just a skip
};
let ack_rate = match self.prev_acked_time {
Some(prev_acked_time) => Self::bw_from_delta(
self.total_acked - self.prev_total_acked,
now - prev_acked_time,
)
.unwrap_or(0),
None => 0,
};
let bandwidth = send_rate.min(ack_rate);
if !app_limited && self.max_filter.get() < bandwidth {
self.max_filter.update_max(round, bandwidth);
}
}
pub(crate) fn bytes_acked_this_window(&self) -> u64 {
self.total_acked - self.acked_at_last_window
}
pub(crate) fn end_acks(&mut self, _current_round: u64, _app_limited: bool) {
self.acked_at_last_window = self.total_acked;
}
pub(crate) fn get_estimate(&self) -> u64 {
self.max_filter.get()
}
pub(crate) const fn bw_from_delta(bytes: u64, delta: Duration) -> Option<u64> {
let window_duration_ns = delta.as_nanos();
if window_duration_ns == 0 {
return None;
}
let b_ns = bytes * 1_000_000_000;
let bytes_per_second = b_ns / (window_duration_ns as u64);
Some(bytes_per_second)
}
}
impl Display for BandwidthEstimation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{:.3} MB/s",
self.get_estimate() as f32 / (1024 * 1024) as f32
)
}
}

View File

@@ -0,0 +1,152 @@
/*
* Based on Google code released under BSD license here:
* https://groups.google.com/forum/#!topic/bbr-dev/3RTgkzi5ZD8
*/
/*
* Kathleen Nichols' algorithm for tracking the minimum (or maximum)
* value of a data stream over some fixed time interval. (E.g.,
* the minimum RTT over the past five minutes.) It uses constant
* space and constant time per update yet almost always delivers
* the same minimum as an implementation that has to keep all the
* data in the window.
*
* The algorithm keeps track of the best, 2nd best & 3rd best min
* values, maintaining an invariant that the measurement time of
* the n'th best >= n-1'th best. It also makes sure that the three
* values are widely separated in the time window since that bounds
* the worse case error when that data is monotonically increasing
* over the window.
*
* Upon getting a new min, we can forget everything earlier because
* it has no value - the new min is <= everything else in the window
* by definition and it samples the most recent. So we restart fresh on
* every new min and overwrites 2nd & 3rd choices. The same property
* holds for 2nd & 3rd best.
*/
use std::fmt::Debug;
#[derive(Copy, Clone, Debug)]
pub(super) struct MinMax {
/// round count, not a timestamp
window: u64,
samples: [MinMaxSample; 3],
}
impl MinMax {
pub(super) fn get(&self) -> u64 {
self.samples[0].value
}
fn fill(&mut self, sample: MinMaxSample) {
self.samples.fill(sample);
}
pub(super) fn reset(&mut self) {
self.fill(Default::default())
}
/// update_min is also defined in the original source, but removed here since it is not used.
pub(super) fn update_max(&mut self, current_round: u64, measurement: u64) {
let sample = MinMaxSample {
time: current_round,
value: measurement,
};
if self.samples[0].value == 0 /* uninitialised */
|| /* found new max? */ sample.value >= self.samples[0].value
|| /* nothing left in window? */ sample.time - self.samples[2].time > self.window
{
self.fill(sample); /* forget earlier samples */
return;
}
if sample.value >= self.samples[1].value {
self.samples[2] = sample;
self.samples[1] = sample;
} else if sample.value >= self.samples[2].value {
self.samples[2] = sample;
}
self.subwin_update(sample);
}
/* As time advances, update the 1st, 2nd, and 3rd choices. */
fn subwin_update(&mut self, sample: MinMaxSample) {
let dt = sample.time - self.samples[0].time;
if dt > self.window {
/*
* Passed entire window without a new sample so make 2nd
* choice the new sample & 3rd choice the new 2nd choice.
* we may have to iterate this since our 2nd choice
* may also be outside the window (we checked on entry
* that the third choice was in the window).
*/
self.samples[0] = self.samples[1];
self.samples[1] = self.samples[2];
self.samples[2] = sample;
if sample.time - self.samples[0].time > self.window {
self.samples[0] = self.samples[1];
self.samples[1] = self.samples[2];
self.samples[2] = sample;
}
} else if self.samples[1].time == self.samples[0].time && dt > self.window / 4 {
/*
* We've passed a quarter of the window without a new sample
* so take a 2nd choice from the 2nd quarter of the window.
*/
self.samples[2] = sample;
self.samples[1] = sample;
} else if self.samples[2].time == self.samples[1].time && dt > self.window / 2 {
/*
* We've passed half the window without finding a new sample
* so take a 3rd choice from the last half of the window
*/
self.samples[2] = sample;
}
}
}
impl Default for MinMax {
fn default() -> Self {
Self {
window: 10,
samples: [Default::default(); 3],
}
}
}
#[derive(Debug, Copy, Clone, Default)]
struct MinMaxSample {
/// round number, not a timestamp
time: u64,
value: u64,
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test() {
let round = 25;
let mut min_max = MinMax::default();
min_max.update_max(round + 1, 100);
assert_eq!(100, min_max.get());
min_max.update_max(round + 3, 120);
assert_eq!(120, min_max.get());
min_max.update_max(round + 5, 160);
assert_eq!(160, min_max.get());
min_max.update_max(round + 7, 100);
assert_eq!(160, min_max.get());
min_max.update_max(round + 10, 100);
assert_eq!(160, min_max.get());
min_max.update_max(round + 14, 100);
assert_eq!(160, min_max.get());
min_max.update_max(round + 16, 100);
assert_eq!(100, min_max.get());
min_max.update_max(round + 18, 130);
assert_eq!(130, min_max.get());
}
}

View File

@@ -0,0 +1,650 @@
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
use rand::{Rng, SeedableRng};
use crate::congestion::ControllerMetrics;
use crate::congestion::bbr::bw_estimation::BandwidthEstimation;
use crate::congestion::bbr::min_max::MinMax;
use crate::connection::RttEstimator;
use crate::{Duration, Instant};
use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
mod bw_estimation;
mod min_max;
/// Experimental! Use at your own risk.
///
/// Aims for reduced buffer bloat and improved performance over high bandwidth-delay product networks.
/// Based on google's quiche implementation <https://source.chromium.org/chromium/chromium/src/+/master:net/third_party/quiche/src/quic/core/congestion_control/bbr_sender.cc>
/// of BBR <https://datatracker.ietf.org/doc/html/draft-cardwell-iccrg-bbr-congestion-control>.
/// More discussion and links at <https://groups.google.com/g/bbr-dev>.
#[derive(Debug, Clone)]
pub struct Bbr {
config: Arc<BbrConfig>,
current_mtu: u64,
max_bandwidth: BandwidthEstimation,
acked_bytes: u64,
mode: Mode,
loss_state: LossState,
recovery_state: RecoveryState,
recovery_window: u64,
is_at_full_bandwidth: bool,
pacing_gain: f32,
high_gain: f32,
drain_gain: f32,
cwnd_gain: f32,
high_cwnd_gain: f32,
last_cycle_start: Option<Instant>,
current_cycle_offset: u8,
init_cwnd: u64,
min_cwnd: u64,
prev_in_flight_count: u64,
exit_probe_rtt_at: Option<Instant>,
probe_rtt_last_started_at: Option<Instant>,
min_rtt: Duration,
exiting_quiescence: bool,
pacing_rate: u64,
max_acked_packet_number: u64,
max_sent_packet_number: u64,
end_recovery_at_packet_number: u64,
cwnd: u64,
current_round_trip_end_packet_number: u64,
round_count: u64,
bw_at_last_round: u64,
round_wo_bw_gain: u64,
ack_aggregation: AckAggregationState,
random_number_generator: rand::rngs::StdRng,
}
impl Bbr {
/// Construct a state using the given `config` and current time `now`
pub fn new(config: Arc<BbrConfig>, current_mtu: u16) -> Self {
let initial_window = config.initial_window;
Self {
config,
current_mtu: current_mtu as u64,
max_bandwidth: BandwidthEstimation::default(),
acked_bytes: 0,
mode: Mode::Startup,
loss_state: Default::default(),
recovery_state: RecoveryState::NotInRecovery,
recovery_window: 0,
is_at_full_bandwidth: false,
pacing_gain: K_DEFAULT_HIGH_GAIN,
high_gain: K_DEFAULT_HIGH_GAIN,
drain_gain: 1.0 / K_DEFAULT_HIGH_GAIN,
cwnd_gain: K_DEFAULT_HIGH_GAIN,
high_cwnd_gain: K_DEFAULT_HIGH_GAIN,
last_cycle_start: None,
current_cycle_offset: 0,
init_cwnd: initial_window,
min_cwnd: calculate_min_window(current_mtu as u64),
prev_in_flight_count: 0,
exit_probe_rtt_at: None,
probe_rtt_last_started_at: None,
min_rtt: Default::default(),
exiting_quiescence: false,
pacing_rate: 0,
max_acked_packet_number: 0,
max_sent_packet_number: 0,
end_recovery_at_packet_number: 0,
cwnd: initial_window,
current_round_trip_end_packet_number: 0,
round_count: 0,
bw_at_last_round: 0,
round_wo_bw_gain: 0,
ack_aggregation: AckAggregationState::default(),
random_number_generator: rand::rngs::StdRng::from_os_rng(),
}
}
fn enter_startup_mode(&mut self) {
self.mode = Mode::Startup;
self.pacing_gain = self.high_gain;
self.cwnd_gain = self.high_cwnd_gain;
}
fn enter_probe_bandwidth_mode(&mut self, now: Instant) {
self.mode = Mode::ProbeBw;
self.cwnd_gain = K_DERIVED_HIGH_CWNDGAIN;
self.last_cycle_start = Some(now);
// Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is
// excluded because in that case increased gain and decreased gain would not
// follow each other.
let mut rand_index = self
.random_number_generator
.random_range(0..K_PACING_GAIN.len() as u8 - 1);
if rand_index >= 1 {
rand_index += 1;
}
self.current_cycle_offset = rand_index;
self.pacing_gain = K_PACING_GAIN[rand_index as usize];
}
fn update_recovery_state(&mut self, is_round_start: bool) {
// Exit recovery when there are no losses for a round.
if self.loss_state.has_losses() {
self.end_recovery_at_packet_number = self.max_sent_packet_number;
}
match self.recovery_state {
// Enter conservation on the first loss.
RecoveryState::NotInRecovery if self.loss_state.has_losses() => {
self.recovery_state = RecoveryState::Conservation;
// This will cause the |recovery_window| to be set to the
// correct value in CalculateRecoveryWindow().
self.recovery_window = 0;
// Since the conservation phase is meant to be lasting for a whole
// round, extend the current round as if it were started right now.
self.current_round_trip_end_packet_number = self.max_sent_packet_number;
}
RecoveryState::Growth | RecoveryState::Conservation => {
if self.recovery_state == RecoveryState::Conservation && is_round_start {
self.recovery_state = RecoveryState::Growth;
}
// Exit recovery if appropriate.
if !self.loss_state.has_losses()
&& self.max_acked_packet_number > self.end_recovery_at_packet_number
{
self.recovery_state = RecoveryState::NotInRecovery;
}
}
_ => {}
}
}
fn update_gain_cycle_phase(&mut self, now: Instant, in_flight: u64) {
// In most cases, the cycle is advanced after an RTT passes.
let mut should_advance_gain_cycling = self
.last_cycle_start
.map(|last_cycle_start| now.duration_since(last_cycle_start) > self.min_rtt)
.unwrap_or(false);
// If the pacing gain is above 1.0, the connection is trying to probe the
// bandwidth by increasing the number of bytes in flight to at least
// pacing_gain * BDP. Make sure that it actually reaches the target, as
// long as there are no losses suggesting that the buffers are not able to
// hold that much.
if self.pacing_gain > 1.0
&& !self.loss_state.has_losses()
&& self.prev_in_flight_count < self.get_target_cwnd(self.pacing_gain)
{
should_advance_gain_cycling = false;
}
// If pacing gain is below 1.0, the connection is trying to drain the extra
// queue which could have been incurred by probing prior to it. If the
// number of bytes in flight falls down to the estimated BDP value earlier,
// conclude that the queue has been successfully drained and exit this cycle
// early.
if self.pacing_gain < 1.0 && in_flight <= self.get_target_cwnd(1.0) {
should_advance_gain_cycling = true;
}
if should_advance_gain_cycling {
self.current_cycle_offset = (self.current_cycle_offset + 1) % K_PACING_GAIN.len() as u8;
self.last_cycle_start = Some(now);
// Stay in low gain mode until the target BDP is hit. Low gain mode
// will be exited immediately when the target BDP is achieved.
if DRAIN_TO_TARGET
&& self.pacing_gain < 1.0
&& (K_PACING_GAIN[self.current_cycle_offset as usize] - 1.0).abs() < f32::EPSILON
&& in_flight > self.get_target_cwnd(1.0)
{
return;
}
self.pacing_gain = K_PACING_GAIN[self.current_cycle_offset as usize];
}
}
fn maybe_exit_startup_or_drain(&mut self, now: Instant, in_flight: u64) {
if self.mode == Mode::Startup && self.is_at_full_bandwidth {
self.mode = Mode::Drain;
self.pacing_gain = self.drain_gain;
self.cwnd_gain = self.high_cwnd_gain;
}
if self.mode == Mode::Drain && in_flight <= self.get_target_cwnd(1.0) {
self.enter_probe_bandwidth_mode(now);
}
}
fn is_min_rtt_expired(&self, now: Instant, app_limited: bool) -> bool {
!app_limited
&& self
.probe_rtt_last_started_at
.map(|last| now.saturating_duration_since(last) > Duration::from_secs(10))
.unwrap_or(true)
}
fn maybe_enter_or_exit_probe_rtt(
&mut self,
now: Instant,
is_round_start: bool,
bytes_in_flight: u64,
app_limited: bool,
) {
let min_rtt_expired = self.is_min_rtt_expired(now, app_limited);
if min_rtt_expired && !self.exiting_quiescence && self.mode != Mode::ProbeRtt {
self.mode = Mode::ProbeRtt;
self.pacing_gain = 1.0;
// Do not decide on the time to exit ProbeRtt until the
// |bytes_in_flight| is at the target small value.
self.exit_probe_rtt_at = None;
self.probe_rtt_last_started_at = Some(now);
}
if self.mode == Mode::ProbeRtt {
match self.exit_probe_rtt_at {
None => {
// If the window has reached the appropriate size, schedule exiting
// ProbeRtt. The CWND during ProbeRtt is
// kMinimumCongestionWindow, but we allow an extra packet since QUIC
// checks CWND before sending a packet.
if bytes_in_flight < self.get_probe_rtt_cwnd() + self.current_mtu {
const K_PROBE_RTT_TIME: Duration = Duration::from_millis(200);
self.exit_probe_rtt_at = Some(now + K_PROBE_RTT_TIME);
}
}
Some(exit_time) if is_round_start && now >= exit_time => {
if !self.is_at_full_bandwidth {
self.enter_startup_mode();
} else {
self.enter_probe_bandwidth_mode(now);
}
}
Some(_) => {}
}
}
self.exiting_quiescence = false;
}
fn get_target_cwnd(&self, gain: f32) -> u64 {
let bw = self.max_bandwidth.get_estimate();
let bdp = self.min_rtt.as_micros() as u64 * bw;
let bdpf = bdp as f64;
let cwnd = ((gain as f64 * bdpf) / 1_000_000f64) as u64;
// BDP estimate will be zero if no bandwidth samples are available yet.
if cwnd == 0 {
return self.init_cwnd;
}
cwnd.max(self.min_cwnd)
}
fn get_probe_rtt_cwnd(&self) -> u64 {
const K_MODERATE_PROBE_RTT_MULTIPLIER: f32 = 0.75;
if PROBE_RTT_BASED_ON_BDP {
return self.get_target_cwnd(K_MODERATE_PROBE_RTT_MULTIPLIER);
}
self.min_cwnd
}
fn calculate_pacing_rate(&mut self) {
let bw = self.max_bandwidth.get_estimate();
if bw == 0 {
return;
}
let target_rate = (bw as f64 * self.pacing_gain as f64) as u64;
if self.is_at_full_bandwidth {
self.pacing_rate = target_rate;
return;
}
// Pace at the rate of initial_window / RTT as soon as RTT measurements are
// available.
if self.pacing_rate == 0 && self.min_rtt.as_nanos() != 0 {
self.pacing_rate =
BandwidthEstimation::bw_from_delta(self.init_cwnd, self.min_rtt).unwrap();
return;
}
// Do not decrease the pacing rate during startup.
if self.pacing_rate < target_rate {
self.pacing_rate = target_rate;
}
}
fn calculate_cwnd(&mut self, bytes_acked: u64, excess_acked: u64) {
if self.mode == Mode::ProbeRtt {
return;
}
let mut target_window = self.get_target_cwnd(self.cwnd_gain);
if self.is_at_full_bandwidth {
// Add the max recently measured ack aggregation to CWND.
target_window += self.ack_aggregation.max_ack_height.get();
} else {
// Add the most recent excess acked. Because CWND never decreases in
// STARTUP, this will automatically create a very localized max filter.
target_window += excess_acked;
}
// Instead of immediately setting the target CWND as the new one, BBR grows
// the CWND towards |target_window| by only increasing it |bytes_acked| at a
// time.
if self.is_at_full_bandwidth {
self.cwnd = target_window.min(self.cwnd + bytes_acked);
} else if (self.cwnd_gain < target_window as f32) || (self.acked_bytes < self.init_cwnd) {
// If the connection is not yet out of startup phase, do not decrease
// the window.
self.cwnd += bytes_acked;
}
// Enforce the limits on the congestion window.
if self.cwnd < self.min_cwnd {
self.cwnd = self.min_cwnd;
}
}
fn calculate_recovery_window(&mut self, bytes_acked: u64, bytes_lost: u64, in_flight: u64) {
if !self.recovery_state.in_recovery() {
return;
}
// Set up the initial recovery window.
if self.recovery_window == 0 {
self.recovery_window = self.min_cwnd.max(in_flight + bytes_acked);
return;
}
// Remove losses from the recovery window, while accounting for a potential
// integer underflow.
if self.recovery_window >= bytes_lost {
self.recovery_window -= bytes_lost;
} else {
// k_max_segment_size = current_mtu
self.recovery_window = self.current_mtu;
}
// In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH,
// release additional |bytes_acked| to achieve a slow-start-like behavior.
if self.recovery_state == RecoveryState::Growth {
self.recovery_window += bytes_acked;
}
// Sanity checks. Ensure that we always allow to send at least an MSS or
// |bytes_acked| in response, whichever is larger.
self.recovery_window = self
.recovery_window
.max(in_flight + bytes_acked)
.max(self.min_cwnd);
}
/// <https://datatracker.ietf.org/doc/html/draft-cardwell-iccrg-bbr-congestion-control#section-4.3.2.2>
fn check_if_full_bw_reached(&mut self, app_limited: bool) {
if app_limited {
return;
}
let target = (self.bw_at_last_round as f64 * K_STARTUP_GROWTH_TARGET as f64) as u64;
let bw = self.max_bandwidth.get_estimate();
if bw >= target {
self.bw_at_last_round = bw;
self.round_wo_bw_gain = 0;
self.ack_aggregation.max_ack_height.reset();
return;
}
self.round_wo_bw_gain += 1;
if self.round_wo_bw_gain >= K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP as u64
|| (self.recovery_state.in_recovery())
{
self.is_at_full_bandwidth = true;
}
}
}
impl Controller for Bbr {
fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) {
self.max_sent_packet_number = last_packet_number;
self.max_bandwidth.on_sent(now, bytes);
}
fn on_ack(
&mut self,
now: Instant,
sent: Instant,
bytes: u64,
app_limited: bool,
rtt: &RttEstimator,
) {
self.max_bandwidth
.on_ack(now, sent, bytes, self.round_count, app_limited);
self.acked_bytes += bytes;
if self.is_min_rtt_expired(now, app_limited) || self.min_rtt > rtt.min() {
self.min_rtt = rtt.min();
}
}
fn on_end_acks(
&mut self,
now: Instant,
in_flight: u64,
app_limited: bool,
largest_packet_num_acked: Option<u64>,
) {
let bytes_acked = self.max_bandwidth.bytes_acked_this_window();
let excess_acked = self.ack_aggregation.update_ack_aggregation_bytes(
bytes_acked,
now,
self.round_count,
self.max_bandwidth.get_estimate(),
);
self.max_bandwidth.end_acks(self.round_count, app_limited);
if let Some(largest_acked_packet) = largest_packet_num_acked {
self.max_acked_packet_number = largest_acked_packet;
}
let mut is_round_start = false;
if bytes_acked > 0 {
is_round_start =
self.max_acked_packet_number > self.current_round_trip_end_packet_number;
if is_round_start {
self.current_round_trip_end_packet_number = self.max_sent_packet_number;
self.round_count += 1;
}
}
self.update_recovery_state(is_round_start);
if self.mode == Mode::ProbeBw {
self.update_gain_cycle_phase(now, in_flight);
}
if is_round_start && !self.is_at_full_bandwidth {
self.check_if_full_bw_reached(app_limited);
}
self.maybe_exit_startup_or_drain(now, in_flight);
self.maybe_enter_or_exit_probe_rtt(now, is_round_start, in_flight, app_limited);
// After the model is updated, recalculate the pacing rate and congestion window.
self.calculate_pacing_rate();
self.calculate_cwnd(bytes_acked, excess_acked);
self.calculate_recovery_window(bytes_acked, self.loss_state.lost_bytes, in_flight);
self.prev_in_flight_count = in_flight;
self.loss_state.reset();
}
fn on_congestion_event(
&mut self,
_now: Instant,
_sent: Instant,
_is_persistent_congestion: bool,
lost_bytes: u64,
) {
self.loss_state.lost_bytes += lost_bytes;
}
fn on_mtu_update(&mut self, new_mtu: u16) {
self.current_mtu = new_mtu as u64;
self.min_cwnd = calculate_min_window(self.current_mtu);
self.init_cwnd = self.config.initial_window.max(self.min_cwnd);
self.cwnd = self.cwnd.max(self.min_cwnd);
}
fn window(&self) -> u64 {
if self.mode == Mode::ProbeRtt {
return self.get_probe_rtt_cwnd();
} else if self.recovery_state.in_recovery() && self.mode != Mode::Startup {
return self.cwnd.min(self.recovery_window);
}
self.cwnd
}
fn metrics(&self) -> ControllerMetrics {
ControllerMetrics {
congestion_window: self.window(),
ssthresh: None,
pacing_rate: Some(self.pacing_rate * 8),
}
}
fn clone_box(&self) -> Box<dyn Controller> {
Box::new(self.clone())
}
fn initial_window(&self) -> u64 {
self.config.initial_window
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
/// Configuration for the [`Bbr`] congestion controller
#[derive(Debug, Clone)]
pub struct BbrConfig {
initial_window: u64,
}
impl BbrConfig {
/// Default limit on the amount of outstanding data in bytes.
///
/// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))`
pub fn initial_window(&mut self, value: u64) -> &mut Self {
self.initial_window = value;
self
}
}
impl Default for BbrConfig {
fn default() -> Self {
Self {
initial_window: K_MAX_INITIAL_CONGESTION_WINDOW * BASE_DATAGRAM_SIZE,
}
}
}
impl ControllerFactory for BbrConfig {
fn build(self: Arc<Self>, _now: Instant, current_mtu: u16) -> Box<dyn Controller> {
Box::new(Bbr::new(self, current_mtu))
}
}
#[derive(Debug, Default, Copy, Clone)]
struct AckAggregationState {
max_ack_height: MinMax,
aggregation_epoch_start_time: Option<Instant>,
aggregation_epoch_bytes: u64,
}
impl AckAggregationState {
fn update_ack_aggregation_bytes(
&mut self,
newly_acked_bytes: u64,
now: Instant,
round: u64,
max_bandwidth: u64,
) -> u64 {
// Compute how many bytes are expected to be delivered, assuming max
// bandwidth is correct.
let expected_bytes_acked = max_bandwidth
* now
.saturating_duration_since(self.aggregation_epoch_start_time.unwrap_or(now))
.as_micros() as u64
/ 1_000_000;
// Reset the current aggregation epoch as soon as the ack arrival rate is
// less than or equal to the max bandwidth.
if self.aggregation_epoch_bytes <= expected_bytes_acked {
// Reset to start measuring a new aggregation epoch.
self.aggregation_epoch_bytes = newly_acked_bytes;
self.aggregation_epoch_start_time = Some(now);
return 0;
}
// Compute how many extra bytes were delivered vs max bandwidth.
// Include the bytes most recently acknowledged to account for stretch acks.
self.aggregation_epoch_bytes += newly_acked_bytes;
let diff = self.aggregation_epoch_bytes - expected_bytes_acked;
self.max_ack_height.update_max(round, diff);
diff
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum Mode {
// Startup phase of the connection.
Startup,
// After achieving the highest possible bandwidth during the startup, lower
// the pacing rate in order to drain the queue.
Drain,
// Cruising mode.
ProbeBw,
// Temporarily slow down sending in order to empty the buffer and measure
// the real minimum RTT.
ProbeRtt,
}
// Indicates how the congestion control limits the amount of bytes in flight.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum RecoveryState {
// Do not limit.
NotInRecovery,
// Allow an extra outstanding byte for each byte acknowledged.
Conservation,
// Allow two extra outstanding bytes for each byte acknowledged (slow
// start).
Growth,
}
impl RecoveryState {
pub(super) fn in_recovery(&self) -> bool {
!matches!(self, Self::NotInRecovery)
}
}
#[derive(Debug, Clone, Default)]
struct LossState {
lost_bytes: u64,
}
impl LossState {
pub(super) fn reset(&mut self) {
self.lost_bytes = 0;
}
pub(super) fn has_losses(&self) -> bool {
self.lost_bytes != 0
}
}
fn calculate_min_window(current_mtu: u64) -> u64 {
4 * current_mtu
}
// The gain used for the STARTUP, equal to 2/ln(2).
const K_DEFAULT_HIGH_GAIN: f32 = 2.885;
// The newly derived CWND gain for STARTUP, 2.
const K_DERIVED_HIGH_CWNDGAIN: f32 = 2.0;
// The cycle of gains used during the ProbeBw stage.
const K_PACING_GAIN: [f32; 8] = [1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
const K_STARTUP_GROWTH_TARGET: f32 = 1.25;
const K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP: u8 = 3;
// Do not allow initial congestion window to be greater than 200 packets.
const K_MAX_INITIAL_CONGESTION_WINDOW: u64 = 200;
const PROBE_RTT_BASED_ON_BDP: bool = true;
const DRAIN_TO_TARGET: bool = true;

View File

@@ -0,0 +1,272 @@
use std::any::Any;
use std::cmp;
use std::sync::Arc;
use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
use crate::connection::RttEstimator;
use crate::{Duration, Instant};
/// CUBIC Constants.
///
/// These are recommended value in RFC8312.
const BETA_CUBIC: f64 = 0.7;
const C: f64 = 0.4;
/// CUBIC State Variables.
///
/// We need to keep those variables across the connection.
/// k, w_max are described in the RFC.
#[derive(Debug, Default, Clone)]
pub(super) struct State {
k: f64,
w_max: f64,
// Store cwnd increment during congestion avoidance.
cwnd_inc: u64,
}
/// CUBIC Functions.
///
/// Note that these calculations are based on a count of cwnd as bytes,
/// not packets.
/// Unit of t (duration) and RTT are based on seconds (f64).
impl State {
// K = cbrt(w_max * (1 - beta_cubic) / C) (Eq. 2)
fn cubic_k(&self, max_datagram_size: u64) -> f64 {
let w_max = self.w_max / max_datagram_size as f64;
(w_max * (1.0 - BETA_CUBIC) / C).cbrt()
}
// W_cubic(t) = C * (t - K)^3 - w_max (Eq. 1)
fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 {
let w_max = self.w_max / max_datagram_size as f64;
(C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64
}
// W_est(t) = w_max * beta_cubic + 3 * (1 - beta_cubic) / (1 + beta_cubic) *
// (t / RTT) (Eq. 4)
fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 {
let w_max = self.w_max / max_datagram_size as f64;
(w_max * BETA_CUBIC
+ 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64())
* max_datagram_size as f64
}
}
/// The RFC8312 congestion controller, as widely used for TCP
#[derive(Debug, Clone)]
pub struct Cubic {
config: Arc<CubicConfig>,
/// Maximum number of bytes in flight that may be sent.
window: u64,
/// Slow start threshold in bytes. When the congestion window is below ssthresh, the mode is
/// slow start and the window grows by the number of bytes acknowledged.
ssthresh: u64,
/// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent
/// after this time is acknowledged, QUIC exits recovery.
recovery_start_time: Option<Instant>,
cubic_state: State,
current_mtu: u64,
}
impl Cubic {
/// Construct a state using the given `config` and current time `now`
pub fn new(config: Arc<CubicConfig>, _now: Instant, current_mtu: u16) -> Self {
Self {
window: config.initial_window,
ssthresh: u64::MAX,
recovery_start_time: None,
config,
cubic_state: Default::default(),
current_mtu: current_mtu as u64,
}
}
fn minimum_window(&self) -> u64 {
2 * self.current_mtu
}
}
impl Controller for Cubic {
fn on_ack(
&mut self,
now: Instant,
sent: Instant,
bytes: u64,
app_limited: bool,
rtt: &RttEstimator,
) {
if app_limited
|| self
.recovery_start_time
.map(|recovery_start_time| sent <= recovery_start_time)
.unwrap_or(false)
{
return;
}
if self.window < self.ssthresh {
// Slow start
self.window += bytes;
} else {
// Congestion avoidance.
let ca_start_time;
match self.recovery_start_time {
Some(t) => ca_start_time = t,
None => {
// When we come here without congestion_event() triggered,
// initialize congestion_recovery_start_time, w_max and k.
ca_start_time = now;
self.recovery_start_time = Some(now);
self.cubic_state.w_max = self.window as f64;
self.cubic_state.k = 0.0;
}
}
let t = now - ca_start_time;
// w_cubic(t + rtt)
let w_cubic = self.cubic_state.w_cubic(t + rtt.get(), self.current_mtu);
// w_est(t)
let w_est = self.cubic_state.w_est(t, rtt.get(), self.current_mtu);
let mut cubic_cwnd = self.window;
if w_cubic < w_est {
// TCP friendly region.
cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64);
} else if cubic_cwnd < w_cubic as u64 {
// Concave region or convex region use same increment.
let cubic_inc =
(w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64;
cubic_cwnd += cubic_inc as u64;
}
// Update the increment and increase cwnd by MSS.
self.cubic_state.cwnd_inc += cubic_cwnd - self.window;
// cwnd_inc can be more than 1 MSS in the late stage of max probing.
// however RFC9002 §7.3.3 (Congestion Avoidance) limits
// the increase of cwnd to 1 max_datagram_size per cwnd acknowledged.
if self.cubic_state.cwnd_inc >= self.current_mtu {
self.window += self.current_mtu;
self.cubic_state.cwnd_inc = 0;
}
}
}
fn on_congestion_event(
&mut self,
now: Instant,
sent: Instant,
is_persistent_congestion: bool,
_lost_bytes: u64,
) {
if self
.recovery_start_time
.map(|recovery_start_time| sent <= recovery_start_time)
.unwrap_or(false)
{
return;
}
self.recovery_start_time = Some(now);
// Fast convergence
if (self.window as f64) < self.cubic_state.w_max {
self.cubic_state.w_max = self.window as f64 * (1.0 + BETA_CUBIC) / 2.0;
} else {
self.cubic_state.w_max = self.window as f64;
}
self.ssthresh = cmp::max(
(self.cubic_state.w_max * BETA_CUBIC) as u64,
self.minimum_window(),
);
self.window = self.ssthresh;
self.cubic_state.k = self.cubic_state.cubic_k(self.current_mtu);
self.cubic_state.cwnd_inc = (self.cubic_state.cwnd_inc as f64 * BETA_CUBIC) as u64;
if is_persistent_congestion {
self.recovery_start_time = None;
self.cubic_state.w_max = self.window as f64;
// 4.7 Timeout - reduce ssthresh based on BETA_CUBIC
self.ssthresh = cmp::max(
(self.window as f64 * BETA_CUBIC) as u64,
self.minimum_window(),
);
self.cubic_state.cwnd_inc = 0;
self.window = self.minimum_window();
}
}
fn on_mtu_update(&mut self, new_mtu: u16) {
self.current_mtu = new_mtu as u64;
self.window = self.window.max(self.minimum_window());
}
fn window(&self) -> u64 {
self.window
}
fn metrics(&self) -> super::ControllerMetrics {
super::ControllerMetrics {
congestion_window: self.window(),
ssthresh: Some(self.ssthresh),
pacing_rate: None,
}
}
fn clone_box(&self) -> Box<dyn Controller> {
Box::new(self.clone())
}
fn initial_window(&self) -> u64 {
self.config.initial_window
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
/// Configuration for the `Cubic` congestion controller
#[derive(Debug, Clone)]
pub struct CubicConfig {
initial_window: u64,
}
impl CubicConfig {
/// Default limit on the amount of outstanding data in bytes.
///
/// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))`
pub fn initial_window(&mut self, value: u64) -> &mut Self {
self.initial_window = value;
self
}
}
impl Default for CubicConfig {
fn default() -> Self {
Self {
initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE),
}
}
}
impl ControllerFactory for CubicConfig {
fn build(self: Arc<Self>, now: Instant, current_mtu: u16) -> Box<dyn Controller> {
Box::new(Cubic::new(self, now, current_mtu))
}
}

View File

@@ -0,0 +1,172 @@
use std::any::Any;
use std::sync::Arc;
use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
use crate::Instant;
use crate::connection::RttEstimator;
/// A simple, standard congestion controller
#[derive(Debug, Clone)]
pub struct NewReno {
config: Arc<NewRenoConfig>,
current_mtu: u64,
/// Maximum number of bytes in flight that may be sent.
window: u64,
/// Slow start threshold in bytes. When the congestion window is below ssthresh, the mode is
/// slow start and the window grows by the number of bytes acknowledged.
ssthresh: u64,
/// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent
/// after this time is acknowledged, QUIC exits recovery.
recovery_start_time: Instant,
/// Bytes which had been acked by the peer since leaving slow start
bytes_acked: u64,
}
impl NewReno {
/// Construct a state using the given `config` and current time `now`
pub fn new(config: Arc<NewRenoConfig>, now: Instant, current_mtu: u16) -> Self {
Self {
window: config.initial_window,
ssthresh: u64::MAX,
recovery_start_time: now,
current_mtu: current_mtu as u64,
config,
bytes_acked: 0,
}
}
fn minimum_window(&self) -> u64 {
2 * self.current_mtu
}
}
impl Controller for NewReno {
fn on_ack(
&mut self,
_now: Instant,
sent: Instant,
bytes: u64,
app_limited: bool,
_rtt: &RttEstimator,
) {
if app_limited || sent <= self.recovery_start_time {
return;
}
if self.window < self.ssthresh {
// Slow start
self.window += bytes;
if self.window >= self.ssthresh {
// Exiting slow start
// Initialize `bytes_acked` for congestion avoidance. The idea
// here is that any bytes over `sshthresh` will already be counted
// towards the congestion avoidance phase - independent of when
// how close to `sshthresh` the `window` was when switching states,
// and independent of datagram sizes.
self.bytes_acked = self.window - self.ssthresh;
}
} else {
// Congestion avoidance
// This implementation uses the method which does not require
// floating point math, which also increases the window by 1 datagram
// for every round trip.
// This mechanism is called Appropriate Byte Counting in
// https://tools.ietf.org/html/rfc3465
self.bytes_acked += bytes;
if self.bytes_acked >= self.window {
self.bytes_acked -= self.window;
self.window += self.current_mtu;
}
}
}
fn on_congestion_event(
&mut self,
now: Instant,
sent: Instant,
is_persistent_congestion: bool,
_lost_bytes: u64,
) {
if sent <= self.recovery_start_time {
return;
}
self.recovery_start_time = now;
self.window = (self.window as f32 * self.config.loss_reduction_factor) as u64;
self.window = self.window.max(self.minimum_window());
self.ssthresh = self.window;
if is_persistent_congestion {
self.window = self.minimum_window();
}
}
fn on_mtu_update(&mut self, new_mtu: u16) {
self.current_mtu = new_mtu as u64;
self.window = self.window.max(self.minimum_window());
}
fn window(&self) -> u64 {
self.window
}
fn metrics(&self) -> super::ControllerMetrics {
super::ControllerMetrics {
congestion_window: self.window(),
ssthresh: Some(self.ssthresh),
pacing_rate: None,
}
}
fn clone_box(&self) -> Box<dyn Controller> {
Box::new(self.clone())
}
fn initial_window(&self) -> u64 {
self.config.initial_window
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
/// Configuration for the `NewReno` congestion controller
#[derive(Debug, Clone)]
pub struct NewRenoConfig {
initial_window: u64,
loss_reduction_factor: f32,
}
impl NewRenoConfig {
/// Default limit on the amount of outstanding data in bytes.
///
/// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))`
pub fn initial_window(&mut self, value: u64) -> &mut Self {
self.initial_window = value;
self
}
/// Reduction in congestion window when a new loss event is detected.
pub fn loss_reduction_factor(&mut self, value: f32) -> &mut Self {
self.loss_reduction_factor = value;
self
}
}
impl Default for NewRenoConfig {
fn default() -> Self {
Self {
initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE),
loss_reduction_factor: 0.5,
}
}
}
impl ControllerFactory for NewRenoConfig {
fn build(self: Arc<Self>, now: Instant, current_mtu: u16) -> Box<dyn Controller> {
Box::new(NewReno::new(self, now, current_mtu))
}
}

View File

@@ -0,0 +1,155 @@
use crate::Duration;
use crate::connection::spaces::PendingAcks;
use crate::frame::AckFrequency;
use crate::transport_parameters::TransportParameters;
use crate::{AckFrequencyConfig, TIMER_GRANULARITY, TransportError, VarInt};
/// State associated to ACK frequency
pub(super) struct AckFrequencyState {
//
// Sending ACK_FREQUENCY frames
//
in_flight_ack_frequency_frame: Option<(u64, Duration)>,
next_outgoing_sequence_number: VarInt,
pub(super) peer_max_ack_delay: Duration,
//
// Receiving ACK_FREQUENCY frames
//
last_ack_frequency_frame: Option<u64>,
pub(super) max_ack_delay: Duration,
}
impl AckFrequencyState {
pub(super) fn new(default_max_ack_delay: Duration) -> Self {
Self {
in_flight_ack_frequency_frame: None,
next_outgoing_sequence_number: VarInt(0),
peer_max_ack_delay: default_max_ack_delay,
last_ack_frequency_frame: None,
max_ack_delay: default_max_ack_delay,
}
}
/// Returns the `max_ack_delay` that should be requested of the peer when sending an
/// ACK_FREQUENCY frame
pub(super) fn candidate_max_ack_delay(
&self,
rtt: Duration,
config: &AckFrequencyConfig,
peer_params: &TransportParameters,
) -> Duration {
// Use the peer's max_ack_delay if no custom max_ack_delay was provided in the config
let min_ack_delay =
Duration::from_micros(peer_params.min_ack_delay.map_or(0, |x| x.into()));
config
.max_ack_delay
.unwrap_or(self.peer_max_ack_delay)
.clamp(min_ack_delay, rtt.max(MIN_AUTOMATIC_ACK_DELAY))
}
/// Returns the `max_ack_delay` for the purposes of calculating the PTO
///
/// This `max_ack_delay` is defined as the maximum of the peer's current `max_ack_delay` and all
/// in-flight `max_ack_delay`s (i.e. proposed values that haven't been acknowledged yet, but
/// might be already in use by the peer).
pub(super) fn max_ack_delay_for_pto(&self) -> Duration {
// Note: we have at most one in-flight ACK_FREQUENCY frame
if let Some((_, max_ack_delay)) = self.in_flight_ack_frequency_frame {
self.peer_max_ack_delay.max(max_ack_delay)
} else {
self.peer_max_ack_delay
}
}
/// Returns the next sequence number for an ACK_FREQUENCY frame
pub(super) fn next_sequence_number(&mut self) -> VarInt {
assert!(self.next_outgoing_sequence_number <= VarInt::MAX);
let seq = self.next_outgoing_sequence_number;
self.next_outgoing_sequence_number.0 += 1;
seq
}
/// Returns true if we should send an ACK_FREQUENCY frame
pub(super) fn should_send_ack_frequency(
&self,
rtt: Duration,
config: &AckFrequencyConfig,
peer_params: &TransportParameters,
) -> bool {
if self.next_outgoing_sequence_number.0 == 0 {
// Always send at startup
return true;
}
let current = self
.in_flight_ack_frequency_frame
.map_or(self.peer_max_ack_delay, |(_, pending)| pending);
let desired = self.candidate_max_ack_delay(rtt, config, peer_params);
let error = (desired.as_secs_f32() / current.as_secs_f32()) - 1.0;
error.abs() > MAX_RTT_ERROR
}
/// Notifies the [`AckFrequencyState`] that a packet containing an ACK_FREQUENCY frame was sent
pub(super) fn ack_frequency_sent(&mut self, pn: u64, requested_max_ack_delay: Duration) {
self.in_flight_ack_frequency_frame = Some((pn, requested_max_ack_delay));
}
/// Notifies the [`AckFrequencyState`] that a packet has been ACKed
pub(super) fn on_acked(&mut self, pn: u64) {
match self.in_flight_ack_frequency_frame {
Some((number, requested_max_ack_delay)) if number == pn => {
self.in_flight_ack_frequency_frame = None;
self.peer_max_ack_delay = requested_max_ack_delay;
}
_ => {}
}
}
/// Notifies the [`AckFrequencyState`] that an ACK_FREQUENCY frame was received
///
/// Updates the endpoint's params according to the payload of the ACK_FREQUENCY frame, or
/// returns an error in case the requested `max_ack_delay` is invalid.
///
/// Returns `true` if the frame was processed and `false` if it was ignored because of being
/// stale.
pub(super) fn ack_frequency_received(
&mut self,
frame: &AckFrequency,
pending_acks: &mut PendingAcks,
) -> Result<bool, TransportError> {
if self
.last_ack_frequency_frame
.is_some_and(|highest_sequence_nr| frame.sequence.into_inner() <= highest_sequence_nr)
{
return Ok(false);
}
self.last_ack_frequency_frame = Some(frame.sequence.into_inner());
// Update max_ack_delay
let max_ack_delay = Duration::from_micros(frame.request_max_ack_delay.into_inner());
if max_ack_delay < TIMER_GRANULARITY {
return Err(TransportError::PROTOCOL_VIOLATION(
"Requested Max Ack Delay in ACK_FREQUENCY frame is less than min_ack_delay",
));
}
self.max_ack_delay = max_ack_delay;
// Update the rest of the params
pending_acks.set_ack_frequency_params(frame);
Ok(true)
}
}
/// Maximum proportion difference between the most recently requested max ACK delay and the
/// currently desired one before a new request is sent, when the peer supports the ACK frequency
/// extension and an explicit max ACK delay is not configured.
const MAX_RTT_ERROR: f32 = 0.2;
/// Minimum value to request the peer set max ACK delay to when the peer supports the ACK frequency
/// extension and an explicit max ACK delay is not configured.
// Keep in sync with `AckFrequencyConfig::max_ack_delay` documentation
const MIN_AUTOMATIC_ACK_DELAY: Duration = Duration::from_millis(25);

View File

@@ -0,0 +1,658 @@
use std::{
cmp::Ordering,
collections::{BinaryHeap, binary_heap::PeekMut},
mem,
};
use bytes::{Buf, Bytes, BytesMut};
use crate::range_set::RangeSet;
/// Helper to assemble unordered stream frames into an ordered stream
#[derive(Debug, Default)]
pub(super) struct Assembler {
state: State,
data: BinaryHeap<Buffer>,
/// Total number of buffered bytes, including duplicates in ordered mode.
buffered: usize,
/// Estimated number of allocated bytes, will never be less than `buffered`.
allocated: usize,
/// Number of bytes read by the application. When only ordered reads have been used, this is the
/// length of the contiguous prefix of the stream which has been consumed by the application,
/// aka the stream offset.
bytes_read: u64,
end: u64,
}
impl Assembler {
pub(super) fn new() -> Self {
Self::default()
}
/// Reset to the initial state
pub(super) fn reinit(&mut self) {
let old_data = mem::take(&mut self.data);
*self = Self::default();
self.data = old_data;
self.data.clear();
}
pub(super) fn ensure_ordering(&mut self, ordered: bool) -> Result<(), IllegalOrderedRead> {
if ordered && !self.state.is_ordered() {
return Err(IllegalOrderedRead);
} else if !ordered && self.state.is_ordered() {
// Enter unordered mode
if !self.data.is_empty() {
// Get rid of possible duplicates
self.defragment();
}
let mut recvd = RangeSet::new();
recvd.insert(0..self.bytes_read);
for chunk in &self.data {
recvd.insert(chunk.offset..chunk.offset + chunk.bytes.len() as u64);
}
self.state = State::Unordered { recvd };
}
Ok(())
}
/// Get the the next chunk
pub(super) fn read(&mut self, max_length: usize, ordered: bool) -> Option<Chunk> {
loop {
let mut chunk = self.data.peek_mut()?;
if ordered {
if chunk.offset > self.bytes_read {
// Next chunk is after current read index
return None;
} else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read {
// Next chunk is useless as the read index is beyond its end
self.buffered -= chunk.bytes.len();
self.allocated -= chunk.allocation_size;
PeekMut::pop(chunk);
continue;
}
// Determine `start` and `len` of the slice of useful data in chunk
let start = (self.bytes_read - chunk.offset) as usize;
if start > 0 {
chunk.bytes.advance(start);
chunk.offset += start as u64;
self.buffered -= start;
}
}
return Some(if max_length < chunk.bytes.len() {
self.bytes_read += max_length as u64;
let offset = chunk.offset;
chunk.offset += max_length as u64;
self.buffered -= max_length;
Chunk::new(offset, chunk.bytes.split_to(max_length))
} else {
self.bytes_read += chunk.bytes.len() as u64;
self.buffered -= chunk.bytes.len();
self.allocated -= chunk.allocation_size;
let chunk = PeekMut::pop(chunk);
Chunk::new(chunk.offset, chunk.bytes)
});
}
}
/// Copy fragmented chunk data to new chunks backed by a single buffer
///
/// This makes sure we're not unnecessarily holding on to many larger allocations.
/// We merge contiguous chunks in the process of doing so.
fn defragment(&mut self) {
let new = BinaryHeap::with_capacity(self.data.len());
let old = mem::replace(&mut self.data, new);
let mut buffers = old.into_sorted_vec();
self.buffered = 0;
let mut fragmented_buffered = 0;
let mut offset = 0;
for chunk in buffers.iter_mut().rev() {
chunk.try_mark_defragment(offset);
let size = chunk.bytes.len();
offset = chunk.offset + size as u64;
self.buffered += size;
if !chunk.defragmented {
fragmented_buffered += size;
}
}
self.allocated = self.buffered;
let mut buffer = BytesMut::with_capacity(fragmented_buffered);
let mut offset = 0;
for chunk in buffers.into_iter().rev() {
if chunk.defragmented {
// bytes might be empty after try_mark_defragment
if !chunk.bytes.is_empty() {
self.data.push(chunk);
}
continue;
}
// Overlap is resolved by try_mark_defragment
if chunk.offset != offset + (buffer.len() as u64) {
if !buffer.is_empty() {
self.data
.push(Buffer::new_defragmented(offset, buffer.split().freeze()));
}
offset = chunk.offset;
}
buffer.extend_from_slice(&chunk.bytes);
}
if !buffer.is_empty() {
self.data
.push(Buffer::new_defragmented(offset, buffer.split().freeze()));
}
}
// Note: If a packet contains many frames from the same stream, the estimated over-allocation
// will be much higher because we are counting the same allocation multiple times.
pub(super) fn insert(&mut self, mut offset: u64, mut bytes: Bytes, allocation_size: usize) {
debug_assert!(
bytes.len() <= allocation_size,
"allocation_size less than bytes.len(): {:?} < {:?}",
allocation_size,
bytes.len()
);
self.end = self.end.max(offset + bytes.len() as u64);
if let State::Unordered { ref mut recvd } = self.state {
// Discard duplicate data
for duplicate in recvd.replace(offset..offset + bytes.len() as u64) {
if duplicate.start > offset {
let buffer = Buffer::new(
offset,
bytes.split_to((duplicate.start - offset) as usize),
allocation_size,
);
self.buffered += buffer.bytes.len();
self.allocated += buffer.allocation_size;
self.data.push(buffer);
offset = duplicate.start;
}
bytes.advance((duplicate.end - offset) as usize);
offset = duplicate.end;
}
} else if offset < self.bytes_read {
if (offset + bytes.len() as u64) <= self.bytes_read {
return;
} else {
let diff = self.bytes_read - offset;
offset += diff;
bytes.advance(diff as usize);
}
}
if bytes.is_empty() {
return;
}
let buffer = Buffer::new(offset, bytes, allocation_size);
self.buffered += buffer.bytes.len();
self.allocated += buffer.allocation_size;
self.data.push(buffer);
// `self.buffered` also counts duplicate bytes, therefore we use
// `self.end - self.bytes_read` as an upper bound of buffered unique
// bytes. This will cause a defragmentation if the amount of duplicate
// bytes exceedes a proportion of the receive window size.
let buffered = self.buffered.min((self.end - self.bytes_read) as usize);
let over_allocation = self.allocated - buffered;
// Rationale: on the one hand, we want to defragment rarely, ideally never
// in non-pathological scenarios. However, a pathological or malicious
// peer could send us one-byte frames, and since we use reference-counted
// buffers in order to prevent copying, this could result in keeping a lot
// of memory allocated. This limits over-allocation in proportion to the
// buffered data. The constants are chosen somewhat arbitrarily and try to
// balance between defragmentation overhead and over-allocation.
let threshold = 32768.max(buffered * 3 / 2);
if over_allocation > threshold {
self.defragment()
}
}
/// Number of bytes consumed by the application
pub(super) fn bytes_read(&self) -> u64 {
self.bytes_read
}
/// Discard all buffered data
pub(super) fn clear(&mut self) {
self.data.clear();
self.buffered = 0;
self.allocated = 0;
}
}
/// A chunk of data from the receive stream
#[derive(Debug, PartialEq, Eq)]
pub struct Chunk {
/// The offset in the stream
pub offset: u64,
/// The contents of the chunk
pub bytes: Bytes,
}
impl Chunk {
fn new(offset: u64, bytes: Bytes) -> Self {
Self { offset, bytes }
}
}
#[derive(Debug, Eq)]
struct Buffer {
offset: u64,
bytes: Bytes,
/// Size of the allocation behind `bytes`, if `defragmented == false`.
/// Otherwise this will be set to `bytes.len()` by `try_mark_defragment`.
/// Will never be less than `bytes.len()`.
allocation_size: usize,
defragmented: bool,
}
impl Buffer {
/// Constructs a new fragmented Buffer
fn new(offset: u64, bytes: Bytes, allocation_size: usize) -> Self {
Self {
offset,
bytes,
allocation_size,
defragmented: false,
}
}
/// Constructs a new defragmented Buffer
fn new_defragmented(offset: u64, bytes: Bytes) -> Self {
let allocation_size = bytes.len();
Self {
offset,
bytes,
allocation_size,
defragmented: true,
}
}
/// Discards data before `offset` and flags `self` as defragmented if it has good utilization
fn try_mark_defragment(&mut self, offset: u64) {
let duplicate = offset.saturating_sub(self.offset) as usize;
self.offset = self.offset.max(offset);
if duplicate >= self.bytes.len() {
// All bytes are duplicate
self.bytes = Bytes::new();
self.defragmented = true;
self.allocation_size = 0;
return;
}
self.bytes.advance(duplicate);
// Make sure that fragmented buffers with high utilization become defragmented and
// defragmented buffers remain defragmented
self.defragmented = self.defragmented || self.bytes.len() * 6 / 5 >= self.allocation_size;
if self.defragmented {
// Make sure that defragmented buffers do not contribute to over-allocation
self.allocation_size = self.bytes.len();
}
}
}
impl Ord for Buffer {
// Invert ordering based on offset (max-heap, min offset first),
// prioritize longer chunks at the same offset.
fn cmp(&self, other: &Self) -> Ordering {
self.offset
.cmp(&other.offset)
.reverse()
.then(self.bytes.len().cmp(&other.bytes.len()))
}
}
impl PartialOrd for Buffer {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for Buffer {
fn eq(&self, other: &Self) -> bool {
(self.offset, self.bytes.len()) == (other.offset, other.bytes.len())
}
}
#[derive(Debug, Default)]
enum State {
#[default]
Ordered,
Unordered {
/// The set of offsets that have been received from the peer, including portions not yet
/// read by the application.
recvd: RangeSet,
},
}
impl State {
fn is_ordered(&self) -> bool {
matches!(self, Self::Ordered)
}
}
/// Error indicating that an ordered read was performed on a stream after an unordered read
#[derive(Debug)]
pub struct IllegalOrderedRead;
#[cfg(test)]
mod test {
use super::*;
use assert_matches::assert_matches;
#[test]
fn assemble_ordered() {
let mut x = Assembler::new();
assert_matches!(next(&mut x, 32), None);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 1), Some(ref y) if &y[..] == b"1");
assert_matches!(next(&mut x, 3), Some(ref y) if &y[..] == b"23");
x.insert(3, Bytes::from_static(b"456"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456");
x.insert(6, Bytes::from_static(b"789"), 3);
x.insert(9, Bytes::from_static(b"10"), 2);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"789");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"10");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_unordered() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(3, Bytes::from_static(b"456"), 3);
assert_matches!(next(&mut x, 32), None);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_duplicate() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_duplicate_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(0, Bytes::from_static(b"123"), 3);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contained() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"12345"), 5);
x.insert(1, Bytes::from_static(b"234"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contained_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"12345"), 5);
x.insert(1, Bytes::from_static(b"234"), 3);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contains() {
let mut x = Assembler::new();
x.insert(1, Bytes::from_static(b"234"), 3);
x.insert(0, Bytes::from_static(b"12345"), 5);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contains_compact() {
let mut x = Assembler::new();
x.insert(1, Bytes::from_static(b"234"), 3);
x.insert(0, Bytes::from_static(b"12345"), 5);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_overlapping() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(1, Bytes::from_static(b"234"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"4");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_overlapping_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 4);
x.insert(1, Bytes::from_static(b"234"), 4);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_complex() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1"), 1);
x.insert(2, Bytes::from_static(b"3"), 1);
x.insert(4, Bytes::from_static(b"5"), 1);
x.insert(0, Bytes::from_static(b"123456"), 6);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_complex_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1"), 1);
x.insert(2, Bytes::from_static(b"3"), 1);
x.insert(4, Bytes::from_static(b"5"), 1);
x.insert(0, Bytes::from_static(b"123456"), 6);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_old() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1234"), 4);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234");
x.insert(0, Bytes::from_static(b"1234"), 4);
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 4);
x.insert(3, Bytes::from_static(b"def"), 4);
x.insert(9, Bytes::from_static(b"jkl"), 4);
x.insert(12, Bytes::from_static(b"mno"), 4);
x.defragment();
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"abcdef"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(9, Bytes::from_static(b"jklmno"))
);
}
#[test]
fn defrag_with_missing_prefix() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 3);
x.defragment();
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
}
#[test]
fn defrag_read_chunk() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 4);
x.insert(0, Bytes::from_static(b"abc"), 4);
x.insert(7, Bytes::from_static(b"hij"), 4);
x.insert(11, Bytes::from_static(b"lmn"), 4);
x.defragment();
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef");
x.insert(5, Bytes::from_static(b"fghijklmn"), 9);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn");
x.insert(13, Bytes::from_static(b"nopq"), 4);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq");
x.insert(15, Bytes::from_static(b"pqrs"), 4);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs");
assert_matches!(x.read(usize::MAX, true), None);
}
#[test]
fn unordered_happy_path() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(0, Bytes::from_static(b"abc"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"abc"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
assert_eq!(x.read(usize::MAX, false), None);
}
#[test]
fn unordered_dedup() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(0, Bytes::from_static(b"a"), 1);
x.insert(0, Bytes::from_static(b"abcdefghi"), 9);
x.insert(0, Bytes::from_static(b"abcd"), 4);
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"a"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(1, Bytes::from_static(b"bc"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(6, Bytes::from_static(b"ghi"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(8, Bytes::from_static(b"ijkl"), 4);
assert_eq!(
next_unordered(&mut x),
Chunk::new(9, Bytes::from_static(b"jkl"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(12, Bytes::from_static(b"mno"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(12, Bytes::from_static(b"mno"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(2, Bytes::from_static(b"cde"), 3);
assert_eq!(x.read(usize::MAX, false), None);
}
#[test]
fn chunks_dedup() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(0, Bytes::from_static(b"a"), 1);
x.insert(1, Bytes::from_static(b"bcdefghi"), 9);
x.insert(0, Bytes::from_static(b"abcd"), 4);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(0, Bytes::from_static(b"abcd")))
);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(4, Bytes::from_static(b"efghi")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(8, Bytes::from_static(b"ijkl"), 4);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(9, Bytes::from_static(b"jkl")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(12, Bytes::from_static(b"mno"), 3);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(12, Bytes::from_static(b"mno")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(2, Bytes::from_static(b"cde"), 3);
assert_eq!(x.read(usize::MAX, true), None);
}
#[test]
fn ordered_eager_discard() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 3);
assert_eq!(x.data.len(), 1);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(0, Bytes::from_static(b"abc")))
);
x.insert(0, Bytes::from_static(b"ab"), 2);
assert_eq!(x.data.len(), 0);
x.insert(2, Bytes::from_static(b"cd"), 2);
assert_eq!(
x.data.peek(),
Some(&Buffer::new(3, Bytes::from_static(b"d"), 2))
);
}
#[test]
fn ordered_insert_unordered_read() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 3);
x.insert(0, Bytes::from_static(b"abc"), 3);
x.ensure_ordering(false).unwrap();
assert_eq!(
x.read(3, false),
Some(Chunk::new(0, Bytes::from_static(b"abc")))
);
assert_eq!(x.read(3, false), None);
}
fn next_unordered(x: &mut Assembler) -> Chunk {
x.read(usize::MAX, false).unwrap()
}
fn next(x: &mut Assembler, size: usize) -> Option<Bytes> {
x.read(size, true).map(|chunk| chunk.bytes)
}
}

View File

@@ -0,0 +1,223 @@
//! Maintain the state of local connection IDs
use std::collections::VecDeque;
use rustc_hash::FxHashSet;
use tracing::{debug, trace};
use crate::{Duration, Instant, TransportError, shared::IssuedCid};
/// Local connection ID management
pub(super) struct CidState {
/// Timestamp when issued cids should be retired
retire_timestamp: VecDeque<CidTimestamp>,
/// Number of local connection IDs that have been issued in NEW_CONNECTION_ID frames.
issued: u64,
/// Sequence numbers of local connection IDs not yet retired by the peer
active_seq: FxHashSet<u64>,
/// Sequence number the peer has already retired all CIDs below at our request via `retire_prior_to`
prev_retire_seq: u64,
/// Sequence number to set in retire_prior_to field in NEW_CONNECTION_ID frame
retire_seq: u64,
/// cid length used to decode short packet
cid_len: usize,
//// cid lifetime
cid_lifetime: Option<Duration>,
}
impl CidState {
pub(crate) fn new(
cid_len: usize,
cid_lifetime: Option<Duration>,
now: Instant,
issued: u64,
) -> Self {
let mut active_seq = FxHashSet::default();
// Add sequence number of CIDs used in handshaking into tracking set
for seq in 0..issued {
active_seq.insert(seq);
}
let mut this = Self {
retire_timestamp: VecDeque::new(),
issued,
active_seq,
prev_retire_seq: 0,
retire_seq: 0,
cid_len,
cid_lifetime,
};
// Track lifetime of CIDs used in handshaking
for seq in 0..issued {
this.track_lifetime(seq, now);
}
this
}
/// Find the next timestamp when previously issued CID should be retired
pub(crate) fn next_timeout(&mut self) -> Option<Instant> {
self.retire_timestamp.front().map(|nc| {
trace!("CID {} will expire at {:?}", nc.sequence, nc.timestamp);
nc.timestamp
})
}
/// Track the lifetime of issued cids in `retire_timestamp`
fn track_lifetime(&mut self, new_cid_seq: u64, now: Instant) {
let lifetime = match self.cid_lifetime {
Some(lifetime) => lifetime,
None => return,
};
let expire_timestamp = now.checked_add(lifetime);
let expire_at = match expire_timestamp {
Some(expire_at) => expire_at,
None => return,
};
let last_record = self.retire_timestamp.back_mut();
if let Some(last) = last_record {
// Compare the timestamp with the last inserted record
// Combine into a single batch if timestamp of current cid is same as the last record
if expire_at == last.timestamp {
debug_assert!(new_cid_seq > last.sequence);
last.sequence = new_cid_seq;
return;
}
}
self.retire_timestamp.push_back(CidTimestamp {
sequence: new_cid_seq,
timestamp: expire_at,
});
}
/// Update local CID state when previously issued CID is retired
///
/// Return whether a new CID needs to be pushed that notifies remote peer to respond `RETIRE_CONNECTION_ID`
pub(crate) fn on_cid_timeout(&mut self) -> bool {
// Whether the peer hasn't retired all the CIDs we asked it to yet
let unretired_ids_found =
(self.prev_retire_seq..self.retire_seq).any(|seq| self.active_seq.contains(&seq));
let current_retire_prior_to = self.retire_seq;
let next_retire_sequence = self
.retire_timestamp
.pop_front()
.map(|seq| seq.sequence + 1);
// According to RFC:
// Endpoints SHOULD NOT issue updates of the Retire Prior To field
// before receiving RETIRE_CONNECTION_ID frames that retire all
// connection IDs indicated by the previous Retire Prior To value.
// https://tools.ietf.org/html/draft-ietf-quic-transport-29#section-5.1.2
if !unretired_ids_found {
// All Cids are retired, `prev_retire_cid_seq` can be assigned to `retire_cid_seq`
self.prev_retire_seq = self.retire_seq;
// Advance `retire_seq` if next cid that needs to be retired exists
if let Some(next_retire_prior_to) = next_retire_sequence {
self.retire_seq = next_retire_prior_to;
}
}
// Check if retirement of all CIDs that reach their lifetime is still needed
// According to RFC:
// An endpoint MUST NOT
// provide more connection IDs than the peer's limit. An endpoint MAY
// send connection IDs that temporarily exceed a peer's limit if the
// NEW_CONNECTION_ID frame also requires the retirement of any excess,
// by including a sufficiently large value in the Retire Prior To field.
//
// If yes (return true), a new CID must be pushed with updated `retire_prior_to` field to remote peer.
// If no (return false), it means CIDs that reach the end of lifetime have been retired already. Do not push a new CID in order to avoid violating above RFC.
(current_retire_prior_to..self.retire_seq).any(|seq| self.active_seq.contains(&seq))
}
/// Update cid state when `NewIdentifiers` event is received
pub(crate) fn new_cids(&mut self, ids: &[IssuedCid], now: Instant) {
// `ids` could be `None` once active_connection_id_limit is set to 1 by peer
let last_cid = match ids.last() {
Some(cid) => cid,
None => return,
};
self.issued += ids.len() as u64;
// Record the timestamp of CID with the largest seq number
let sequence = last_cid.sequence;
ids.iter().for_each(|frame| {
self.active_seq.insert(frame.sequence);
});
self.track_lifetime(sequence, now);
}
/// Update CidState for receipt of a `RETIRE_CONNECTION_ID` frame
///
/// Returns whether a new CID can be issued, or an error if the frame was illegal.
pub(crate) fn on_cid_retirement(
&mut self,
sequence: u64,
limit: u64,
) -> Result<bool, TransportError> {
if self.cid_len == 0 {
return Err(TransportError::PROTOCOL_VIOLATION(
"RETIRE_CONNECTION_ID when CIDs aren't in use",
));
}
if sequence > self.issued {
debug!(
sequence,
"got RETIRE_CONNECTION_ID for unissued sequence number"
);
return Err(TransportError::PROTOCOL_VIOLATION(
"RETIRE_CONNECTION_ID for unissued sequence number",
));
}
self.active_seq.remove(&sequence);
// Consider a scenario where peer A has active remote cid 0,1,2.
// Peer B first send a NEW_CONNECTION_ID with cid 3 and retire_prior_to set to 1.
// Peer A processes this NEW_CONNECTION_ID frame; update remote cid to 1,2,3
// and meanwhile send a RETIRE_CONNECTION_ID to retire cid 0 to peer B.
// If peer B doesn't check the cid limit here and send a new cid again, peer A will then face CONNECTION_ID_LIMIT_ERROR
Ok(limit > self.active_seq.len() as u64)
}
/// Length of local Connection IDs
pub(crate) fn cid_len(&self) -> usize {
self.cid_len
}
/// The value for `retire_prior_to` field in `NEW_CONNECTION_ID` frame
pub(crate) fn retire_prior_to(&self) -> u64 {
self.retire_seq
}
#[cfg(test)]
pub(crate) fn active_seq(&self) -> (u64, u64) {
let mut min = u64::MAX;
let mut max = u64::MIN;
for n in self.active_seq.iter() {
if n < &min {
min = *n;
}
if n > &max {
max = *n;
}
}
(min, max)
}
#[cfg(test)]
pub(crate) fn assign_retire_seq(&mut self, v: u64) -> u64 {
// Cannot retire more CIDs than what have been issued
debug_assert!(v <= *self.active_seq.iter().max().unwrap() + 1);
let n = v.checked_sub(self.retire_seq).unwrap();
self.retire_seq = v;
n
}
}
/// Data structure that records when issued cids should be retired
#[derive(Copy, Clone, Eq, PartialEq)]
struct CidTimestamp {
/// Highest cid sequence number created in a batch
sequence: u64,
/// Timestamp when cid needs to be retired
timestamp: Instant,
}

View File

@@ -0,0 +1,211 @@
use std::collections::VecDeque;
use bytes::Bytes;
use thiserror::Error;
use tracing::{debug, trace};
use super::Connection;
use crate::{
TransportError,
frame::{Datagram, FrameStruct},
};
/// API to control datagram traffic
pub struct Datagrams<'a> {
pub(super) conn: &'a mut Connection,
}
impl Datagrams<'_> {
/// Queue an unreliable, unordered datagram for immediate transmission
///
/// If `drop` is true, previously queued datagrams which are still unsent may be discarded to
/// make space for this datagram, in order of oldest to newest. If `drop` is false, and there
/// isn't enough space due to previously queued datagrams, this function will return
/// `SendDatagramError::Blocked`. `Event::DatagramsUnblocked` will be emitted once datagrams
/// have been sent.
///
/// Returns `Err` iff a `len`-byte datagram cannot currently be sent.
pub fn send(&mut self, data: Bytes, drop: bool) -> Result<(), SendDatagramError> {
if self.conn.config.datagram_receive_buffer_size.is_none() {
return Err(SendDatagramError::Disabled);
}
let max = self
.max_size()
.ok_or(SendDatagramError::UnsupportedByPeer)?;
if data.len() > max {
return Err(SendDatagramError::TooLarge);
}
if drop {
while self.conn.datagrams.outgoing_total > self.conn.config.datagram_send_buffer_size {
let prev = self
.conn
.datagrams
.outgoing
.pop_front()
.expect("datagrams.outgoing_total desynchronized");
trace!(len = prev.data.len(), "dropping outgoing datagram");
self.conn.datagrams.outgoing_total -= prev.data.len();
}
} else if self.conn.datagrams.outgoing_total + data.len()
> self.conn.config.datagram_send_buffer_size
{
self.conn.datagrams.send_blocked = true;
return Err(SendDatagramError::Blocked(data));
}
self.conn.datagrams.outgoing_total += data.len();
self.conn.datagrams.outgoing.push_back(Datagram { data });
Ok(())
}
/// Compute the maximum size of datagrams that may passed to `send_datagram`
///
/// Returns `None` if datagrams are unsupported by the peer or disabled locally.
///
/// This may change over the lifetime of a connection according to variation in the path MTU
/// estimate. The peer can also enforce an arbitrarily small fixed limit, but if the peer's
/// limit is large this is guaranteed to be a little over a kilobyte at minimum.
///
/// Not necessarily the maximum size of received datagrams.
pub fn max_size(&self) -> Option<usize> {
// We use the conservative overhead bound for any packet number, reducing the budget by at
// most 3 bytes, so that PN size fluctuations don't cause users sending maximum-size
// datagrams to suffer avoidable packet loss.
let max_size = self.conn.path.current_mtu() as usize
- self.conn.predict_1rtt_overhead(None)
- Datagram::SIZE_BOUND;
let limit = self
.conn
.peer_params
.max_datagram_frame_size?
.into_inner()
.saturating_sub(Datagram::SIZE_BOUND as u64);
Some(limit.min(max_size as u64) as usize)
}
/// Receive an unreliable, unordered datagram
pub fn recv(&mut self) -> Option<Bytes> {
self.conn.datagrams.recv()
}
/// Bytes available in the outgoing datagram buffer
///
/// When greater than zero, [`send`](Self::send)ing a datagram of at most this size is
/// guaranteed not to cause older datagrams to be dropped.
pub fn send_buffer_space(&self) -> usize {
self.conn
.config
.datagram_send_buffer_size
.saturating_sub(self.conn.datagrams.outgoing_total)
}
}
#[derive(Default)]
pub(super) struct DatagramState {
/// Number of bytes of datagrams that have been received by the local transport but not
/// delivered to the application
pub(super) recv_buffered: usize,
pub(super) incoming: VecDeque<Datagram>,
pub(super) outgoing: VecDeque<Datagram>,
pub(super) outgoing_total: usize,
pub(super) send_blocked: bool,
}
impl DatagramState {
pub(super) fn received(
&mut self,
datagram: Datagram,
window: &Option<usize>,
) -> Result<bool, TransportError> {
let window = match window {
None => {
return Err(TransportError::PROTOCOL_VIOLATION(
"unexpected DATAGRAM frame",
));
}
Some(x) => *x,
};
if datagram.data.len() > window {
return Err(TransportError::PROTOCOL_VIOLATION("oversized datagram"));
}
let was_empty = self.recv_buffered == 0;
while datagram.data.len() + self.recv_buffered > window {
debug!("dropping stale datagram");
self.recv();
}
self.recv_buffered += datagram.data.len();
self.incoming.push_back(datagram);
Ok(was_empty)
}
/// Discard outgoing datagrams with a payload larger than `max_payload` bytes
///
/// Used to ensure that reductions in MTU don't get us stuck in a state where we have a datagram
/// queued but can't send it.
pub(super) fn drop_oversized(&mut self, max_payload: usize) {
self.outgoing.retain(|datagram| {
let result = datagram.data.len() < max_payload;
if !result {
trace!(
"dropping {} byte datagram violating {} byte limit",
datagram.data.len(),
max_payload
);
self.outgoing_total -= datagram.data.len();
}
result
});
}
/// Attempt to write a datagram frame into `buf`, consuming it from `self.outgoing`
///
/// Returns whether a frame was written. At most `max_size` bytes will be written, including
/// framing.
pub(super) fn write(&mut self, buf: &mut Vec<u8>, max_size: usize) -> bool {
let datagram = match self.outgoing.pop_front() {
Some(x) => x,
None => return false,
};
if buf.len() + datagram.size(true) > max_size {
// Future work: we could be more clever about cramming small datagrams into
// mostly-full packets when a larger one is queued first
self.outgoing.push_front(datagram);
return false;
}
trace!(len = datagram.data.len(), "DATAGRAM");
self.outgoing_total -= datagram.data.len();
datagram.encode(true, buf);
true
}
pub(super) fn recv(&mut self) -> Option<Bytes> {
let x = self.incoming.pop_front()?.data;
self.recv_buffered -= x.len();
Some(x)
}
}
/// Errors that can arise when sending a datagram
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum SendDatagramError {
/// The peer does not support receiving datagram frames
#[error("datagrams not supported by peer")]
UnsupportedByPeer,
/// Datagram support is disabled locally
#[error("datagram support disabled")]
Disabled,
/// The datagram is larger than the connection can currently accommodate
///
/// Indicates that the path MTU minus overhead or the limit advertised by the peer has been
/// exceeded.
#[error("datagram too large")]
TooLarge,
/// Send would block
#[error("datagram send blocked")]
Blocked(Bytes),
}

4102
vendor/quinn-proto/src/connection/mod.rs vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,970 @@
use crate::{Instant, MAX_UDP_PAYLOAD, MtuDiscoveryConfig, packet::SpaceId};
use std::cmp;
use tracing::trace;
/// Implements Datagram Packetization Layer Path Maximum Transmission Unit Discovery
///
/// See [`MtuDiscoveryConfig`] for details
#[derive(Clone)]
pub(crate) struct MtuDiscovery {
/// Detected MTU for the path
current_mtu: u16,
/// The state of the MTU discovery, if enabled
state: Option<EnabledMtuDiscovery>,
/// The state of the black hole detector
black_hole_detector: BlackHoleDetector,
}
impl MtuDiscovery {
pub(crate) fn new(
initial_plpmtu: u16,
min_mtu: u16,
peer_max_udp_payload_size: Option<u16>,
config: MtuDiscoveryConfig,
) -> Self {
debug_assert!(
initial_plpmtu >= min_mtu,
"initial_max_udp_payload_size must be at least {min_mtu}"
);
let mut mtud = Self::with_state(
initial_plpmtu,
min_mtu,
Some(EnabledMtuDiscovery::new(config)),
);
// We might be migrating an existing connection to a new path, in which case the transport
// parameters have already been transmitted, and we already know the value of
// `peer_max_udp_payload_size`
if let Some(peer_max_udp_payload_size) = peer_max_udp_payload_size {
mtud.on_peer_max_udp_payload_size_received(peer_max_udp_payload_size);
}
mtud
}
/// MTU discovery will be disabled and the current MTU will be fixed to the provided value
pub(crate) fn disabled(plpmtu: u16, min_mtu: u16) -> Self {
Self::with_state(plpmtu, min_mtu, None)
}
fn with_state(current_mtu: u16, min_mtu: u16, state: Option<EnabledMtuDiscovery>) -> Self {
Self {
current_mtu,
state,
black_hole_detector: BlackHoleDetector::new(min_mtu),
}
}
pub(super) fn reset(&mut self, current_mtu: u16, min_mtu: u16) {
self.current_mtu = current_mtu;
if let Some(state) = self.state.take() {
self.state = Some(EnabledMtuDiscovery::new(state.config));
self.on_peer_max_udp_payload_size_received(state.peer_max_udp_payload_size);
}
self.black_hole_detector = BlackHoleDetector::new(min_mtu);
}
/// Returns the current MTU
pub(crate) fn current_mtu(&self) -> u16 {
self.current_mtu
}
/// Returns the amount of bytes that should be sent as an MTU probe, if any
pub(crate) fn poll_transmit(&mut self, now: Instant, next_pn: u64) -> Option<u16> {
self.state
.as_mut()
.and_then(|state| state.poll_transmit(now, self.current_mtu, next_pn))
}
/// Notifies the [`MtuDiscovery`] that the peer's `max_udp_payload_size` transport parameter has
/// been received
pub(crate) fn on_peer_max_udp_payload_size_received(&mut self, peer_max_udp_payload_size: u16) {
self.current_mtu = self.current_mtu.min(peer_max_udp_payload_size);
if let Some(state) = self.state.as_mut() {
// MTUD is only active after the connection has been fully established, so it is
// guaranteed we will receive the peer's transport parameters before we start probing
debug_assert!(matches!(state.phase, Phase::Initial));
state.peer_max_udp_payload_size = peer_max_udp_payload_size;
}
}
/// Notifies the [`MtuDiscovery`] that a packet has been ACKed
///
/// Returns true if the packet was an MTU probe
pub(crate) fn on_acked(&mut self, space: SpaceId, pn: u64, len: u16) -> bool {
// MTU probes are only sent in application data space
if space != SpaceId::Data {
return false;
}
// Update the state of the MTU search
if let Some(new_mtu) = self
.state
.as_mut()
.and_then(|state| state.on_probe_acked(pn))
{
self.current_mtu = new_mtu;
trace!(current_mtu = self.current_mtu, "new MTU detected");
self.black_hole_detector.on_probe_acked(pn, len);
true
} else {
self.black_hole_detector.on_non_probe_acked(pn, len);
false
}
}
/// Returns the packet number of the in-flight MTU probe, if any
pub(crate) fn in_flight_mtu_probe(&self) -> Option<u64> {
match &self.state {
Some(EnabledMtuDiscovery {
phase: Phase::Searching(search_state),
..
}) => search_state.in_flight_probe,
_ => None,
}
}
/// Notifies the [`MtuDiscovery`] that the in-flight MTU probe was lost
pub(crate) fn on_probe_lost(&mut self) {
if let Some(state) = &mut self.state {
state.on_probe_lost();
}
}
/// Notifies the [`MtuDiscovery`] that a non-probe packet was lost
///
/// When done notifying of lost packets, [`MtuDiscovery::black_hole_detected`] must be called, to
/// ensure the last loss burst is properly processed and to trigger black hole recovery logic if
/// necessary.
pub(crate) fn on_non_probe_lost(&mut self, pn: u64, len: u16) {
self.black_hole_detector.on_non_probe_lost(pn, len);
}
/// Returns true if a black hole was detected
///
/// Calling this function will close the previous loss burst. If a black hole is detected, the
/// current MTU will be reset to `min_mtu`.
pub(crate) fn black_hole_detected(&mut self, now: Instant) -> bool {
if !self.black_hole_detector.black_hole_detected() {
return false;
}
self.current_mtu = self.black_hole_detector.min_mtu;
if let Some(state) = &mut self.state {
state.on_black_hole_detected(now);
}
true
}
}
/// Additional state for enabled MTU discovery
#[derive(Debug, Clone)]
struct EnabledMtuDiscovery {
phase: Phase,
peer_max_udp_payload_size: u16,
config: MtuDiscoveryConfig,
}
impl EnabledMtuDiscovery {
fn new(config: MtuDiscoveryConfig) -> Self {
Self {
phase: Phase::Initial,
peer_max_udp_payload_size: MAX_UDP_PAYLOAD,
config,
}
}
/// Returns the amount of bytes that should be sent as an MTU probe, if any
fn poll_transmit(&mut self, now: Instant, current_mtu: u16, next_pn: u64) -> Option<u16> {
if let Phase::Initial = &self.phase {
// Start the first search
self.phase = Phase::Searching(SearchState::new(
current_mtu,
self.peer_max_udp_payload_size,
&self.config,
));
} else if let Phase::Complete(next_mtud_activation) = &self.phase {
if now < *next_mtud_activation {
return None;
}
// Start a new search (we have reached the next activation time)
self.phase = Phase::Searching(SearchState::new(
current_mtu,
self.peer_max_udp_payload_size,
&self.config,
));
}
if let Phase::Searching(state) = &mut self.phase {
// Nothing to do while there is a probe in flight
if state.in_flight_probe.is_some() {
return None;
}
// Retransmit lost probes, if any
if 0 < state.lost_probe_count && state.lost_probe_count < MAX_PROBE_RETRANSMITS {
state.in_flight_probe = Some(next_pn);
return Some(state.last_probed_mtu);
}
let last_probe_succeeded = state.lost_probe_count == 0;
// The probe is definitely lost (we reached the MAX_PROBE_RETRANSMITS threshold)
if !last_probe_succeeded {
state.lost_probe_count = 0;
state.in_flight_probe = None;
}
if let Some(probe_udp_payload_size) = state.next_mtu_to_probe(last_probe_succeeded) {
state.in_flight_probe = Some(next_pn);
state.last_probed_mtu = probe_udp_payload_size;
return Some(probe_udp_payload_size);
} else {
let next_mtud_activation = now + self.config.interval;
self.phase = Phase::Complete(next_mtud_activation);
return None;
}
}
None
}
/// Called when a packet is acknowledged in [`SpaceId::Data`]
///
/// Returns the new `current_mtu` if the packet number corresponds to the in-flight MTU probe
fn on_probe_acked(&mut self, pn: u64) -> Option<u16> {
match &mut self.phase {
Phase::Searching(state) if state.in_flight_probe == Some(pn) => {
state.in_flight_probe = None;
state.lost_probe_count = 0;
Some(state.last_probed_mtu)
}
_ => None,
}
}
/// Called when the in-flight MTU probe was lost
fn on_probe_lost(&mut self) {
// We might no longer be searching, e.g. if a black hole was detected
if let Phase::Searching(state) = &mut self.phase {
state.in_flight_probe = None;
state.lost_probe_count += 1;
}
}
/// Called when a black hole is detected
fn on_black_hole_detected(&mut self, now: Instant) {
// Stop searching, if applicable, and reset the timer
let next_mtud_activation = now + self.config.black_hole_cooldown;
self.phase = Phase::Complete(next_mtud_activation);
}
}
#[derive(Debug, Clone, Copy)]
enum Phase {
/// We haven't started polling yet
Initial,
/// We are currently searching for a higher PMTU
Searching(SearchState),
/// Searching has completed and will be triggered again at the provided instant
Complete(Instant),
}
#[derive(Debug, Clone, Copy)]
struct SearchState {
/// The lower bound for the current binary search
lower_bound: u16,
/// The upper bound for the current binary search
upper_bound: u16,
/// The minimum change to stop the current binary search
minimum_change: u16,
/// The UDP payload size we last sent a probe for
last_probed_mtu: u16,
/// Packet number of an in-flight probe (if any)
in_flight_probe: Option<u64>,
/// Lost probes at the current probe size
lost_probe_count: usize,
}
impl SearchState {
/// Creates a new search state, with the specified lower bound (the upper bound is derived from
/// the config and the peer's `max_udp_payload_size` transport parameter)
fn new(
mut lower_bound: u16,
peer_max_udp_payload_size: u16,
config: &MtuDiscoveryConfig,
) -> Self {
lower_bound = lower_bound.min(peer_max_udp_payload_size);
let upper_bound = config
.upper_bound
.clamp(lower_bound, peer_max_udp_payload_size);
Self {
in_flight_probe: None,
lost_probe_count: 0,
lower_bound,
upper_bound,
minimum_change: config.minimum_change,
// During initialization, we consider the lower bound to have already been
// successfully probed
last_probed_mtu: lower_bound,
}
}
/// Determines the next MTU to probe using binary search
fn next_mtu_to_probe(&mut self, last_probe_succeeded: bool) -> Option<u16> {
debug_assert_eq!(self.in_flight_probe, None);
if last_probe_succeeded {
self.lower_bound = self.last_probed_mtu;
} else {
self.upper_bound = self.last_probed_mtu - 1;
}
let next_mtu = (self.lower_bound as i32 + self.upper_bound as i32) / 2;
// Binary search stopping condition
if ((next_mtu - self.last_probed_mtu as i32).unsigned_abs() as u16) < self.minimum_change {
// Special case: if the upper bound is far enough, we want to probe it as a last
// step (otherwise we will never achieve the upper bound)
if self.upper_bound.saturating_sub(self.last_probed_mtu) >= self.minimum_change {
return Some(self.upper_bound);
}
return None;
}
Some(next_mtu as u16)
}
}
/// Judges whether packet loss might indicate a drop in MTU
///
/// Our MTU black hole detection scheme is a heuristic based on the order in which packets were sent
/// (the packet number order), their sizes, and which are deemed lost.
///
/// First, contiguous groups of lost packets ("loss bursts") are aggregated, because a group of
/// packets all lost together were probably lost for the same reason.
///
/// A loss burst is deemed "suspicious" if it contains no packets that are (a) smaller than the
/// minimum MTU or (b) smaller than a more recent acknowledged packet, because such a burst could be
/// fully explained by a reduction in MTU.
///
/// When the number of suspicious loss bursts exceeds [`BLACK_HOLE_THRESHOLD`], we judge the
/// evidence for an MTU black hole to be sufficient.
#[derive(Clone)]
struct BlackHoleDetector {
/// Packet loss bursts currently considered suspicious
suspicious_loss_bursts: Vec<LossBurst>,
/// Loss burst currently being aggregated, if any
current_loss_burst: Option<CurrentLossBurst>,
/// Packet number of the biggest packet larger than `min_mtu` which we've received
/// acknowledgment of more recently than any suspicious loss burst, if any
largest_post_loss_packet: u64,
/// The maximum of `min_mtu` and the size of `largest_post_loss_packet`, or exactly `min_mtu` if
/// no larger packets have been received since the most recent loss burst.
acked_mtu: u16,
/// The UDP payload size guaranteed to be supported by the network
min_mtu: u16,
}
impl BlackHoleDetector {
fn new(min_mtu: u16) -> Self {
Self {
suspicious_loss_bursts: Vec::with_capacity(BLACK_HOLE_THRESHOLD + 1),
current_loss_burst: None,
largest_post_loss_packet: 0,
acked_mtu: min_mtu,
min_mtu,
}
}
fn on_probe_acked(&mut self, pn: u64, len: u16) {
// MTU probes are always larger than the previous MTU, so no previous loss bursts are
// suspicious. At most one MTU probe is in flight at a time, so we don't need to worry about
// reordering between them.
self.suspicious_loss_bursts.clear();
self.acked_mtu = len;
// This might go backwards, but that's okay: a successful ACK means we haven't yet judged a
// more recently sent packet lost, and we just want to track the largest packet that's been
// successfully delivered more recently than a loss.
self.largest_post_loss_packet = pn;
}
fn on_non_probe_acked(&mut self, pn: u64, len: u16) {
if len <= self.acked_mtu {
// We've already seen a larger packet since the most recent suspicious loss burst;
// nothing to do.
return;
}
self.acked_mtu = len;
// This might go backwards, but that's okay as described in `on_probe_acked`.
self.largest_post_loss_packet = pn;
// Loss bursts packets smaller than this are retroactively deemed non-suspicious.
self.suspicious_loss_bursts
.retain(|burst| burst.smallest_packet_size > len);
}
fn on_non_probe_lost(&mut self, pn: u64, len: u16) {
// A loss burst is a group of consecutive packets that are declared lost, so a distance
// greater than 1 indicates a new burst
let end_last_burst = self
.current_loss_burst
.as_ref()
.is_some_and(|current| pn - current.latest_non_probe != 1);
if end_last_burst {
self.finish_loss_burst();
}
self.current_loss_burst = Some(CurrentLossBurst {
latest_non_probe: pn,
smallest_packet_size: self
.current_loss_burst
.map_or(len, |prev| cmp::min(prev.smallest_packet_size, len)),
});
}
fn black_hole_detected(&mut self) -> bool {
self.finish_loss_burst();
if self.suspicious_loss_bursts.len() <= BLACK_HOLE_THRESHOLD {
return false;
}
self.suspicious_loss_bursts.clear();
true
}
/// Marks the end of the current loss burst, checking whether it was suspicious
fn finish_loss_burst(&mut self) {
let Some(burst) = self.current_loss_burst.take() else {
return;
};
// If a loss burst contains a packet smaller than the minimum MTU or a more recently
// transmitted packet, it is not suspicious.
if burst.smallest_packet_size < self.min_mtu
|| (burst.latest_non_probe < self.largest_post_loss_packet
&& burst.smallest_packet_size < self.acked_mtu)
{
return;
}
// The loss burst is now deemed suspicious.
// A suspicious loss burst more recent than `largest_post_loss_packet` invalidates it. This
// makes `acked_mtu` a conservative approximation. Ideally we'd update `safe_mtu` and
// `largest_post_loss_packet` to describe the largest acknowledged packet sent later than
// this burst, but that would require tracking the size of an unpredictable number of
// recently acknowledged packets, and erring on the side of false positives is safe.
if burst.latest_non_probe > self.largest_post_loss_packet {
self.acked_mtu = self.min_mtu;
}
let burst = LossBurst {
smallest_packet_size: burst.smallest_packet_size,
};
if self.suspicious_loss_bursts.len() <= BLACK_HOLE_THRESHOLD {
self.suspicious_loss_bursts.push(burst);
return;
}
// To limit memory use, only track the most suspicious loss bursts.
let smallest = self
.suspicious_loss_bursts
.iter_mut()
.min_by_key(|prev| prev.smallest_packet_size)
.filter(|prev| prev.smallest_packet_size < burst.smallest_packet_size);
if let Some(smallest) = smallest {
*smallest = burst;
}
}
#[cfg(test)]
fn suspicious_loss_burst_count(&self) -> usize {
self.suspicious_loss_bursts.len()
}
#[cfg(test)]
fn largest_non_probe_lost(&self) -> Option<u64> {
self.current_loss_burst.as_ref().map(|x| x.latest_non_probe)
}
}
#[derive(Copy, Clone)]
struct LossBurst {
smallest_packet_size: u16,
}
#[derive(Copy, Clone)]
struct CurrentLossBurst {
smallest_packet_size: u16,
latest_non_probe: u64,
}
// Corresponds to the RFC's `MAX_PROBES` constant (see
// https://www.rfc-editor.org/rfc/rfc8899#section-5.1.2)
const MAX_PROBE_RETRANSMITS: usize = 3;
/// Maximum number of suspicious loss bursts that will not trigger black hole detection
const BLACK_HOLE_THRESHOLD: usize = 3;
#[cfg(test)]
mod tests {
use super::*;
use crate::Duration;
use crate::MAX_UDP_PAYLOAD;
use crate::packet::SpaceId;
use assert_matches::assert_matches;
fn default_mtud() -> MtuDiscovery {
let config = MtuDiscoveryConfig::default();
MtuDiscovery::new(1_200, 1_200, None, config)
}
fn completed(mtud: &MtuDiscovery) -> bool {
matches!(mtud.state.as_ref().unwrap().phase, Phase::Complete(_))
}
/// Drives mtud until it reaches `Phase::Completed`
fn drive_to_completion(
mtud: &mut MtuDiscovery,
now: Instant,
link_payload_size_limit: u16,
) -> Vec<u16> {
let mut probed_sizes = Vec::new();
for probe_pn in 1..100 {
let result = mtud.poll_transmit(now, probe_pn);
if completed(mtud) {
break;
}
// "Send" next probe
assert!(result.is_some());
let probe_size = result.unwrap();
probed_sizes.push(probe_size);
if probe_size <= link_payload_size_limit {
mtud.on_acked(SpaceId::Data, probe_pn, probe_size);
} else {
mtud.on_probe_lost();
}
}
probed_sizes
}
#[test]
fn black_hole_detector_ignores_burst_containing_non_suspicious_packet() {
let mut mtud = default_mtud();
mtud.on_non_probe_lost(2, 1300);
mtud.on_non_probe_lost(3, 1300);
assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), Some(3));
assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0);
mtud.on_non_probe_lost(4, 800);
assert!(!mtud.black_hole_detected(Instant::now()));
assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), None);
assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0);
}
#[test]
fn black_hole_detector_counts_burst_containing_only_suspicious_packets() {
let mut mtud = default_mtud();
mtud.on_non_probe_lost(2, 1300);
mtud.on_non_probe_lost(3, 1300);
assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), Some(3));
assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0);
assert!(!mtud.black_hole_detected(Instant::now()));
assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), None);
assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 1);
}
#[test]
fn black_hole_detector_ignores_empty_burst() {
let mut mtud = default_mtud();
assert!(!mtud.black_hole_detected(Instant::now()));
assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0);
}
#[test]
fn mtu_discovery_disabled_does_nothing() {
let mut mtud = MtuDiscovery::disabled(1_200, 1_200);
let probe_size = mtud.poll_transmit(Instant::now(), 0);
assert_eq!(probe_size, None);
}
#[test]
fn mtu_discovery_disabled_lost_four_packet_bursts_triggers_black_hole_detection() {
let mut mtud = MtuDiscovery::disabled(1_400, 1_250);
let now = Instant::now();
for i in 0..4 {
// The packets are never contiguous, so each one has its own burst
mtud.on_non_probe_lost(i * 2, 1300);
}
assert!(mtud.black_hole_detected(now));
assert_eq!(mtud.current_mtu, 1250);
assert_matches!(mtud.state, None);
}
#[test]
fn mtu_discovery_lost_two_packet_bursts_does_not_trigger_black_hole_detection() {
let mut mtud = default_mtud();
let now = Instant::now();
for i in 0..2 {
mtud.on_non_probe_lost(i, 1300);
assert!(!mtud.black_hole_detected(now));
}
}
#[test]
fn mtu_discovery_lost_four_packet_bursts_triggers_black_hole_detection_and_resets_timer() {
let mut mtud = default_mtud();
let now = Instant::now();
for i in 0..4 {
// The packets are never contiguous, so each one has its own burst
mtud.on_non_probe_lost(i * 2, 1300);
}
assert!(mtud.black_hole_detected(now));
assert_eq!(mtud.current_mtu, 1200);
if let Phase::Complete(next_mtud_activation) = mtud.state.unwrap().phase {
assert_eq!(next_mtud_activation, now + Duration::from_secs(60));
} else {
panic!("Unexpected MTUD phase!");
}
}
#[test]
fn mtu_discovery_after_complete_reactivates_when_interval_elapsed() {
let mut config = MtuDiscoveryConfig::default();
config.upper_bound(9_000);
let mut mtud = MtuDiscovery::new(1_200, 1_200, None, config);
let now = Instant::now();
drive_to_completion(&mut mtud, now, 1_500);
// Polling right after completion does not cause new packets to be sent
assert_eq!(mtud.poll_transmit(now, 42), None);
assert!(completed(&mtud));
assert_eq!(mtud.current_mtu, 1_471);
// Polling after the interval has passed does (taking the current mtu as lower bound)
assert_eq!(
mtud.poll_transmit(now + Duration::from_secs(600), 43),
Some(5235)
);
match mtud.state.unwrap().phase {
Phase::Searching(state) => {
assert_eq!(state.lower_bound, 1_471);
assert_eq!(state.upper_bound, 9_000);
}
_ => {
panic!("Unexpected MTUD phase!")
}
}
}
#[test]
fn mtu_discovery_lost_three_probes_lowers_probe_size() {
let mut mtud = default_mtud();
let mut probe_sizes = (0..4).map(|i| {
let probe_size = mtud.poll_transmit(Instant::now(), i);
assert!(probe_size.is_some(), "no probe returned for packet {i}");
mtud.on_probe_lost();
probe_size.unwrap()
});
// After the first probe is lost, it gets retransmitted twice
let first_probe_size = probe_sizes.next().unwrap();
for _ in 0..2 {
assert_eq!(probe_sizes.next().unwrap(), first_probe_size)
}
// After the third probe is lost, we decrement our probe size
let fourth_probe_size = probe_sizes.next().unwrap();
assert!(fourth_probe_size < first_probe_size);
assert_eq!(
fourth_probe_size,
first_probe_size - (first_probe_size - 1_200) / 2 - 1
);
}
#[test]
fn mtu_discovery_with_peer_max_udp_payload_size_clamps_upper_bound() {
let mut mtud = default_mtud();
mtud.on_peer_max_udp_payload_size_received(1300);
let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500);
assert_eq!(mtud.state.as_ref().unwrap().peer_max_udp_payload_size, 1300);
assert_eq!(mtud.current_mtu, 1300);
let expected_probed_sizes = &[1250, 1275, 1300];
assert_eq!(probed_sizes, expected_probed_sizes);
assert!(completed(&mtud));
}
#[test]
fn mtu_discovery_with_previous_peer_max_udp_payload_size_clamps_upper_bound() {
let mut mtud = MtuDiscovery::new(1500, 1_200, Some(1400), MtuDiscoveryConfig::default());
assert_eq!(mtud.current_mtu, 1400);
assert_eq!(mtud.state.as_ref().unwrap().peer_max_udp_payload_size, 1400);
let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500);
assert_eq!(mtud.current_mtu, 1400);
assert!(probed_sizes.is_empty());
assert!(completed(&mtud));
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn mtu_discovery_with_peer_max_udp_payload_size_after_search_panics() {
let mut mtud = default_mtud();
drive_to_completion(&mut mtud, Instant::now(), 1500);
mtud.on_peer_max_udp_payload_size_received(1300);
}
#[test]
fn mtu_discovery_with_1500_limit() {
let mut mtud = default_mtud();
let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500);
let expected_probed_sizes = &[1326, 1389, 1420, 1452];
assert_eq!(probed_sizes, expected_probed_sizes);
assert_eq!(mtud.current_mtu, 1452);
assert!(completed(&mtud));
}
#[test]
fn mtu_discovery_with_1500_limit_and_10000_upper_bound() {
let mut config = MtuDiscoveryConfig::default();
config.upper_bound(10_000);
let mut mtud = MtuDiscovery::new(1_200, 1_200, None, config);
let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500);
let expected_probed_sizes = &[
5600, 5600, 5600, 3399, 3399, 3399, 2299, 2299, 2299, 1749, 1749, 1749, 1474, 1611,
1611, 1611, 1542, 1542, 1542, 1507, 1507, 1507,
];
assert_eq!(probed_sizes, expected_probed_sizes);
assert_eq!(mtud.current_mtu, 1474);
assert!(completed(&mtud));
}
#[test]
fn mtu_discovery_no_lost_probes_finds_maximum_udp_payload() {
let mut config = MtuDiscoveryConfig::default();
config.upper_bound(MAX_UDP_PAYLOAD);
let mut mtud = MtuDiscovery::new(1200, 1200, None, config);
drive_to_completion(&mut mtud, Instant::now(), u16::MAX);
assert_eq!(mtud.current_mtu, 65527);
assert!(completed(&mtud));
}
#[test]
fn mtu_discovery_lost_half_of_probes_finds_maximum_udp_payload() {
let mut config = MtuDiscoveryConfig::default();
config.upper_bound(MAX_UDP_PAYLOAD);
let mut mtud = MtuDiscovery::new(1200, 1200, None, config);
let now = Instant::now();
let mut iterations = 0;
for i in 1..100 {
iterations += 1;
let probe_pn = i * 2 - 1;
let other_pn = i * 2;
let result = mtud.poll_transmit(Instant::now(), probe_pn);
if completed(&mtud) {
break;
}
// "Send" next probe
assert!(result.is_some());
assert!(mtud.in_flight_mtu_probe().is_some());
// Nothing else to send while the probe is in-flight
assert_matches!(mtud.poll_transmit(now, other_pn), None);
if i % 2 == 0 {
// ACK probe and ensure it results in an increase of current_mtu
let previous_max_size = mtud.current_mtu;
mtud.on_acked(SpaceId::Data, probe_pn, result.unwrap());
println!(
"ACK packet {}. Previous MTU = {previous_max_size}. New MTU = {}",
result.unwrap(),
mtud.current_mtu
);
// assert!(mtud.current_mtu > previous_max_size);
} else {
mtud.on_probe_lost();
}
}
assert_eq!(iterations, 25);
assert_eq!(mtud.current_mtu, 65527);
assert!(completed(&mtud));
}
#[test]
fn search_state_lower_bound_higher_than_upper_bound_clamps_upper_bound() {
let mut config = MtuDiscoveryConfig::default();
config.upper_bound(1400);
let state = SearchState::new(1500, u16::MAX, &config);
assert_eq!(state.lower_bound, 1500);
assert_eq!(state.upper_bound, 1500);
}
#[test]
fn search_state_lower_bound_higher_than_peer_max_udp_payload_size_clamps_lower_bound() {
let mut config = MtuDiscoveryConfig::default();
config.upper_bound(9000);
let state = SearchState::new(1500, 1300, &config);
assert_eq!(state.lower_bound, 1300);
assert_eq!(state.upper_bound, 1300);
}
#[test]
fn search_state_upper_bound_higher_than_peer_max_udp_payload_size_clamps_upper_bound() {
let mut config = MtuDiscoveryConfig::default();
config.upper_bound(9000);
let state = SearchState::new(1200, 1450, &config);
assert_eq!(state.lower_bound, 1200);
assert_eq!(state.upper_bound, 1450);
}
// Loss of packets larger than have been acknowledged should indicate a black hole
#[test]
fn simple_black_hole_detection() {
let mut bhd = BlackHoleDetector::new(1200);
bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 2, 1300);
for i in 0..BLACK_HOLE_THRESHOLD {
bhd.on_non_probe_lost(i as u64 * 2, 1400);
}
// But not before `BLACK_HOLE_THRESHOLD + 1` bursts
assert!(!bhd.black_hole_detected());
bhd.on_non_probe_lost(BLACK_HOLE_THRESHOLD as u64 * 2, 1400);
assert!(bhd.black_hole_detected());
}
// Loss of packets followed in transmission order by confirmation of a larger packet should not
// indicate a black hole
#[test]
fn non_suspicious_bursts() {
let mut bhd = BlackHoleDetector::new(1200);
bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 2, 1500);
for i in 0..(BLACK_HOLE_THRESHOLD + 1) {
bhd.on_non_probe_lost(i as u64 * 2, 1400);
}
assert!(!bhd.black_hole_detected());
}
// Loss of packets smaller than have been acknowledged previously should still indicate a black
// hole
#[test]
fn dynamic_mtu_reduction() {
let mut bhd = BlackHoleDetector::new(1200);
bhd.on_non_probe_acked(0, 1500);
for i in 0..(BLACK_HOLE_THRESHOLD + 1) {
bhd.on_non_probe_lost(i as u64 * 2, 1400);
}
assert!(bhd.black_hole_detected());
}
// Bursts containing heterogeneous packets are judged based on the smallest
#[test]
fn mixed_non_suspicious_bursts() {
let mut bhd = BlackHoleDetector::new(1200);
bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 3, 1400);
for i in 0..(BLACK_HOLE_THRESHOLD + 1) {
bhd.on_non_probe_lost(i as u64 * 3, 1500);
bhd.on_non_probe_lost(i as u64 * 3 + 1, 1300);
}
assert!(!bhd.black_hole_detected());
}
// Multi-packet bursts are only counted once
#[test]
fn bursts_count_once() {
let mut bhd = BlackHoleDetector::new(1200);
bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 3, 1400);
for i in 0..(BLACK_HOLE_THRESHOLD) {
bhd.on_non_probe_lost(i as u64 * 3, 1500);
bhd.on_non_probe_lost(i as u64 * 3 + 1, 1500);
}
assert!(!bhd.black_hole_detected());
bhd.on_non_probe_lost(BLACK_HOLE_THRESHOLD as u64 * 3, 1500);
assert!(bhd.black_hole_detected());
}
// Non-suspicious bursts don't interfere with detection of suspicious bursts
#[test]
fn interleaved_bursts() {
let mut bhd = BlackHoleDetector::new(1200);
bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 4, 1400);
for i in 0..(BLACK_HOLE_THRESHOLD + 1) {
bhd.on_non_probe_lost(i as u64 * 4, 1500);
bhd.on_non_probe_lost(i as u64 * 4 + 2, 1300);
}
assert!(bhd.black_hole_detected());
}
// Bursts that are non-suspicious before a delivered packet become suspicious past it
#[test]
fn suspicious_after_acked() {
let mut bhd = BlackHoleDetector::new(1200);
bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 2, 1400);
for i in 0..(BLACK_HOLE_THRESHOLD + 1) {
bhd.on_non_probe_lost(i as u64 * 2, 1300);
}
assert!(
!bhd.black_hole_detected(),
"1300 byte losses preceding a 1400 byte delivery are not suspicious"
);
for i in 0..(BLACK_HOLE_THRESHOLD + 1) {
bhd.on_non_probe_lost((BLACK_HOLE_THRESHOLD as u64 + 1 + i as u64) * 2, 1300);
}
assert!(
bhd.black_hole_detected(),
"1300 byte losses following a 1400 byte delivery are suspicious"
);
}
// Acknowledgment of a packet marks prior loss bursts with the same packet size as
// non-suspicious
#[test]
fn retroactively_non_suspicious() {
let mut bhd = BlackHoleDetector::new(1200);
for i in 0..BLACK_HOLE_THRESHOLD {
bhd.on_non_probe_lost(i as u64 * 2, 1400);
}
bhd.on_non_probe_acked(BLACK_HOLE_THRESHOLD as u64 * 2, 1400);
bhd.on_non_probe_lost(BLACK_HOLE_THRESHOLD as u64 * 2 + 1, 1400);
assert!(!bhd.black_hole_detected());
}
}

View File

@@ -0,0 +1,308 @@
//! Pacing of packet transmissions.
use crate::{Duration, Instant};
use tracing::warn;
/// A simple token-bucket pacer
///
/// The pacer's capacity is derived on a fraction of the congestion window
/// which can be sent in regular intervals
/// Once the bucket is empty, further transmission is blocked.
/// The bucket refills at a rate slightly faster
/// than one congestion window per RTT, as recommended in
/// <https://tools.ietf.org/html/draft-ietf-quic-recovery-34#section-7.7>
pub(super) struct Pacer {
capacity: u64,
last_window: u64,
last_mtu: u16,
tokens: u64,
prev: Instant,
}
impl Pacer {
/// Obtains a new [`Pacer`].
pub(super) fn new(smoothed_rtt: Duration, window: u64, mtu: u16, now: Instant) -> Self {
let capacity = optimal_capacity(smoothed_rtt, window, mtu);
Self {
capacity,
last_window: window,
last_mtu: mtu,
tokens: capacity,
prev: now,
}
}
/// Record that a packet has been transmitted.
pub(super) fn on_transmit(&mut self, packet_length: u16) {
self.tokens = self.tokens.saturating_sub(packet_length.into())
}
/// Return how long we need to wait before sending `bytes_to_send`
///
/// If we can send a packet right away, this returns `None`. Otherwise, returns `Some(d)`,
/// where `d` is the time before this function should be called again.
///
/// The 5/4 ratio used here comes from the suggestion that N = 1.25 in the draft IETF RFC for
/// QUIC.
pub(super) fn delay(
&mut self,
smoothed_rtt: Duration,
bytes_to_send: u64,
mtu: u16,
window: u64,
now: Instant,
) -> Option<Instant> {
debug_assert_ne!(
window, 0,
"zero-sized congestion control window is nonsense"
);
if window != self.last_window || mtu != self.last_mtu {
self.capacity = optimal_capacity(smoothed_rtt, window, mtu);
// Clamp the tokens
self.tokens = self.capacity.min(self.tokens);
self.last_window = window;
self.last_mtu = mtu;
}
// if we can already send a packet, there is no need for delay
if self.tokens >= bytes_to_send {
return None;
}
// we disable pacing for extremely large windows
if window > u64::from(u32::MAX) {
return None;
}
let window = window as u32;
let time_elapsed = now.checked_duration_since(self.prev).unwrap_or_else(|| {
warn!("received a timestamp early than a previous recorded time, ignoring");
Default::default()
});
if smoothed_rtt.as_nanos() == 0 {
return None;
}
let elapsed_rtts = time_elapsed.as_secs_f64() / smoothed_rtt.as_secs_f64();
let new_tokens = window as f64 * 1.25 * elapsed_rtts;
self.tokens = self
.tokens
.saturating_add(new_tokens as _)
.min(self.capacity);
self.prev = now;
// if we can already send a packet, there is no need for delay
if self.tokens >= bytes_to_send {
return None;
}
let unscaled_delay = smoothed_rtt
.checked_mul((bytes_to_send.max(self.capacity) - self.tokens) as _)
.unwrap_or(Duration::MAX)
/ window;
// divisions come before multiplications to prevent overflow
// this is the time at which the pacing window becomes empty
Some(self.prev + (unscaled_delay / 5) * 4)
}
}
/// Calculates a pacer capacity for a certain window and RTT
///
/// The goal is to emit a burst (of size `capacity`) in timer intervals
/// which compromise between
/// - ideally distributing datagrams over time
/// - constantly waking up the connection to produce additional datagrams
///
/// Too short burst intervals means we will never meet them since the timer
/// accuracy in user-space is not high enough. If we miss the interval by more
/// than 25%, we will lose that part of the congestion window since no additional
/// tokens for the extra-elapsed time can be stored.
///
/// Too long burst intervals make pacing less effective.
fn optimal_capacity(smoothed_rtt: Duration, window: u64, mtu: u16) -> u64 {
let rtt = smoothed_rtt.as_nanos().max(1);
let capacity = ((window as u128 * BURST_INTERVAL_NANOS) / rtt) as u64;
// Small bursts are less efficient (no GSO), could increase latency and don't effectively
// use the channel's buffer capacity. Large bursts might block the connection on sending.
capacity.clamp(MIN_BURST_SIZE * mtu as u64, MAX_BURST_SIZE * mtu as u64)
}
/// The burst interval
///
/// The capacity will we refilled in 4/5 of that time.
/// 2ms is chosen here since framework timers might have 1ms precision.
/// If kernel-level pacing is supported later a higher time here might be
/// more applicable.
const BURST_INTERVAL_NANOS: u128 = 2_000_000; // 2ms
/// Allows some usage of GSO, and doesn't slow down the handshake.
const MIN_BURST_SIZE: u64 = 10;
/// Creating 256 packets took 1ms in a benchmark, so larger bursts don't make sense.
const MAX_BURST_SIZE: u64 = 256;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn does_not_panic_on_bad_instant() {
let old_instant = Instant::now();
let new_instant = old_instant + Duration::from_micros(15);
let rtt = Duration::from_micros(400);
assert!(
Pacer::new(rtt, 30000, 1500, new_instant)
.delay(Duration::from_micros(0), 0, 1500, 1, old_instant)
.is_none()
);
assert!(
Pacer::new(rtt, 30000, 1500, new_instant)
.delay(Duration::from_micros(0), 1600, 1500, 1, old_instant)
.is_none()
);
assert!(
Pacer::new(rtt, 30000, 1500, new_instant)
.delay(Duration::from_micros(0), 1500, 1500, 3000, old_instant)
.is_none()
);
}
#[test]
fn derives_initial_capacity() {
let window = 2_000_000;
let mtu = 1500;
let rtt = Duration::from_millis(50);
let now = Instant::now();
let pacer = Pacer::new(rtt, window, mtu, now);
assert_eq!(
pacer.capacity,
(window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64
);
assert_eq!(pacer.tokens, pacer.capacity);
let pacer = Pacer::new(Duration::from_millis(0), window, mtu, now);
assert_eq!(pacer.capacity, MAX_BURST_SIZE * mtu as u64);
assert_eq!(pacer.tokens, pacer.capacity);
let pacer = Pacer::new(rtt, 1, mtu, now);
assert_eq!(pacer.capacity, MIN_BURST_SIZE * mtu as u64);
assert_eq!(pacer.tokens, pacer.capacity);
}
#[test]
fn adjusts_capacity() {
let window = 2_000_000;
let mtu = 1500;
let rtt = Duration::from_millis(50);
let now = Instant::now();
let mut pacer = Pacer::new(rtt, window, mtu, now);
assert_eq!(
pacer.capacity,
(window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64
);
assert_eq!(pacer.tokens, pacer.capacity);
let initial_tokens = pacer.tokens;
pacer.delay(rtt, mtu as u64, mtu, window * 2, now);
assert_eq!(
pacer.capacity,
(2 * window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64
);
assert_eq!(pacer.tokens, initial_tokens);
pacer.delay(rtt, mtu as u64, mtu, window / 2, now);
assert_eq!(
pacer.capacity,
(window as u128 / 2 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64
);
assert_eq!(pacer.tokens, initial_tokens / 2);
pacer.delay(rtt, mtu as u64, mtu * 2, window, now);
assert_eq!(
pacer.capacity,
(window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64
);
pacer.delay(rtt, mtu as u64, 20_000, window, now);
assert_eq!(pacer.capacity, 20_000_u64 * MIN_BURST_SIZE);
}
#[test]
fn computes_pause_correctly() {
let window = 2_000_000u64;
let mtu = 1000;
let rtt = Duration::from_millis(50);
let old_instant = Instant::now();
let mut pacer = Pacer::new(rtt, window, mtu, old_instant);
let packet_capacity = pacer.capacity / mtu as u64;
for _ in 0..packet_capacity {
assert_eq!(
pacer.delay(rtt, mtu as u64, mtu, window, old_instant),
None,
"When capacity is available packets should be sent immediately"
);
pacer.on_transmit(mtu);
}
let pace_duration = Duration::from_nanos((BURST_INTERVAL_NANOS * 4 / 5) as u64);
assert_eq!(
pacer
.delay(rtt, mtu as u64, mtu, window, old_instant)
.expect("Send must be delayed")
.duration_since(old_instant),
pace_duration
);
// Refill half of the tokens
assert_eq!(
pacer.delay(
rtt,
mtu as u64,
mtu,
window,
old_instant + pace_duration / 2
),
None
);
assert_eq!(pacer.tokens, pacer.capacity / 2);
for _ in 0..packet_capacity / 2 {
assert_eq!(
pacer.delay(rtt, mtu as u64, mtu, window, old_instant),
None,
"When capacity is available packets should be sent immediately"
);
pacer.on_transmit(mtu);
}
// Refill all capacity by waiting more than the expected duration
assert_eq!(
pacer.delay(
rtt,
mtu as u64,
mtu,
window,
old_instant + pace_duration * 3 / 2
),
None
);
assert_eq!(pacer.tokens, pacer.capacity);
}
}

View File

@@ -0,0 +1,282 @@
use bytes::Bytes;
use rand::Rng;
use tracing::{debug, trace, trace_span};
use super::{Connection, SentFrames, spaces::SentPacket};
use crate::{
ConnectionId, Instant, TransportError, TransportErrorCode,
connection::ConnectionSide,
frame::{self, Close},
packet::{FIXED_BIT, Header, InitialHeader, LongType, PacketNumber, PartialEncode, SpaceId},
};
pub(super) struct PacketBuilder {
pub(super) datagram_start: usize,
pub(super) space: SpaceId,
pub(super) partial_encode: PartialEncode,
pub(super) ack_eliciting: bool,
pub(super) exact_number: u64,
pub(super) short_header: bool,
/// Smallest absolute position in the associated buffer that must be occupied by this packet's
/// frames
pub(super) min_size: usize,
/// Largest absolute position in the associated buffer that may be occupied by this packet's
/// frames
pub(super) max_size: usize,
pub(super) tag_len: usize,
pub(super) _span: tracing::span::EnteredSpan,
}
impl PacketBuilder {
/// Write a new packet header to `buffer` and determine the packet's properties
///
/// Marks the connection drained and returns `None` if the confidentiality limit would be
/// violated.
pub(super) fn new(
now: Instant,
space_id: SpaceId,
dst_cid: ConnectionId,
buffer: &mut Vec<u8>,
buffer_capacity: usize,
datagram_start: usize,
ack_eliciting: bool,
conn: &mut Connection,
) -> Option<Self> {
let version = conn.version;
// Initiate key update if we're approaching the confidentiality limit
let sent_with_keys = conn.spaces[space_id].sent_with_keys;
if space_id == SpaceId::Data {
if sent_with_keys >= conn.key_phase_size {
debug!("routine key update due to phase exhaustion");
conn.force_key_update();
}
} else {
let confidentiality_limit = conn.spaces[space_id]
.crypto
.as_ref()
.map_or_else(
|| &conn.zero_rtt_crypto.as_ref().unwrap().packet,
|keys| &keys.packet.local,
)
.confidentiality_limit();
if sent_with_keys.saturating_add(1) == confidentiality_limit {
// We still have time to attempt a graceful close
conn.close_inner(
now,
Close::Connection(frame::ConnectionClose {
error_code: TransportErrorCode::AEAD_LIMIT_REACHED,
frame_type: None,
reason: Bytes::from_static(b"confidentiality limit reached"),
}),
)
} else if sent_with_keys > confidentiality_limit {
// Confidentiality limited violated and there's nothing we can do
conn.kill(
TransportError::AEAD_LIMIT_REACHED("confidentiality limit reached").into(),
);
return None;
}
}
let space = &mut conn.spaces[space_id];
let exact_number = match space_id {
SpaceId::Data => conn.packet_number_filter.allocate(&mut conn.rng, space),
_ => space.get_tx_number(),
};
let span = trace_span!("send", space = ?space_id, pn = exact_number).entered();
let number = PacketNumber::new(exact_number, space.largest_acked_packet.unwrap_or(0));
let header = match space_id {
SpaceId::Data if space.crypto.is_some() => Header::Short {
dst_cid,
number,
spin: if conn.spin_enabled {
conn.spin
} else {
conn.rng.random()
},
key_phase: conn.key_phase,
},
SpaceId::Data => Header::Long {
ty: LongType::ZeroRtt,
src_cid: conn.handshake_cid,
dst_cid,
number,
version,
},
SpaceId::Handshake => Header::Long {
ty: LongType::Handshake,
src_cid: conn.handshake_cid,
dst_cid,
number,
version,
},
SpaceId::Initial => Header::Initial(InitialHeader {
src_cid: conn.handshake_cid,
dst_cid,
token: match &conn.side {
ConnectionSide::Client { token, .. } => token.clone(),
ConnectionSide::Server { .. } => Bytes::new(),
},
number,
version,
}),
};
let partial_encode = header.encode(buffer);
if conn.peer_params.grease_quic_bit && conn.rng.random() {
buffer[partial_encode.start] ^= FIXED_BIT;
}
let (sample_size, tag_len) = if let Some(ref crypto) = space.crypto {
(
crypto.header.local.sample_size(),
crypto.packet.local.tag_len(),
)
} else if space_id == SpaceId::Data {
let zero_rtt = conn.zero_rtt_crypto.as_ref().unwrap();
(zero_rtt.header.sample_size(), zero_rtt.packet.tag_len())
} else {
unreachable!();
};
// Each packet must be large enough for header protection sampling, i.e. the combined
// lengths of the encoded packet number and protected payload must be at least 4 bytes
// longer than the sample required for header protection. Further, each packet should be at
// least tag_len + 6 bytes larger than the destination CID on incoming packets so that the
// peer may send stateless resets that are indistinguishable from regular traffic.
// pn_len + payload_len + tag_len >= sample_size + 4
// payload_len >= sample_size + 4 - pn_len - tag_len
let min_size = Ord::max(
buffer.len() + (sample_size + 4).saturating_sub(number.len() + tag_len),
partial_encode.start + dst_cid.len() + 6,
);
let max_size = buffer_capacity - tag_len;
debug_assert!(max_size >= min_size);
Some(Self {
datagram_start,
space: space_id,
partial_encode,
exact_number,
short_header: header.is_short(),
min_size,
max_size,
tag_len,
ack_eliciting,
_span: span,
})
}
/// Append the minimum amount of padding to the packet such that, after encryption, the
/// enclosing datagram will occupy at least `min_size` bytes
pub(super) fn pad_to(&mut self, min_size: u16) {
// The datagram might already have a larger minimum size than the caller is requesting, if
// e.g. we're coalescing packets and have populated more than `min_size` bytes with packets
// already.
self.min_size = Ord::max(
self.min_size,
self.datagram_start + (min_size as usize) - self.tag_len,
);
}
pub(super) fn finish_and_track(
self,
now: Instant,
conn: &mut Connection,
sent: Option<SentFrames>,
buffer: &mut Vec<u8>,
) {
let ack_eliciting = self.ack_eliciting;
let exact_number = self.exact_number;
let space_id = self.space;
let (size, padded) = self.finish(conn, now, buffer);
let sent = match sent {
Some(sent) => sent,
None => return,
};
let size = match padded || ack_eliciting {
true => size as u16,
false => 0,
};
let packet = SentPacket {
path_generation: conn.path.generation(),
largest_acked: sent.largest_acked,
time_sent: now,
size,
ack_eliciting,
retransmits: sent.retransmits,
stream_frames: sent.stream_frames,
};
conn.path
.sent(exact_number, packet, &mut conn.spaces[space_id]);
conn.stats.path.sent_packets += 1;
conn.reset_keep_alive(now);
if size != 0 {
if ack_eliciting {
conn.spaces[space_id].time_of_last_ack_eliciting_packet = Some(now);
if conn.permit_idle_reset {
conn.reset_idle_timeout(now, space_id);
}
conn.permit_idle_reset = false;
}
conn.set_loss_detection_timer(now);
conn.path.pacing.on_transmit(size);
}
}
/// Encrypt packet, returning the length of the packet and whether padding was added
pub(super) fn finish(
self,
conn: &mut Connection,
now: Instant,
buffer: &mut Vec<u8>,
) -> (usize, bool) {
let pad = buffer.len() < self.min_size;
if pad {
trace!("PADDING * {}", self.min_size - buffer.len());
buffer.resize(self.min_size, 0);
}
let space = &conn.spaces[self.space];
let (header_crypto, packet_crypto) = if let Some(ref crypto) = space.crypto {
(&*crypto.header.local, &*crypto.packet.local)
} else if self.space == SpaceId::Data {
let zero_rtt = conn.zero_rtt_crypto.as_ref().unwrap();
(&*zero_rtt.header, &*zero_rtt.packet)
} else {
unreachable!("tried to send {:?} packet without keys", self.space);
};
debug_assert_eq!(
packet_crypto.tag_len(),
self.tag_len,
"Mismatching crypto tag len"
);
buffer.resize(buffer.len() + packet_crypto.tag_len(), 0);
let encode_start = self.partial_encode.start;
let packet_buf = &mut buffer[encode_start..];
self.partial_encode.finish(
packet_buf,
header_crypto,
Some((self.exact_number, packet_crypto)),
);
let len = buffer.len() - encode_start;
conn.config.qlog_sink.emit_packet_sent(
self.exact_number,
len,
self.space,
self.space == SpaceId::Data && conn.spaces[SpaceId::Data].crypto.is_none(),
now,
conn.orig_rem_cid,
);
(len, pad)
}
}

View File

@@ -0,0 +1,173 @@
use tracing::{debug, trace};
use crate::Instant;
use crate::connection::spaces::PacketSpace;
use crate::crypto::{HeaderKey, KeyPair, PacketKey};
use crate::packet::{Packet, PartialDecode, SpaceId};
use crate::token::ResetToken;
use crate::{RESET_TOKEN_SIZE, TransportError};
/// Removes header protection of a packet, or returns `None` if the packet was dropped
pub(super) fn unprotect_header(
partial_decode: PartialDecode,
spaces: &[PacketSpace; 3],
zero_rtt_crypto: Option<&ZeroRttCrypto>,
stateless_reset_token: Option<ResetToken>,
) -> Option<UnprotectHeaderResult> {
let header_crypto = if partial_decode.is_0rtt() {
if let Some(crypto) = zero_rtt_crypto {
Some(&*crypto.header)
} else {
debug!("dropping unexpected 0-RTT packet");
return None;
}
} else if let Some(space) = partial_decode.space() {
if let Some(ref crypto) = spaces[space].crypto {
Some(&*crypto.header.remote)
} else {
debug!(
"discarding unexpected {:?} packet ({} bytes)",
space,
partial_decode.len(),
);
return None;
}
} else {
// Unprotected packet
None
};
let packet = partial_decode.data();
let stateless_reset = packet.len() >= RESET_TOKEN_SIZE + 5
&& stateless_reset_token.as_deref() == Some(&packet[packet.len() - RESET_TOKEN_SIZE..]);
match partial_decode.finish(header_crypto) {
Ok(packet) => Some(UnprotectHeaderResult {
packet: Some(packet),
stateless_reset,
}),
Err(_) if stateless_reset => Some(UnprotectHeaderResult {
packet: None,
stateless_reset: true,
}),
Err(e) => {
trace!("unable to complete packet decoding: {}", e);
None
}
}
}
pub(super) struct UnprotectHeaderResult {
/// The packet with the now unprotected header (`None` in the case of stateless reset packets
/// that fail to be decoded)
pub(super) packet: Option<Packet>,
/// Whether the packet was a stateless reset packet
pub(super) stateless_reset: bool,
}
/// Decrypts a packet's body in-place
pub(super) fn decrypt_packet_body(
packet: &mut Packet,
spaces: &[PacketSpace; 3],
zero_rtt_crypto: Option<&ZeroRttCrypto>,
conn_key_phase: bool,
prev_crypto: Option<&PrevCrypto>,
next_crypto: Option<&KeyPair<Box<dyn PacketKey>>>,
) -> Result<Option<DecryptPacketResult>, Option<TransportError>> {
if !packet.header.is_protected() {
// Unprotected packets also don't have packet numbers
return Ok(None);
}
let space = packet.header.space();
let rx_packet = spaces[space].rx_packet;
let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1);
let packet_key_phase = packet.header.key_phase();
let mut crypto_update = false;
let crypto = if packet.header.is_0rtt() {
&zero_rtt_crypto.unwrap().packet
} else if packet_key_phase == conn_key_phase || space != SpaceId::Data {
&spaces[space].crypto.as_ref().unwrap().packet.remote
} else if let Some(prev) = prev_crypto.and_then(|crypto| {
// If this packet comes prior to acknowledgment of the key update by the peer,
if crypto.end_packet.map_or(true, |(pn, _)| number < pn) {
// use the previous keys.
Some(crypto)
} else {
// Otherwise, this must be a remotely-initiated key update, so fall through to the
// final case.
None
}
}) {
&prev.crypto.remote
} else {
// We're in the Data space with a key phase mismatch and either there is no locally
// initiated key update or the locally initiated key update was acknowledged by a
// lower-numbered packet. The key phase mismatch must therefore represent a new
// remotely-initiated key update.
crypto_update = true;
&next_crypto.unwrap().remote
};
crypto
.decrypt(number, &packet.header_data, &mut packet.payload)
.map_err(|_| {
trace!("decryption failed with packet number {}", number);
None
})?;
if !packet.reserved_bits_valid() {
return Err(Some(TransportError::PROTOCOL_VIOLATION(
"reserved bits set",
)));
}
let mut outgoing_key_update_acked = false;
if let Some(prev) = prev_crypto {
if prev.end_packet.is_none() && packet_key_phase == conn_key_phase {
outgoing_key_update_acked = true;
}
}
if crypto_update {
// Validate incoming key update
if number <= rx_packet || prev_crypto.is_some_and(|x| x.update_unacked) {
return Err(Some(TransportError::KEY_UPDATE_ERROR("")));
}
}
Ok(Some(DecryptPacketResult {
number,
outgoing_key_update_acked,
incoming_key_update: crypto_update,
}))
}
pub(super) struct DecryptPacketResult {
/// The packet number
pub(super) number: u64,
/// Whether a locally initiated key update has been acknowledged by the peer
pub(super) outgoing_key_update_acked: bool,
/// Whether the peer has initiated a key update
pub(super) incoming_key_update: bool,
}
pub(super) struct PrevCrypto {
/// The keys used for the previous key phase, temporarily retained to decrypt packets sent by
/// the peer prior to its own key update.
pub(super) crypto: KeyPair<Box<dyn PacketKey>>,
/// The incoming packet that ends the interval for which these keys are applicable, and the time
/// of its receipt.
///
/// Incoming packets should be decrypted using these keys iff this is `None` or their packet
/// number is lower. `None` indicates that we have not yet received a packet using newer keys,
/// which implies that the update was locally initiated.
pub(super) end_packet: Option<(u64, Instant)>,
/// Whether the following key phase is from a remotely initiated update that we haven't acked
pub(super) update_unacked: bool,
}
pub(super) struct ZeroRttCrypto {
pub(super) header: Box<dyn HeaderKey>,
pub(super) packet: Box<dyn PacketKey>,
}

View File

@@ -0,0 +1,456 @@
use std::{cmp, net::SocketAddr};
use tracing::trace;
use super::{
mtud::MtuDiscovery,
pacing::Pacer,
spaces::{PacketSpace, SentPacket},
};
use crate::{Duration, Instant, TIMER_GRANULARITY, TransportConfig, congestion, packet::SpaceId};
#[cfg(feature = "qlog")]
use qlog::events::quic::MetricsUpdated;
/// Description of a particular network path
pub(super) struct PathData {
pub(super) remote: SocketAddr,
pub(super) rtt: RttEstimator,
/// Whether we're enabling ECN on outgoing packets
pub(super) sending_ecn: bool,
/// Congestion controller state
pub(super) congestion: Box<dyn congestion::Controller>,
/// Pacing state
pub(super) pacing: Pacer,
pub(super) challenge: Option<u64>,
pub(super) challenge_pending: bool,
/// Whether we're certain the peer can both send and receive on this address
///
/// Initially equal to `use_stateless_retry` for servers, and becomes false again on every
/// migration. Always true for clients.
pub(super) validated: bool,
/// Total size of all UDP datagrams sent on this path
pub(super) total_sent: u64,
/// Total size of all UDP datagrams received on this path
pub(super) total_recvd: u64,
/// The state of the MTU discovery process
pub(super) mtud: MtuDiscovery,
/// Packet number of the first packet sent after an RTT sample was collected on this path
///
/// Used in persistent congestion determination.
pub(super) first_packet_after_rtt_sample: Option<(SpaceId, u64)>,
pub(super) in_flight: InFlight,
/// Number of the first packet sent on this path
///
/// Used to determine whether a packet was sent on an earlier path. Insufficient to determine if
/// a packet was sent on a later path.
first_packet: Option<u64>,
/// Snapshot of the qlog recovery metrics
#[cfg(feature = "qlog")]
recovery_metrics: RecoveryMetrics,
/// Tag uniquely identifying a path in a connection
generation: u64,
}
impl PathData {
pub(super) fn new(
remote: SocketAddr,
allow_mtud: bool,
peer_max_udp_payload_size: Option<u16>,
generation: u64,
now: Instant,
config: &TransportConfig,
) -> Self {
let congestion = config
.congestion_controller_factory
.clone()
.build(now, config.get_initial_mtu());
Self {
remote,
rtt: RttEstimator::new(config.initial_rtt),
sending_ecn: true,
pacing: Pacer::new(
config.initial_rtt,
congestion.initial_window(),
config.get_initial_mtu(),
now,
),
congestion,
challenge: None,
challenge_pending: false,
validated: false,
total_sent: 0,
total_recvd: 0,
mtud: config
.mtu_discovery_config
.as_ref()
.filter(|_| allow_mtud)
.map_or(
MtuDiscovery::disabled(config.get_initial_mtu(), config.min_mtu),
|mtud_config| {
MtuDiscovery::new(
config.get_initial_mtu(),
config.min_mtu,
peer_max_udp_payload_size,
mtud_config.clone(),
)
},
),
first_packet_after_rtt_sample: None,
in_flight: InFlight::new(),
first_packet: None,
#[cfg(feature = "qlog")]
recovery_metrics: RecoveryMetrics::default(),
generation,
}
}
pub(super) fn from_previous(
remote: SocketAddr,
prev: &Self,
generation: u64,
now: Instant,
) -> Self {
let congestion = prev.congestion.clone_box();
let smoothed_rtt = prev.rtt.get();
Self {
remote,
rtt: prev.rtt,
pacing: Pacer::new(smoothed_rtt, congestion.window(), prev.current_mtu(), now),
sending_ecn: true,
congestion,
challenge: None,
challenge_pending: false,
validated: false,
total_sent: 0,
total_recvd: 0,
mtud: prev.mtud.clone(),
first_packet_after_rtt_sample: prev.first_packet_after_rtt_sample,
in_flight: InFlight::new(),
first_packet: None,
#[cfg(feature = "qlog")]
recovery_metrics: prev.recovery_metrics.clone(),
generation,
}
}
/// Resets RTT, congestion control and MTU states.
///
/// This is useful when it is known the underlying path has changed.
pub(super) fn reset(&mut self, now: Instant, config: &TransportConfig) {
self.rtt = RttEstimator::new(config.initial_rtt);
self.congestion = config
.congestion_controller_factory
.clone()
.build(now, config.get_initial_mtu());
self.mtud.reset(config.get_initial_mtu(), config.min_mtu);
}
/// Indicates whether we're a server that hasn't validated the peer's address and hasn't
/// received enough data from the peer to permit sending `bytes_to_send` additional bytes
pub(super) fn anti_amplification_blocked(&self, bytes_to_send: u64) -> bool {
!self.validated && self.total_recvd * 3 < self.total_sent + bytes_to_send
}
/// Returns the path's current MTU
pub(super) fn current_mtu(&self) -> u16 {
self.mtud.current_mtu()
}
/// Account for transmission of `packet` with number `pn` in `space`
pub(super) fn sent(&mut self, pn: u64, packet: SentPacket, space: &mut PacketSpace) {
self.in_flight.insert(&packet);
if self.first_packet.is_none() {
self.first_packet = Some(pn);
}
if let Some(forgotten) = space.sent(pn, packet) {
self.remove_in_flight(&forgotten);
}
}
/// Remove `packet` with number `pn` from this path's congestion control counters, or return
/// `false` if `pn` was sent before this path was established.
pub(super) fn remove_in_flight(&mut self, packet: &SentPacket) -> bool {
if packet.path_generation != self.generation {
return false;
}
self.in_flight.remove(packet);
true
}
#[cfg(feature = "qlog")]
pub(super) fn qlog_recovery_metrics(&mut self, pto_count: u32) -> Option<MetricsUpdated> {
let controller_metrics = self.congestion.metrics();
let metrics = RecoveryMetrics {
min_rtt: Some(self.rtt.min),
smoothed_rtt: Some(self.rtt.get()),
latest_rtt: Some(self.rtt.latest),
rtt_variance: Some(self.rtt.var),
pto_count: Some(pto_count),
bytes_in_flight: Some(self.in_flight.bytes),
packets_in_flight: Some(self.in_flight.ack_eliciting),
congestion_window: Some(controller_metrics.congestion_window),
ssthresh: controller_metrics.ssthresh,
pacing_rate: controller_metrics.pacing_rate,
};
let event = metrics.to_qlog_event(&self.recovery_metrics);
self.recovery_metrics = metrics;
event
}
pub(super) fn generation(&self) -> u64 {
self.generation
}
}
/// Congestion metrics as described in [`recovery_metrics_updated`].
///
/// [`recovery_metrics_updated`]: https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-quic-events.html#name-recovery_metrics_updated
#[cfg(feature = "qlog")]
#[derive(Default, Clone, PartialEq)]
#[non_exhaustive]
struct RecoveryMetrics {
pub min_rtt: Option<Duration>,
pub smoothed_rtt: Option<Duration>,
pub latest_rtt: Option<Duration>,
pub rtt_variance: Option<Duration>,
pub pto_count: Option<u32>,
pub bytes_in_flight: Option<u64>,
pub packets_in_flight: Option<u64>,
pub congestion_window: Option<u64>,
pub ssthresh: Option<u64>,
pub pacing_rate: Option<u64>,
}
#[cfg(feature = "qlog")]
impl RecoveryMetrics {
/// Retain only values that have been updated since the last snapshot.
fn retain_updated(&self, previous: &Self) -> Self {
macro_rules! keep_if_changed {
($name:ident) => {
if previous.$name == self.$name {
None
} else {
self.$name
}
};
}
Self {
min_rtt: keep_if_changed!(min_rtt),
smoothed_rtt: keep_if_changed!(smoothed_rtt),
latest_rtt: keep_if_changed!(latest_rtt),
rtt_variance: keep_if_changed!(rtt_variance),
pto_count: keep_if_changed!(pto_count),
bytes_in_flight: keep_if_changed!(bytes_in_flight),
packets_in_flight: keep_if_changed!(packets_in_flight),
congestion_window: keep_if_changed!(congestion_window),
ssthresh: keep_if_changed!(ssthresh),
pacing_rate: keep_if_changed!(pacing_rate),
}
}
/// Emit a `MetricsUpdated` event containing only updated values
fn to_qlog_event(&self, previous: &Self) -> Option<MetricsUpdated> {
let updated = self.retain_updated(previous);
if updated == Self::default() {
return None;
}
Some(MetricsUpdated {
min_rtt: updated.min_rtt.map(|rtt| rtt.as_secs_f32()),
smoothed_rtt: updated.smoothed_rtt.map(|rtt| rtt.as_secs_f32()),
latest_rtt: updated.latest_rtt.map(|rtt| rtt.as_secs_f32()),
rtt_variance: updated.rtt_variance.map(|rtt| rtt.as_secs_f32()),
pto_count: updated
.pto_count
.map(|count| count.try_into().unwrap_or(u16::MAX)),
bytes_in_flight: updated.bytes_in_flight,
packets_in_flight: updated.packets_in_flight,
congestion_window: updated.congestion_window,
ssthresh: updated.ssthresh,
pacing_rate: updated.pacing_rate,
})
}
}
/// RTT estimation for a particular network path
#[derive(Copy, Clone)]
pub struct RttEstimator {
/// The most recent RTT measurement made when receiving an ack for a previously unacked packet
latest: Duration,
/// The smoothed RTT of the connection, computed as described in RFC6298
smoothed: Option<Duration>,
/// The RTT variance, computed as described in RFC6298
var: Duration,
/// The minimum RTT seen in the connection, ignoring ack delay.
min: Duration,
}
impl RttEstimator {
fn new(initial_rtt: Duration) -> Self {
Self {
latest: initial_rtt,
smoothed: None,
var: initial_rtt / 2,
min: initial_rtt,
}
}
/// The current best RTT estimation.
pub fn get(&self) -> Duration {
self.smoothed.unwrap_or(self.latest)
}
/// Conservative estimate of RTT
///
/// Takes the maximum of smoothed and latest RTT, as recommended
/// in 6.1.2 of the recovery spec (draft 29).
pub fn conservative(&self) -> Duration {
self.get().max(self.latest)
}
/// Minimum RTT registered so far for this estimator.
pub fn min(&self) -> Duration {
self.min
}
// PTO computed as described in RFC9002#6.2.1
pub(crate) fn pto_base(&self) -> Duration {
self.get() + cmp::max(4 * self.var, TIMER_GRANULARITY)
}
pub(crate) fn update(&mut self, ack_delay: Duration, rtt: Duration) {
self.latest = rtt;
// min_rtt ignores ack delay.
self.min = cmp::min(self.min, self.latest);
// Based on RFC6298.
if let Some(smoothed) = self.smoothed {
let adjusted_rtt = if self.min + ack_delay <= self.latest {
self.latest - ack_delay
} else {
self.latest
};
let var_sample = if smoothed > adjusted_rtt {
smoothed - adjusted_rtt
} else {
adjusted_rtt - smoothed
};
self.var = (3 * self.var + var_sample) / 4;
self.smoothed = Some((7 * smoothed + adjusted_rtt) / 8);
} else {
self.smoothed = Some(self.latest);
self.var = self.latest / 2;
self.min = self.latest;
}
}
}
#[derive(Default)]
pub(crate) struct PathResponses {
pending: Vec<PathResponse>,
}
impl PathResponses {
pub(crate) fn push(&mut self, packet: u64, token: u64, remote: SocketAddr) {
/// Arbitrary permissive limit to prevent abuse
const MAX_PATH_RESPONSES: usize = 16;
let response = PathResponse {
packet,
token,
remote,
};
let existing = self.pending.iter_mut().find(|x| x.remote == remote);
if let Some(existing) = existing {
// Update a queued response
if existing.packet <= packet {
*existing = response;
}
return;
}
if self.pending.len() < MAX_PATH_RESPONSES {
self.pending.push(response);
} else {
// We don't expect to ever hit this with well-behaved peers, so we don't bother dropping
// older challenges.
trace!("ignoring excessive PATH_CHALLENGE");
}
}
pub(crate) fn pop_off_path(&mut self, remote: SocketAddr) -> Option<(u64, SocketAddr)> {
let response = *self.pending.last()?;
if response.remote == remote {
// We don't bother searching further because we expect that the on-path response will
// get drained in the immediate future by a call to `pop_on_path`
return None;
}
self.pending.pop();
Some((response.token, response.remote))
}
pub(crate) fn pop_on_path(&mut self, remote: SocketAddr) -> Option<u64> {
let response = *self.pending.last()?;
if response.remote != remote {
// We don't bother searching further because we expect that the off-path response will
// get drained in the immediate future by a call to `pop_off_path`
return None;
}
self.pending.pop();
Some(response.token)
}
pub(crate) fn is_empty(&self) -> bool {
self.pending.is_empty()
}
}
#[derive(Copy, Clone)]
struct PathResponse {
/// The packet number the corresponding PATH_CHALLENGE was received in
packet: u64,
token: u64,
/// The address the corresponding PATH_CHALLENGE was received from
remote: SocketAddr,
}
/// Summary statistics of packets that have been sent on a particular path, but which have not yet
/// been acked or deemed lost
pub(super) struct InFlight {
/// Sum of the sizes of all sent packets considered "in flight" by congestion control
///
/// The size does not include IP or UDP overhead. Packets only containing ACK frames do not
/// count towards this to ensure congestion control does not impede congestion feedback.
pub(super) bytes: u64,
/// Number of packets in flight containing frames other than ACK and PADDING
///
/// This can be 0 even when bytes is not 0 because PADDING frames cause a packet to be
/// considered "in flight" by congestion control. However, if this is nonzero, bytes will always
/// also be nonzero.
pub(super) ack_eliciting: u64,
}
impl InFlight {
fn new() -> Self {
Self {
bytes: 0,
ack_eliciting: 0,
}
}
fn insert(&mut self, packet: &SentPacket) {
self.bytes += u64::from(packet.size);
self.ack_eliciting += u64::from(packet.ack_eliciting);
}
/// Update counters to account for a packet becoming acknowledged, lost, or abandoned
fn remove(&mut self, packet: &SentPacket) {
self.bytes -= u64::from(packet.size);
self.ack_eliciting -= u64::from(packet.ack_eliciting);
}
}

View File

@@ -0,0 +1,190 @@
// Function bodies in this module are regularly cfg'd out
#![allow(unused_variables)]
#[cfg(feature = "qlog")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "qlog")]
use qlog::{
events::{
Event, EventData,
quic::{
PacketHeader, PacketLost, PacketLostTrigger, PacketReceived, PacketSent, PacketType,
},
},
streamer::QlogStreamer,
};
#[cfg(feature = "qlog")]
use tracing::warn;
use crate::{
ConnectionId, Instant,
connection::{PathData, SentPacket},
packet::SpaceId,
};
/// Shareable handle to a single qlog output stream
#[cfg(feature = "qlog")]
#[derive(Clone)]
pub struct QlogStream(pub(crate) Arc<Mutex<QlogStreamer>>);
#[cfg(feature = "qlog")]
impl QlogStream {
fn emit_event(&self, orig_rem_cid: ConnectionId, event: EventData, now: Instant) {
// Time will be overwritten by `add_event_with_instant`
let mut event = Event::with_time(0.0, event);
event.group_id = Some(orig_rem_cid.to_string());
let mut qlog_streamer = self.0.lock().unwrap();
if let Err(e) = qlog_streamer.add_event_with_instant(event, now) {
warn!("could not emit qlog event: {e}");
}
}
}
/// A [`QlogStream`] that may be either dynamically disabled or compiled out entirely
#[derive(Clone, Default)]
pub(crate) struct QlogSink {
#[cfg(feature = "qlog")]
stream: Option<QlogStream>,
}
impl QlogSink {
pub(crate) fn is_enabled(&self) -> bool {
#[cfg(feature = "qlog")]
{
self.stream.is_some()
}
#[cfg(not(feature = "qlog"))]
{
false
}
}
pub(super) fn emit_recovery_metrics(
&self,
pto_count: u32,
path: &mut PathData,
now: Instant,
orig_rem_cid: ConnectionId,
) {
#[cfg(feature = "qlog")]
{
let Some(stream) = self.stream.as_ref() else {
return;
};
let Some(metrics) = path.qlog_recovery_metrics(pto_count) else {
return;
};
stream.emit_event(orig_rem_cid, EventData::MetricsUpdated(metrics), now);
}
}
pub(super) fn emit_packet_lost(
&self,
pn: u64,
info: &SentPacket,
lost_send_time: Instant,
space: SpaceId,
now: Instant,
orig_rem_cid: ConnectionId,
) {
#[cfg(feature = "qlog")]
{
let Some(stream) = self.stream.as_ref() else {
return;
};
let event = PacketLost {
header: Some(PacketHeader {
packet_number: Some(pn),
packet_type: packet_type(space, false),
length: Some(info.size),
..Default::default()
}),
frames: None,
trigger: Some(match info.time_sent <= lost_send_time {
true => PacketLostTrigger::TimeThreshold,
false => PacketLostTrigger::ReorderingThreshold,
}),
};
stream.emit_event(orig_rem_cid, EventData::PacketLost(event), now);
}
}
pub(super) fn emit_packet_sent(
&self,
pn: u64,
len: usize,
space: SpaceId,
is_0rtt: bool,
now: Instant,
orig_rem_cid: ConnectionId,
) {
#[cfg(feature = "qlog")]
{
let Some(stream) = self.stream.as_ref() else {
return;
};
let event = PacketSent {
header: PacketHeader {
packet_number: Some(pn),
packet_type: packet_type(space, is_0rtt),
length: Some(len as u16),
..Default::default()
},
..Default::default()
};
stream.emit_event(orig_rem_cid, EventData::PacketSent(event), now);
}
}
pub(super) fn emit_packet_received(
&self,
pn: u64,
space: SpaceId,
is_0rtt: bool,
now: Instant,
orig_rem_cid: ConnectionId,
) {
#[cfg(feature = "qlog")]
{
let Some(stream) = self.stream.as_ref() else {
return;
};
let event = PacketReceived {
header: PacketHeader {
packet_number: Some(pn),
packet_type: packet_type(space, is_0rtt),
..Default::default()
},
..Default::default()
};
stream.emit_event(orig_rem_cid, EventData::PacketReceived(event), now);
}
}
}
#[cfg(feature = "qlog")]
impl From<Option<QlogStream>> for QlogSink {
fn from(stream: Option<QlogStream>) -> Self {
Self { stream }
}
}
#[cfg(feature = "qlog")]
fn packet_type(space: SpaceId, is_0rtt: bool) -> PacketType {
match space {
SpaceId::Initial => PacketType::Initial,
SpaceId::Handshake => PacketType::Handshake,
SpaceId::Data if is_0rtt => PacketType::ZeroRtt,
SpaceId::Data => PacketType::OneRtt,
}
}

View File

@@ -0,0 +1,394 @@
use std::{collections::VecDeque, ops::Range};
use bytes::{Buf, Bytes};
use crate::{VarInt, range_set::RangeSet};
/// Buffer of outgoing retransmittable stream data
#[derive(Default, Debug)]
pub(super) struct SendBuffer {
/// Data queued by the application but not yet acknowledged. May or may not have been sent.
unacked_segments: VecDeque<Bytes>,
/// Total size of `unacked_segments`
unacked_len: usize,
/// The first offset that hasn't been written by the application, i.e. the offset past the end of `unacked`
offset: u64,
/// The first offset that hasn't been sent
///
/// Always lies in (offset - unacked.len())..offset
unsent: u64,
/// Acknowledged ranges which couldn't be discarded yet as they don't include the earliest
/// offset in `unacked`
// TODO: Recover storage from these by compacting (#700)
acks: RangeSet,
/// Previously transmitted ranges deemed lost
retransmits: RangeSet,
}
impl SendBuffer {
/// Construct an empty buffer at the initial offset
pub(super) fn new() -> Self {
Self::default()
}
/// Append application data to the end of the stream
pub(super) fn write(&mut self, data: Bytes) {
self.unacked_len += data.len();
self.offset += data.len() as u64;
self.unacked_segments.push_back(data);
}
/// Discard a range of acknowledged stream data
pub(super) fn ack(&mut self, mut range: Range<u64>) {
// Clamp the range to data which is still tracked
let base_offset = self.offset - self.unacked_len as u64;
range.start = base_offset.max(range.start);
range.end = base_offset.max(range.end);
self.acks.insert(range);
while self.acks.min() == Some(self.offset - self.unacked_len as u64) {
let prefix = self.acks.pop_min().unwrap();
let mut to_advance = (prefix.end - prefix.start) as usize;
self.unacked_len -= to_advance;
while to_advance > 0 {
let front = self
.unacked_segments
.front_mut()
.expect("Expected buffered data");
if front.len() <= to_advance {
to_advance -= front.len();
self.unacked_segments.pop_front();
if self.unacked_segments.len() * 4 < self.unacked_segments.capacity() {
self.unacked_segments.shrink_to_fit();
}
} else {
front.advance(to_advance);
to_advance = 0;
}
}
}
}
/// Compute the next range to transmit on this stream and update state to account for that
/// transmission.
///
/// `max_len` here includes the space which is available to transmit the
/// offset and length of the data to send. The caller has to guarantee that
/// there is at least enough space available to write maximum-sized metadata
/// (8 byte offset + 8 byte length).
///
/// The method returns a tuple:
/// - The first return value indicates the range of data to send
/// - The second return value indicates whether the length needs to be encoded
/// in the STREAM frames metadata (`true`), or whether it can be omitted
/// since the selected range will fill the whole packet.
pub(super) fn poll_transmit(&mut self, mut max_len: usize) -> (Range<u64>, bool) {
debug_assert!(max_len >= 8 + 8);
let mut encode_length = false;
if let Some(range) = self.retransmits.pop_min() {
// Retransmit sent data
// When the offset is known, we know how many bytes are required to encode it.
// Offset 0 requires no space
if range.start != 0 {
max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) });
}
if range.end - range.start < max_len as u64 {
encode_length = true;
max_len -= 8;
}
let end = range.end.min((max_len as u64).saturating_add(range.start));
if end != range.end {
self.retransmits.insert(end..range.end);
}
return (range.start..end, encode_length);
}
// Transmit new data
// When the offset is known, we know how many bytes are required to encode it.
// Offset 0 requires no space
if self.unsent != 0 {
max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) });
}
if self.offset - self.unsent < max_len as u64 {
encode_length = true;
max_len -= 8;
}
let end = self
.offset
.min((max_len as u64).saturating_add(self.unsent));
let result = self.unsent..end;
self.unsent = end;
(result, encode_length)
}
/// Returns data which is associated with a range
///
/// This function can return a subset of the range, if the data is stored
/// in noncontiguous fashion in the send buffer. In this case callers
/// should call the function again with an incremented start offset to
/// retrieve more data.
pub(super) fn get(&self, offsets: Range<u64>) -> &[u8] {
let base_offset = self.offset - self.unacked_len as u64;
let mut segment_offset = base_offset;
for segment in self.unacked_segments.iter() {
if offsets.start >= segment_offset
&& offsets.start < segment_offset + segment.len() as u64
{
let start = (offsets.start - segment_offset) as usize;
let end = (offsets.end - segment_offset) as usize;
return &segment[start..end.min(segment.len())];
}
segment_offset += segment.len() as u64;
}
&[]
}
/// Queue a range of sent but unacknowledged data to be retransmitted
pub(super) fn retransmit(&mut self, range: Range<u64>) {
debug_assert!(range.end <= self.unsent, "unsent data can't be lost");
self.retransmits.insert(range);
}
pub(super) fn retransmit_all_for_0rtt(&mut self) {
debug_assert_eq!(self.offset, self.unacked_len as u64);
self.unsent = 0;
}
/// First stream offset unwritten by the application, i.e. the offset that the next write will
/// begin at
pub(super) fn offset(&self) -> u64 {
self.offset
}
/// Whether all sent data has been acknowledged
pub(super) fn is_fully_acked(&self) -> bool {
self.unacked_len == 0
}
/// Whether there's data to send
///
/// There may be sent unacknowledged data even when this is false.
pub(super) fn has_unsent_data(&self) -> bool {
self.unsent != self.offset || !self.retransmits.is_empty()
}
/// Compute the amount of data that hasn't been acknowledged
pub(super) fn unacked(&self) -> u64 {
self.unacked_len as u64 - self.acks.iter().map(|x| x.end - x.start).sum::<u64>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fragment_with_length() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world!";
buf.write(MSG.into());
// 0 byte offset => 19 bytes left => 13 byte data isn't enough
// with 8 bytes reserved for length 11 payload bytes will fit
assert_eq!(buf.poll_transmit(19), (0..11, true));
assert_eq!(
buf.poll_transmit(MSG.len() + 16 - 11),
(11..MSG.len() as u64, true)
);
assert_eq!(
buf.poll_transmit(58),
(MSG.len() as u64..MSG.len() as u64, true)
);
}
#[test]
fn fragment_without_length() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world with some extra data!";
buf.write(MSG.into());
// 0 byte offset => 19 bytes left => can be filled by 34 bytes payload
assert_eq!(buf.poll_transmit(19), (0..19, false));
assert_eq!(
buf.poll_transmit(MSG.len() - 19 + 1),
(19..MSG.len() as u64, false)
);
assert_eq!(
buf.poll_transmit(58),
(MSG.len() as u64..MSG.len() as u64, true)
);
}
#[test]
fn reserves_encoded_offset() {
let mut buf = SendBuffer::new();
// Pretend we have more than 1 GB of data in the buffer
let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]);
for _ in 0..1025 {
buf.write(chunk.clone());
}
const SIZE1: u64 = 64;
const SIZE2: u64 = 16 * 1024;
const SIZE3: u64 = 1024 * 1024 * 1024;
// Offset 0 requires no space
assert_eq!(buf.poll_transmit(16), (0..16, false));
buf.retransmit(0..16);
assert_eq!(buf.poll_transmit(16), (0..16, false));
let mut transmitted = 16u64;
// Offset 16 requires 1 byte
assert_eq!(
buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
(transmitted..SIZE1, false)
);
buf.retransmit(transmitted..SIZE1);
assert_eq!(
buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
(transmitted..SIZE1, false)
);
transmitted = SIZE1;
// Offset 64 requires 2 bytes
assert_eq!(
buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
(transmitted..SIZE2, false)
);
buf.retransmit(transmitted..SIZE2);
assert_eq!(
buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
(transmitted..SIZE2, false)
);
transmitted = SIZE2;
// Offset 16384 requires requires 4 bytes
assert_eq!(
buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
(transmitted..SIZE3, false)
);
buf.retransmit(transmitted..SIZE3);
assert_eq!(
buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
(transmitted..SIZE3, false)
);
transmitted = SIZE3;
// Offset 1GB requires 8 bytes
assert_eq!(
buf.poll_transmit(chunk.len() + 8),
(transmitted..transmitted + chunk.len() as u64, false)
);
buf.retransmit(transmitted..transmitted + chunk.len() as u64);
assert_eq!(
buf.poll_transmit(chunk.len() + 8),
(transmitted..transmitted + chunk.len() as u64, false)
);
}
#[test]
fn multiple_segments() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world!";
const MSG_LEN: u64 = MSG.len() as u64;
const SEG1: &[u8] = b"He";
buf.write(SEG1.into());
const SEG2: &[u8] = b"llo,";
buf.write(SEG2.into());
const SEG3: &[u8] = b" w";
buf.write(SEG3.into());
const SEG4: &[u8] = b"o";
buf.write(SEG4.into());
const SEG5: &[u8] = b"rld!";
buf.write(SEG5.into());
assert_eq!(aggregate_unacked(&buf), MSG);
assert_eq!(buf.poll_transmit(16), (0..8, true));
assert_eq!(buf.get(0..5), SEG1);
assert_eq!(buf.get(2..8), SEG2);
assert_eq!(buf.get(6..8), SEG3);
assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true));
assert_eq!(buf.get(8..MSG_LEN), SEG4);
assert_eq!(buf.get(9..MSG_LEN), SEG5);
assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true));
// Now drain the segments
buf.ack(0..1);
assert_eq!(aggregate_unacked(&buf), &MSG[1..]);
buf.ack(0..3);
assert_eq!(aggregate_unacked(&buf), &MSG[3..]);
buf.ack(3..5);
assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
buf.ack(7..9);
assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
buf.ack(4..7);
assert_eq!(aggregate_unacked(&buf), &MSG[9..]);
buf.ack(0..MSG_LEN);
assert_eq!(aggregate_unacked(&buf), &[] as &[u8]);
}
#[test]
fn retransmit() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world with extra data!";
buf.write(MSG.into());
// Transmit two frames
assert_eq!(buf.poll_transmit(16), (0..16, false));
assert_eq!(buf.poll_transmit(16), (16..23, true));
// Lose the first, but not the second
buf.retransmit(0..16);
// Ensure we only retransmit the lost frame, then continue sending fresh data
assert_eq!(buf.poll_transmit(16), (0..16, false));
assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true));
// Lose the second frame
buf.retransmit(16..23);
assert_eq!(buf.poll_transmit(16), (16..23, true));
}
#[test]
fn ack() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world!";
buf.write(MSG.into());
assert_eq!(buf.poll_transmit(16), (0..8, true));
buf.ack(0..8);
assert_eq!(aggregate_unacked(&buf), &MSG[8..]);
}
#[test]
fn reordered_ack() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world with extra data!";
buf.write(MSG.into());
assert_eq!(buf.poll_transmit(16), (0..16, false));
assert_eq!(buf.poll_transmit(16), (16..23, true));
buf.ack(16..23);
assert_eq!(aggregate_unacked(&buf), MSG);
buf.ack(0..16);
assert_eq!(aggregate_unacked(&buf), &MSG[23..]);
assert!(buf.acks.is_empty());
}
fn aggregate_unacked(buf: &SendBuffer) -> Vec<u8> {
let mut result = Vec::new();
for segment in buf.unacked_segments.iter() {
result.extend_from_slice(&segment[..]);
}
result
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,174 @@
//! Connection statistics
use crate::{Dir, Duration, frame::Frame};
/// Statistics about UDP datagrams transmitted or received on a connection
///
/// All QUIC packets are carried by UDP datagrams. Hence, these statistics cover all traffic on a connection.
#[derive(Default, Debug, Copy, Clone)]
#[non_exhaustive]
pub struct UdpStats {
/// The amount of UDP datagrams observed
pub datagrams: u64,
/// The total amount of bytes which have been transferred inside UDP datagrams
pub bytes: u64,
/// The amount of I/O operations executed
///
/// Can be less than `datagrams` when GSO, GRO, and/or batched system calls are in use.
pub ios: u64,
}
impl UdpStats {
pub(crate) fn on_sent(&mut self, datagrams: u64, bytes: usize) {
self.datagrams += datagrams;
self.bytes += bytes as u64;
self.ios += 1;
}
}
/// Number of frames transmitted or received of each frame type
#[derive(Default, Copy, Clone)]
#[non_exhaustive]
#[allow(missing_docs)]
pub struct FrameStats {
pub acks: u64,
pub ack_frequency: u64,
pub crypto: u64,
pub connection_close: u64,
pub data_blocked: u64,
pub datagram: u64,
pub handshake_done: u8,
pub immediate_ack: u64,
pub max_data: u64,
pub max_stream_data: u64,
pub max_streams_bidi: u64,
pub max_streams_uni: u64,
pub new_connection_id: u64,
pub new_token: u64,
pub path_challenge: u64,
pub path_response: u64,
pub ping: u64,
pub reset_stream: u64,
pub retire_connection_id: u64,
pub stream_data_blocked: u64,
pub streams_blocked_bidi: u64,
pub streams_blocked_uni: u64,
pub stop_sending: u64,
pub stream: u64,
}
impl FrameStats {
pub(crate) fn record(&mut self, frame: &Frame) {
match frame {
Frame::Padding => {}
Frame::Ping => self.ping += 1,
Frame::Ack(_) => self.acks += 1,
Frame::ResetStream(_) => self.reset_stream += 1,
Frame::StopSending(_) => self.stop_sending += 1,
Frame::Crypto(_) => self.crypto += 1,
Frame::Datagram(_) => self.datagram += 1,
Frame::NewToken(_) => self.new_token += 1,
Frame::MaxData(_) => self.max_data += 1,
Frame::MaxStreamData { .. } => self.max_stream_data += 1,
Frame::MaxStreams { dir, .. } => {
if *dir == Dir::Bi {
self.max_streams_bidi += 1;
} else {
self.max_streams_uni += 1;
}
}
Frame::DataBlocked { .. } => self.data_blocked += 1,
Frame::Stream(_) => self.stream += 1,
Frame::StreamDataBlocked { .. } => self.stream_data_blocked += 1,
Frame::StreamsBlocked { dir, .. } => {
if *dir == Dir::Bi {
self.streams_blocked_bidi += 1;
} else {
self.streams_blocked_uni += 1;
}
}
Frame::NewConnectionId(_) => self.new_connection_id += 1,
Frame::RetireConnectionId { .. } => self.retire_connection_id += 1,
Frame::PathChallenge(_) => self.path_challenge += 1,
Frame::PathResponse(_) => self.path_response += 1,
Frame::Close(_) => self.connection_close += 1,
Frame::AckFrequency(_) => self.ack_frequency += 1,
Frame::ImmediateAck => self.immediate_ack += 1,
Frame::HandshakeDone => self.handshake_done = self.handshake_done.saturating_add(1),
}
}
}
impl std::fmt::Debug for FrameStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrameStats")
.field("ACK", &self.acks)
.field("ACK_FREQUENCY", &self.ack_frequency)
.field("CONNECTION_CLOSE", &self.connection_close)
.field("CRYPTO", &self.crypto)
.field("DATA_BLOCKED", &self.data_blocked)
.field("DATAGRAM", &self.datagram)
.field("HANDSHAKE_DONE", &self.handshake_done)
.field("IMMEDIATE_ACK", &self.immediate_ack)
.field("MAX_DATA", &self.max_data)
.field("MAX_STREAM_DATA", &self.max_stream_data)
.field("MAX_STREAMS_BIDI", &self.max_streams_bidi)
.field("MAX_STREAMS_UNI", &self.max_streams_uni)
.field("NEW_CONNECTION_ID", &self.new_connection_id)
.field("NEW_TOKEN", &self.new_token)
.field("PATH_CHALLENGE", &self.path_challenge)
.field("PATH_RESPONSE", &self.path_response)
.field("PING", &self.ping)
.field("RESET_STREAM", &self.reset_stream)
.field("RETIRE_CONNECTION_ID", &self.retire_connection_id)
.field("STREAM_DATA_BLOCKED", &self.stream_data_blocked)
.field("STREAMS_BLOCKED_BIDI", &self.streams_blocked_bidi)
.field("STREAMS_BLOCKED_UNI", &self.streams_blocked_uni)
.field("STOP_SENDING", &self.stop_sending)
.field("STREAM", &self.stream)
.finish()
}
}
/// Statistics related to a transmission path
#[derive(Debug, Default, Copy, Clone)]
#[non_exhaustive]
pub struct PathStats {
/// Current best estimate of this connection's latency (round-trip-time)
pub rtt: Duration,
/// Current congestion window of the connection
pub cwnd: u64,
/// Congestion events on the connection
pub congestion_events: u64,
/// The amount of packets lost on this path
pub lost_packets: u64,
/// The amount of bytes lost on this path
pub lost_bytes: u64,
/// The amount of packets sent on this path
pub sent_packets: u64,
/// The amount of PLPMTUD probe packets sent on this path (also counted by `sent_packets`)
pub sent_plpmtud_probes: u64,
/// The amount of PLPMTUD probe packets lost on this path (ignored by `lost_packets` and
/// `lost_bytes`)
pub lost_plpmtud_probes: u64,
/// The number of times a black hole was detected in the path
pub black_holes_detected: u64,
/// Largest UDP payload size the path currently supports
pub current_mtu: u16,
}
/// Connection statistics
#[derive(Debug, Default, Copy, Clone)]
#[non_exhaustive]
pub struct ConnectionStats {
/// Statistics about UDP datagrams transmitted on a connection
pub udp_tx: UdpStats,
/// Statistics about UDP datagrams received on a connection
pub udp_rx: UdpStats,
/// Statistics about frames transmitted on a connection
pub frame_tx: FrameStats,
/// Statistics about frames received on a connection
pub frame_rx: FrameStats,
/// Statistics related to the current transmission path
pub path: PathStats,
}

View File

@@ -0,0 +1,528 @@
use std::{
collections::{BinaryHeap, hash_map},
io,
};
use bytes::Bytes;
use thiserror::Error;
use tracing::trace;
use super::spaces::{Retransmits, ThinRetransmits};
use crate::{
Dir, StreamId, VarInt,
connection::streams::state::{get_or_insert_recv, get_or_insert_send},
frame,
};
mod recv;
use recv::Recv;
pub use recv::{Chunks, ReadError, ReadableError};
mod send;
pub(crate) use send::{ByteSlice, BytesArray};
use send::{BytesSource, Send, SendState};
pub use send::{FinishError, WriteError, Written};
mod state;
#[allow(unreachable_pub)] // fuzzing only
pub use state::StreamsState;
/// Access to streams
pub struct Streams<'a> {
pub(super) state: &'a mut StreamsState,
pub(super) conn_state: &'a super::State,
}
#[allow(clippy::needless_lifetimes)] // Needed for cfg(fuzzing)
impl<'a> Streams<'a> {
#[cfg(fuzzing)]
pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self {
Self { state, conn_state }
}
/// Open a single stream if possible
///
/// Returns `None` if the streams in the given direction are currently exhausted.
pub fn open(&mut self, dir: Dir) -> Option<StreamId> {
if self.conn_state.is_closed() {
return None;
}
// TODO: Queue STREAM_ID_BLOCKED if this fails
if self.state.next[dir as usize] >= self.state.max[dir as usize] {
return None;
}
self.state.next[dir as usize] += 1;
let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1);
self.state.insert(false, id);
self.state.send_streams += 1;
Some(id)
}
/// Accept a remotely initiated stream of a certain directionality, if possible
///
/// Returns `None` if there are no new incoming streams for this connection.
/// Has no impact on the data flow-control or stream concurrency limits.
pub fn accept(&mut self, dir: Dir) -> Option<StreamId> {
if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] {
return None;
}
let x = self.state.next_reported_remote[dir as usize];
self.state.next_reported_remote[dir as usize] = x + 1;
if dir == Dir::Bi {
self.state.send_streams += 1;
}
Some(StreamId::new(!self.state.side, dir, x))
}
#[cfg(fuzzing)]
pub fn state(&mut self) -> &mut StreamsState {
self.state
}
/// The number of streams that may have unacknowledged data.
pub fn send_streams(&self) -> usize {
self.state.send_streams
}
/// The number of remotely initiated open streams of a certain directionality.
///
/// Includes remotely initiated streams, which have not been accepted via [`accept`](Self::accept).
/// These streams count against the respective concurrency limit reported by
/// [`Connection::max_concurrent_streams`](super::Connection::max_concurrent_streams).
pub fn remote_open_streams(&self, dir: Dir) -> u64 {
// total opened - total closed = total opened - ( total permitted - total permitted unclosed )
self.state.next_remote[dir as usize]
- (self.state.max_remote[dir as usize]
- self.state.allocated_remote_count[dir as usize])
}
}
/// Access to streams
pub struct RecvStream<'a> {
pub(super) id: StreamId,
pub(super) state: &'a mut StreamsState,
pub(super) pending: &'a mut Retransmits,
}
impl RecvStream<'_> {
/// Read from the given recv stream
///
/// `max_length` limits the maximum size of the returned `Bytes` value; passing `usize::MAX`
/// will yield the best performance. `ordered` will make sure the returned chunk's offset will
/// have an offset exactly equal to the previously returned offset plus the previously returned
/// bytes' length.
///
/// Yields `Ok(None)` if the stream was finished. Otherwise, yields a segment of data and its
/// offset in the stream. If `ordered` is `false`, segments may be received in any order, and
/// the `Chunk`'s `offset` field can be used to determine ordering in the caller.
///
/// While most applications will prefer to consume stream data in order, unordered reads can
/// improve performance when packet loss occurs and data cannot be retransmitted before the flow
/// control window is filled. On any given stream, you can switch from ordered to unordered
/// reads, but ordered reads on streams that have seen previous unordered reads will return
/// `ReadError::IllegalOrderedRead`.
pub fn read(&mut self, ordered: bool) -> Result<Chunks<'_>, ReadableError> {
Chunks::new(self.id, ordered, self.state, self.pending)
}
/// Stop accepting data on the given receive stream
///
/// Discards unread data and notifies the peer to stop transmitting. Once stopped, further
/// attempts to operate on a stream will yield `ClosedStream` errors.
pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let mut entry = match self.state.recv.entry(self.id) {
hash_map::Entry::Occupied(s) => s,
hash_map::Entry::Vacant(_) => return Err(ClosedStream { _private: () }),
};
let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut());
let (read_credits, stop_sending) = stream.stop()?;
if stop_sending.should_transmit() {
self.pending.stop_sending.push(frame::StopSending {
id: self.id,
error_code,
});
}
// We need to keep stopped streams around until they're finished or reset so we can update
// connection-level flow control to account for discarded data. Otherwise, we can discard
// state immediately.
if !stream.final_offset_unknown() {
let recv = entry.remove().expect("must have recv when stopping");
self.state.stream_recv_freed(self.id, recv);
}
if self.state.add_read_credits(read_credits).should_transmit() {
self.pending.max_data = true;
}
Ok(())
}
/// Check whether this stream has been reset by the peer, returning the reset error code if so
///
/// After returning `Ok(Some(_))` once, stream state will be discarded and all future calls will
/// return `Err(ClosedStream)`.
pub fn received_reset(&mut self) -> Result<Option<VarInt>, ClosedStream> {
let hash_map::Entry::Occupied(entry) = self.state.recv.entry(self.id) else {
return Err(ClosedStream { _private: () });
};
let Some(s) = entry.get().as_ref().and_then(|s| s.as_open_recv()) else {
return Ok(None);
};
if s.stopped {
return Err(ClosedStream { _private: () });
}
let Some(code) = s.reset_code() else {
return Ok(None);
};
// Clean up state after application observes the reset, since there's no reason for the
// application to attempt to read or stop the stream once it knows it's reset
let (_, recv) = entry.remove_entry();
self.state
.stream_recv_freed(self.id, recv.expect("must have recv on reset"));
self.state.queue_max_stream_id(self.pending);
Ok(Some(code))
}
}
/// Access to streams
pub struct SendStream<'a> {
pub(super) id: StreamId,
pub(super) state: &'a mut StreamsState,
pub(super) pending: &'a mut Retransmits,
pub(super) conn_state: &'a super::State,
}
#[allow(clippy::needless_lifetimes)] // Needed for cfg(fuzzing)
impl<'a> SendStream<'a> {
#[cfg(fuzzing)]
pub fn new(
id: StreamId,
state: &'a mut StreamsState,
pending: &'a mut Retransmits,
conn_state: &'a super::State,
) -> Self {
Self {
id,
state,
pending,
conn_state,
}
}
/// Send data on the given stream
///
/// Returns the number of bytes successfully written.
pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
}
/// Send data on the given stream
///
/// Returns the number of bytes and chunks successfully written.
/// Note that this method might also write a partial chunk. In this case
/// [`Written::chunks`] will not count this chunk as fully written. However
/// the chunk will be advanced and contain only non-written data after the call.
pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
self.write_source(&mut BytesArray::from_chunks(data))
}
fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
if self.conn_state.is_closed() {
trace!(%self.id, "write blocked; connection draining");
return Err(WriteError::Blocked);
}
let limit = self.state.write_limit();
let max_send_data = self.state.max_send_data(self.id);
let stream = self
.state
.send
.get_mut(&self.id)
.map(get_or_insert_send(max_send_data))
.ok_or(WriteError::ClosedStream)?;
if limit == 0 {
trace!(
stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent,
"write blocked by connection-level flow control or send window"
);
if !stream.connection_blocked {
stream.connection_blocked = true;
self.state.connection_blocked.push(self.id);
}
return Err(WriteError::Blocked);
}
let was_pending = stream.is_pending();
let written = stream.write(source, limit)?;
self.state.data_sent += written.bytes as u64;
self.state.unacked_data += written.bytes as u64;
trace!(stream = %self.id, "wrote {} bytes", written.bytes);
if !was_pending {
self.state.pending.push_pending(self.id, stream.priority);
}
Ok(written)
}
/// Check if this stream was stopped, get the reason if it was
pub fn stopped(&self) -> Result<Option<VarInt>, ClosedStream> {
match self.state.send.get(&self.id).as_ref() {
Some(Some(s)) => Ok(s.stop_reason),
Some(None) => Ok(None),
None => Err(ClosedStream { _private: () }),
}
}
/// Finish a send stream, signalling that no more data will be sent.
///
/// If this fails, no [`StreamEvent::Finished`] will be generated.
///
/// [`StreamEvent::Finished`]: crate::StreamEvent::Finished
pub fn finish(&mut self) -> Result<(), FinishError> {
let max_send_data = self.state.max_send_data(self.id);
let stream = self
.state
.send
.get_mut(&self.id)
.map(get_or_insert_send(max_send_data))
.ok_or(FinishError::ClosedStream)?;
let was_pending = stream.is_pending();
stream.finish()?;
if !was_pending {
self.state.pending.push_pending(self.id, stream.priority);
}
Ok(())
}
/// Abandon transmitting data on a stream
///
/// # Panics
/// - when applied to a receive stream
pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let max_send_data = self.state.max_send_data(self.id);
let stream = self
.state
.send
.get_mut(&self.id)
.map(get_or_insert_send(max_send_data))
.ok_or(ClosedStream { _private: () })?;
if matches!(stream.state, SendState::ResetSent) {
// Redundant reset call
return Err(ClosedStream { _private: () });
}
// Restore the portion of the send window consumed by the data that we aren't about to
// send. We leave flow control alone because the peer's responsible for issuing additional
// credit based on the final offset communicated in the RESET_STREAM frame we send.
self.state.unacked_data -= stream.pending.unacked();
stream.reset();
self.pending.reset_stream.push((self.id, error_code));
// Don't reopen an already-closed stream we haven't forgotten yet
Ok(())
}
/// Set the priority of a stream
///
/// # Panics
/// - when applied to a receive stream
pub fn set_priority(&mut self, priority: i32) -> Result<(), ClosedStream> {
let max_send_data = self.state.max_send_data(self.id);
let stream = self
.state
.send
.get_mut(&self.id)
.map(get_or_insert_send(max_send_data))
.ok_or(ClosedStream { _private: () })?;
stream.priority = priority;
Ok(())
}
/// Get the priority of a stream
///
/// # Panics
/// - when applied to a receive stream
pub fn priority(&self) -> Result<i32, ClosedStream> {
let stream = self
.state
.send
.get(&self.id)
.ok_or(ClosedStream { _private: () })?;
Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default())
}
}
/// A queue of streams with pending outgoing data, sorted by priority
struct PendingStreamsQueue {
streams: BinaryHeap<PendingStream>,
/// The next stream to write out. This is `Some` when `TransportConfig::send_fairness(false)` and writing a stream is
/// interrupted while the stream still has some pending data. See `reinsert_pending()`.
next: Option<PendingStream>,
/// A monotonically decreasing counter, used to implement round-robin scheduling for streams of the same priority.
/// Underflowing is not a practical concern, as it is initialized to u64::MAX and only decremented by 1 in `push_pending`
recency: u64,
}
impl PendingStreamsQueue {
fn new() -> Self {
Self {
streams: BinaryHeap::new(),
next: None,
recency: u64::MAX,
}
}
/// Reinsert a stream that was pending and still contains unsent data.
fn reinsert_pending(&mut self, id: StreamId, priority: i32) {
assert!(self.next.is_none());
self.next = Some(PendingStream {
priority,
recency: self.recency, // the value here doesn't really matter
id,
});
}
/// Push a pending stream ID with the given priority, queued after any already-queued streams for the priority
fn push_pending(&mut self, id: StreamId, priority: i32) {
// Note that in the case where fairness is disabled, if we have a reinserted stream we don't
// bump it even if priority > next.priority. In order to minimize fragmentation we
// always try to complete a stream once part of it has been written.
// As the recency counter is monotonically decreasing, we know that using its value to sort this stream will queue it
// after all other queued streams of the same priority.
// This is enough to implement round-robin scheduling for streams that are still pending even after being handled,
// as in that case they are removed from the `BinaryHeap`, handled, and then immediately reinserted.
self.recency -= 1;
self.streams.push(PendingStream {
priority,
recency: self.recency,
id,
});
}
fn pop(&mut self) -> Option<PendingStream> {
self.next.take().or_else(|| self.streams.pop())
}
fn clear(&mut self) {
self.next = None;
self.streams.clear();
}
fn iter(&self) -> impl Iterator<Item = &PendingStream> {
self.next.iter().chain(self.streams.iter())
}
#[cfg(test)]
fn len(&self) -> usize {
self.streams.len() + self.next.is_some() as usize
}
}
/// The [`StreamId`] of a stream with pending data queued, ordered by its priority and recency
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
struct PendingStream {
/// The priority of the stream
// Note that this field should be kept above the `recency` field, in order for the `Ord` derive to be correct
// (See https://doc.rust-lang.org/stable/std/cmp/trait.Ord.html#derivable)
priority: i32,
/// A tie-breaker for streams of the same priority, used to improve fairness by implementing round-robin scheduling:
/// Larger values are prioritized, so it is initialised to `u64::MAX`, and when a stream writes data, we know
/// that it currently has the highest recency value, so it is deprioritized by setting its recency to 1 less than the
/// previous lowest recency value, such that all other streams of this priority will get processed once before we get back
/// round to this one
recency: u64,
/// The ID of the stream
// The way this type is used ensures that every instance has a unique `recency` value, so this field should be kept below
// the `priority` and `recency` fields, so that it does not interfere with the behaviour of the `Ord` derive
id: StreamId,
}
/// Application events about streams
#[derive(Debug, PartialEq, Eq)]
pub enum StreamEvent {
/// One or more new streams has been opened and might be readable
Opened {
/// Directionality for which streams have been opened
dir: Dir,
},
/// A currently open stream likely has data or errors waiting to be read
Readable {
/// Which stream is now readable
id: StreamId,
},
/// A formerly write-blocked stream might be ready for a write or have been stopped
///
/// Only generated for streams that are currently open.
Writable {
/// Which stream is now writable
id: StreamId,
},
/// A finished stream has been fully acknowledged or stopped
Finished {
/// Which stream has been finished
id: StreamId,
},
/// The peer asked us to stop sending on an outgoing stream
Stopped {
/// Which stream has been stopped
id: StreamId,
/// Error code supplied by the peer
error_code: VarInt,
},
/// At least one new stream of a certain directionality may be opened
Available {
/// Directionality for which streams are newly available
dir: Dir,
},
}
/// Indicates whether a frame needs to be transmitted
///
/// This type wraps around bool and uses the `#[must_use]` attribute in order
/// to prevent accidental loss of the frame transmission requirement.
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
#[must_use = "A frame might need to be enqueued"]
pub struct ShouldTransmit(bool);
impl ShouldTransmit {
/// Returns whether a frame should be transmitted
pub fn should_transmit(self) -> bool {
self.0
}
}
/// Error indicating that a stream has not been opened or has already been finished or reset
#[derive(Debug, Default, Error, Clone, PartialEq, Eq)]
#[error("closed stream")]
pub struct ClosedStream {
_private: (),
}
impl From<ClosedStream> for io::Error {
fn from(x: ClosedStream) -> Self {
Self::new(io::ErrorKind::NotConnected, x)
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum StreamHalf {
Send,
Recv,
}

View File

@@ -0,0 +1,543 @@
use std::collections::hash_map::Entry;
use std::mem;
use thiserror::Error;
use tracing::debug;
use super::state::get_or_insert_recv;
use super::{ClosedStream, Retransmits, ShouldTransmit, StreamId, StreamsState};
use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead};
use crate::connection::streams::state::StreamRecv;
use crate::{TransportError, VarInt, frame};
#[derive(Debug, Default)]
pub(super) struct Recv {
// NB: when adding or removing fields, remember to update `reinit`.
state: RecvState,
pub(super) assembler: Assembler,
sent_max_stream_data: u64,
pub(super) end: u64,
pub(super) stopped: bool,
}
impl Recv {
pub(super) fn new(initial_max_data: u64) -> Box<Self> {
Box::new(Self {
state: RecvState::default(),
assembler: Assembler::new(),
sent_max_stream_data: initial_max_data,
end: 0,
stopped: false,
})
}
/// Reset to the initial state
pub(super) fn reinit(&mut self, initial_max_data: u64) {
self.state = RecvState::default();
self.assembler.reinit();
self.sent_max_stream_data = initial_max_data;
self.end = 0;
self.stopped = false;
}
/// Process a STREAM frame
///
/// Return value is `(number_of_new_bytes_ingested, stream_is_closed)`
pub(super) fn ingest(
&mut self,
frame: frame::Stream,
payload_len: usize,
received: u64,
max_data: u64,
) -> Result<(u64, bool), TransportError> {
let end = frame.offset + frame.data.len() as u64;
if end >= 2u64.pow(62) {
return Err(TransportError::FLOW_CONTROL_ERROR(
"maximum stream offset too large",
));
}
if let Some(final_offset) = self.final_offset() {
if end > final_offset || (frame.fin && end != final_offset) {
debug!(end, final_offset, "final size error");
return Err(TransportError::FINAL_SIZE_ERROR(""));
}
}
let new_bytes = self.credit_consumed_by(end, received, max_data)?;
// Stopped streams don't need to wait for the actual data, they just need to know
// how much there was.
if frame.fin && !self.stopped {
if let RecvState::Recv { ref mut size } = self.state {
*size = Some(end);
}
}
self.end = self.end.max(end);
// Don't bother storing data or releasing stream-level flow control credit if the stream's
// already stopped
if !self.stopped {
self.assembler.insert(frame.offset, frame.data, payload_len);
}
Ok((new_bytes, frame.fin && self.stopped))
}
pub(super) fn stop(&mut self) -> Result<(u64, ShouldTransmit), ClosedStream> {
if self.stopped {
return Err(ClosedStream { _private: () });
}
self.stopped = true;
self.assembler.clear();
// Issue flow control credit for unread data
let read_credits = self.end - self.assembler.bytes_read();
// This may send a spurious STOP_SENDING if we've already received all data, but it's a bit
// fiddly to distinguish that from the case where we've received a FIN but are missing some
// data that the peer might still be trying to retransmit, in which case a STOP_SENDING is
// still useful.
Ok((read_credits, ShouldTransmit(self.is_receiving())))
}
/// Returns the window that should be advertised in a `MAX_STREAM_DATA` frame
///
/// The method returns a tuple which consists of the window that should be
/// announced, as well as a boolean parameter which indicates if a new
/// transmission of the value is recommended. If the boolean value is
/// `false` the new window should only be transmitted if a previous transmission
/// had failed.
pub(super) fn max_stream_data(&mut self, stream_receive_window: u64) -> (u64, ShouldTransmit) {
let max_stream_data = self.assembler.bytes_read() + stream_receive_window;
// Only announce a window update if it's significant enough
// to make it worthwhile sending a MAX_STREAM_DATA frame.
// We use here a fraction of the configured stream receive window to make
// the decision, and accommodate for streams using bigger windows requiring
// less updates. A fixed size would also work - but it would need to be
// smaller than `stream_receive_window` in order to make sure the stream
// does not get stuck.
let diff = max_stream_data - self.sent_max_stream_data;
let transmit = self.can_send_flow_control() && diff >= (stream_receive_window / 8);
(max_stream_data, ShouldTransmit(transmit))
}
/// Records that a `MAX_STREAM_DATA` announcing a certain window was sent
///
/// This will suppress enqueuing further `MAX_STREAM_DATA` frames unless
/// either the previous transmission was not acknowledged or the window
/// further increased.
pub(super) fn record_sent_max_stream_data(&mut self, sent_value: u64) {
if sent_value > self.sent_max_stream_data {
self.sent_max_stream_data = sent_value;
}
}
/// Whether the total amount of data that the peer will send on this stream is unknown
///
/// True until we've received either a reset or the final frame.
///
/// Implies that the sender might benefit from stream-level flow control updates, and we might
/// need to issue connection-level flow control updates due to flow control budget use by this
/// stream in the future, even if it's been stopped.
pub(super) fn final_offset_unknown(&self) -> bool {
matches!(self.state, RecvState::Recv { size: None })
}
/// Whether stream-level flow control updates should be sent for this stream
pub(super) fn can_send_flow_control(&self) -> bool {
// Stream-level flow control is redundant if the sender has already sent the whole stream,
// and moot if we no longer want data on this stream.
self.final_offset_unknown() && !self.stopped
}
/// Whether data is still being accepted from the peer
pub(super) fn is_receiving(&self) -> bool {
matches!(self.state, RecvState::Recv { .. })
}
fn final_offset(&self) -> Option<u64> {
match self.state {
RecvState::Recv { size } => size,
RecvState::ResetRecvd { size, .. } => Some(size),
}
}
/// Returns `false` iff the reset was redundant
pub(super) fn reset(
&mut self,
error_code: VarInt,
final_offset: VarInt,
received: u64,
max_data: u64,
) -> Result<bool, TransportError> {
// Validate final_offset
if let Some(offset) = self.final_offset() {
if offset != final_offset.into_inner() {
return Err(TransportError::FINAL_SIZE_ERROR("inconsistent value"));
}
} else if self.end > u64::from(final_offset) {
return Err(TransportError::FINAL_SIZE_ERROR(
"lower than high water mark",
));
}
self.credit_consumed_by(final_offset.into(), received, max_data)?;
if matches!(self.state, RecvState::ResetRecvd { .. }) {
return Ok(false);
}
self.state = RecvState::ResetRecvd {
size: final_offset.into(),
error_code,
};
// Nuke buffers so that future reads fail immediately, which ensures future reads don't
// issue flow control credit redundant to that already issued. We could instead special-case
// reset streams during read, but it's unclear if there's any benefit to retaining data for
// reset streams.
self.assembler.clear();
Ok(true)
}
pub(super) fn reset_code(&self) -> Option<VarInt> {
match self.state {
RecvState::ResetRecvd { error_code, .. } => Some(error_code),
_ => None,
}
}
/// Compute the amount of flow control credit consumed, or return an error if more was consumed
/// than issued
fn credit_consumed_by(
&self,
offset: u64,
received: u64,
max_data: u64,
) -> Result<u64, TransportError> {
let prev_end = self.end;
let new_bytes = offset.saturating_sub(prev_end);
if offset > self.sent_max_stream_data || received + new_bytes > max_data {
debug!(
received,
new_bytes,
max_data,
offset,
stream_max_data = self.sent_max_stream_data,
"flow control error"
);
return Err(TransportError::FLOW_CONTROL_ERROR(""));
}
Ok(new_bytes)
}
}
/// Chunks returned from [`RecvStream::read()`][crate::RecvStream::read].
///
/// ### Note: Finalization Needed
/// Bytes read from the stream are not released from the congestion window until
/// either [`Self::finalize()`] is called, or this type is dropped.
///
/// It is recommended that you call [`Self::finalize()`] because it returns a flag
/// telling you whether reading from the stream has resulted in the need to transmit a packet.
///
/// If this type is leaked, the stream will remain blocked on the remote peer until
/// another read from the stream is done.
pub struct Chunks<'a> {
id: StreamId,
ordered: bool,
streams: &'a mut StreamsState,
pending: &'a mut Retransmits,
state: ChunksState,
read: u64,
}
impl<'a> Chunks<'a> {
pub(super) fn new(
id: StreamId,
ordered: bool,
streams: &'a mut StreamsState,
pending: &'a mut Retransmits,
) -> Result<Self, ReadableError> {
let mut entry = match streams.recv.entry(id) {
Entry::Occupied(entry) => entry,
Entry::Vacant(_) => return Err(ReadableError::ClosedStream),
};
let mut recv =
match get_or_insert_recv(streams.stream_receive_window)(entry.get_mut()).stopped {
true => return Err(ReadableError::ClosedStream),
false => entry.remove().unwrap().into_inner(), // this can't fail due to the previous get_or_insert_with
};
recv.assembler.ensure_ordering(ordered)?;
Ok(Self {
id,
ordered,
streams,
pending,
state: ChunksState::Readable(recv),
read: 0,
})
}
/// Next
///
/// Should call finalize() when done calling this.
pub fn next(&mut self, max_length: usize) -> Result<Option<Chunk>, ReadError> {
let rs = match self.state {
ChunksState::Readable(ref mut rs) => rs,
ChunksState::Reset(error_code) => {
return Err(ReadError::Reset(error_code));
}
ChunksState::Finished => {
return Ok(None);
}
ChunksState::Finalized => panic!("must not call next() after finalize()"),
};
if let Some(chunk) = rs.assembler.read(max_length, self.ordered) {
self.read += chunk.bytes.len() as u64;
return Ok(Some(chunk));
}
match rs.state {
RecvState::ResetRecvd { error_code, .. } => {
debug_assert_eq!(self.read, 0, "reset streams have empty buffers");
let state = mem::replace(&mut self.state, ChunksState::Reset(error_code));
// At this point if we have `rs` self.state must be `ChunksState::Readable`
let recv = match state {
ChunksState::Readable(recv) => StreamRecv::Open(recv),
_ => unreachable!("state must be ChunkState::Readable"),
};
self.streams.stream_recv_freed(self.id, recv);
Err(ReadError::Reset(error_code))
}
RecvState::Recv { size } => {
if size == Some(rs.end) && rs.assembler.bytes_read() == rs.end {
let state = mem::replace(&mut self.state, ChunksState::Finished);
// At this point if we have `rs` self.state must be `ChunksState::Readable`
let recv = match state {
ChunksState::Readable(recv) => StreamRecv::Open(recv),
_ => unreachable!("state must be ChunkState::Readable"),
};
self.streams.stream_recv_freed(self.id, recv);
Ok(None)
} else {
// We don't need a distinct `ChunksState` variant for a blocked stream because
// retrying a read harmlessly re-traces our steps back to returning
// `Err(Blocked)` again. The buffers can't refill and the stream's own state
// can't change so long as this `Chunks` exists.
Err(ReadError::Blocked)
}
}
}
}
/// Mark the read data as consumed from the stream.
///
/// The number of read bytes will be released from the congestion window,
/// allowing the remote peer to send more data if it was previously blocked.
///
/// If [`ShouldTransmit::should_transmit()`] returns `true`,
/// a packet needs to be sent to the peer informing them that the stream is unblocked.
/// This means that you should call [`Connection::poll_transmit()`][crate::Connection::poll_transmit]
/// and send the returned packet as soon as is reasonable, to unblock the remote peer.
pub fn finalize(mut self) -> ShouldTransmit {
self.finalize_inner()
}
fn finalize_inner(&mut self) -> ShouldTransmit {
let state = mem::replace(&mut self.state, ChunksState::Finalized);
if let ChunksState::Finalized = state {
// Noop on repeated calls
return ShouldTransmit(false);
}
// We issue additional stream ID credit after the application is notified that a previously
// open stream has finished or been reset and we've therefore disposed of its state, as
// recorded by `stream_freed` calls in `next`.
let mut should_transmit = self.streams.queue_max_stream_id(self.pending);
// If the stream hasn't finished, we may need to issue stream-level flow control credit
if let ChunksState::Readable(mut rs) = state {
let (_, max_stream_data) = rs.max_stream_data(self.streams.stream_receive_window);
should_transmit |= max_stream_data.0;
if max_stream_data.0 {
self.pending.max_stream_data.insert(self.id);
}
// Return the stream to storage for future use
self.streams
.recv
.insert(self.id, Some(StreamRecv::Open(rs)));
}
// Issue connection-level flow control credit for any data we read regardless of state
let max_data = self.streams.add_read_credits(self.read);
self.pending.max_data |= max_data.0;
should_transmit |= max_data.0;
ShouldTransmit(should_transmit)
}
}
impl Drop for Chunks<'_> {
fn drop(&mut self) {
let _ = self.finalize_inner();
}
}
enum ChunksState {
Readable(Box<Recv>),
Reset(VarInt),
Finished,
Finalized,
}
/// Errors triggered when reading from a recv stream
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum ReadError {
/// No more data is currently available on this stream.
///
/// If more data on this stream is received from the peer, an `Event::StreamReadable` will be
/// generated for this stream, indicating that retrying the read might succeed.
#[error("blocked")]
Blocked,
/// The peer abandoned transmitting data on this stream.
///
/// Carries an application-defined error code.
#[error("reset by peer: code {0}")]
Reset(VarInt),
}
/// Errors triggered when opening a recv stream for reading
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum ReadableError {
/// The stream has not been opened or was already stopped, finished, or reset
#[error("closed stream")]
ClosedStream,
/// Attempted an ordered read following an unordered read
///
/// Performing an unordered read allows discontinuities to arise in the receive buffer of a
/// stream which cannot be recovered, making further ordered reads impossible.
#[error("ordered read after unordered read")]
IllegalOrderedRead,
}
impl From<IllegalOrderedRead> for ReadableError {
fn from(_: IllegalOrderedRead) -> Self {
Self::IllegalOrderedRead
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum RecvState {
Recv { size: Option<u64> },
ResetRecvd { size: u64, error_code: VarInt },
}
impl Default for RecvState {
fn default() -> Self {
Self::Recv { size: None }
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use crate::{Dir, Side};
use super::*;
#[test]
fn reordered_frames_while_stopped() {
const INITIAL_BYTES: u64 = 3;
const INITIAL_OFFSET: u64 = 3;
const RECV_WINDOW: u64 = 8;
let mut s = Recv::new(RECV_WINDOW);
let mut data_recvd = 0;
// Receive bytes 3..6
let (new_bytes, is_closed) = s
.ingest(
frame::Stream {
id: StreamId::new(Side::Client, Dir::Uni, 0),
offset: INITIAL_OFFSET,
fin: false,
data: Bytes::from_static(&[0; INITIAL_BYTES as usize]),
},
123,
data_recvd,
data_recvd + 1024,
)
.unwrap();
data_recvd += new_bytes;
assert_eq!(new_bytes, INITIAL_OFFSET + INITIAL_BYTES);
assert!(!is_closed);
let (credits, transmit) = s.stop().unwrap();
assert!(transmit.should_transmit());
assert_eq!(
credits,
INITIAL_OFFSET + INITIAL_BYTES,
"full connection flow control credit is issued by stop"
);
let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW);
assert!(!transmit.should_transmit());
assert_eq!(
max_stream_data, RECV_WINDOW,
"stream flow control credit isn't issued by stop"
);
// Receive byte 7
let (new_bytes, is_closed) = s
.ingest(
frame::Stream {
id: StreamId::new(Side::Client, Dir::Uni, 0),
offset: RECV_WINDOW - 1,
fin: false,
data: Bytes::from_static(&[0; 1]),
},
123,
data_recvd,
data_recvd + 1024,
)
.unwrap();
data_recvd += new_bytes;
assert_eq!(new_bytes, RECV_WINDOW - (INITIAL_OFFSET + INITIAL_BYTES));
assert!(!is_closed);
let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW);
assert!(!transmit.should_transmit());
assert_eq!(
max_stream_data, RECV_WINDOW,
"stream flow control credit isn't issued after stop"
);
// Receive bytes 0..3
let (new_bytes, is_closed) = s
.ingest(
frame::Stream {
id: StreamId::new(Side::Client, Dir::Uni, 0),
offset: 0,
fin: false,
data: Bytes::from_static(&[0; INITIAL_OFFSET as usize]),
},
123,
data_recvd,
data_recvd + 1024,
)
.unwrap();
assert_eq!(
new_bytes, 0,
"reordered frames don't issue connection-level flow control for stopped streams"
);
assert!(!is_closed);
let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW);
assert!(!transmit.should_transmit());
assert_eq!(
max_stream_data, RECV_WINDOW,
"stream flow control credit isn't issued after stop"
);
}
}

View File

@@ -0,0 +1,402 @@
use bytes::Bytes;
use thiserror::Error;
use crate::{VarInt, connection::send_buffer::SendBuffer, frame};
#[derive(Debug)]
pub(super) struct Send {
pub(super) max_data: u64,
pub(super) state: SendState,
pub(super) pending: SendBuffer,
pub(super) priority: i32,
/// Whether a frame containing a FIN bit must be transmitted, even if we don't have any new data
pub(super) fin_pending: bool,
/// Whether this stream is in the `connection_blocked` list of `Streams`
pub(super) connection_blocked: bool,
/// The reason the peer wants us to stop, if `STOP_SENDING` was received
pub(super) stop_reason: Option<VarInt>,
}
impl Send {
pub(super) fn new(max_data: VarInt) -> Box<Self> {
Box::new(Self {
max_data: max_data.into(),
state: SendState::Ready,
pending: SendBuffer::new(),
priority: 0,
fin_pending: false,
connection_blocked: false,
stop_reason: None,
})
}
/// Whether the stream has been reset
pub(super) fn is_reset(&self) -> bool {
matches!(self.state, SendState::ResetSent)
}
pub(super) fn finish(&mut self) -> Result<(), FinishError> {
if let Some(error_code) = self.stop_reason {
Err(FinishError::Stopped(error_code))
} else if self.state == SendState::Ready {
self.state = SendState::DataSent {
finish_acked: false,
};
self.fin_pending = true;
Ok(())
} else {
Err(FinishError::ClosedStream)
}
}
pub(super) fn write<S: BytesSource>(
&mut self,
source: &mut S,
limit: u64,
) -> Result<Written, WriteError> {
if !self.is_writable() {
return Err(WriteError::ClosedStream);
}
if let Some(error_code) = self.stop_reason {
return Err(WriteError::Stopped(error_code));
}
let budget = self.max_data - self.pending.offset();
if budget == 0 {
return Err(WriteError::Blocked);
}
let mut limit = limit.min(budget) as usize;
let mut result = Written::default();
loop {
let (chunk, chunks_consumed) = source.pop_chunk(limit);
result.chunks += chunks_consumed;
result.bytes += chunk.len();
if chunk.is_empty() {
break;
}
limit -= chunk.len();
self.pending.write(chunk);
}
Ok(result)
}
/// Update stream state due to a reset sent by the local application
pub(super) fn reset(&mut self) {
use SendState::*;
if let DataSent { .. } | Ready = self.state {
self.state = ResetSent;
}
}
/// Handle STOP_SENDING
///
/// Returns true if the stream was stopped due to this frame, and false
/// if it had been stopped before
pub(super) fn try_stop(&mut self, error_code: VarInt) -> bool {
if self.stop_reason.is_none() {
self.stop_reason = Some(error_code);
true
} else {
false
}
}
/// Returns whether the stream has been finished and all data has been acknowledged by the peer
pub(super) fn ack(&mut self, frame: frame::StreamMeta) -> bool {
self.pending.ack(frame.offsets);
match self.state {
SendState::DataSent {
ref mut finish_acked,
} => {
*finish_acked |= frame.fin;
*finish_acked && self.pending.is_fully_acked()
}
_ => false,
}
}
/// Handle increase to stream-level flow control limit
///
/// Returns whether the stream was unblocked
pub(super) fn increase_max_data(&mut self, offset: u64) -> bool {
if offset <= self.max_data || self.state != SendState::Ready {
return false;
}
let was_blocked = self.pending.offset() == self.max_data;
self.max_data = offset;
was_blocked
}
pub(super) fn offset(&self) -> u64 {
self.pending.offset()
}
pub(super) fn is_pending(&self) -> bool {
self.pending.has_unsent_data() || self.fin_pending
}
pub(super) fn is_writable(&self) -> bool {
matches!(self.state, SendState::Ready)
}
}
/// A [`BytesSource`] implementation for `&'a mut [Bytes]`
///
/// The type allows to dequeue [`Bytes`] chunks from an array of chunks, up to
/// a configured limit.
pub(crate) struct BytesArray<'a> {
/// The wrapped slice of `Bytes`
chunks: &'a mut [Bytes],
/// The amount of chunks consumed from this source
consumed: usize,
}
impl<'a> BytesArray<'a> {
pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self {
Self {
chunks,
consumed: 0,
}
}
}
impl BytesSource for BytesArray<'_> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
// The loop exists to skip empty chunks while still marking them as
// consumed
let mut chunks_consumed = 0;
while self.consumed < self.chunks.len() {
let chunk = &mut self.chunks[self.consumed];
if chunk.len() <= limit {
let chunk = std::mem::take(chunk);
self.consumed += 1;
chunks_consumed += 1;
if chunk.is_empty() {
continue;
}
return (chunk, chunks_consumed);
} else if limit > 0 {
let chunk = chunk.split_to(limit);
return (chunk, chunks_consumed);
} else {
break;
}
}
(Bytes::new(), chunks_consumed)
}
}
/// A [`BytesSource`] implementation for `&[u8]`
///
/// The type allows to dequeue a single [`Bytes`] chunk, which will be lazily
/// created from a reference. This allows to defer the allocation until it is
/// known how much data needs to be copied.
pub(crate) struct ByteSlice<'a> {
/// The wrapped byte slice
data: &'a [u8],
}
impl<'a> ByteSlice<'a> {
pub(crate) fn from_slice(data: &'a [u8]) -> Self {
Self { data }
}
}
impl BytesSource for ByteSlice<'_> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
let limit = limit.min(self.data.len());
if limit == 0 {
return (Bytes::new(), 0);
}
let chunk = Bytes::from(self.data[..limit].to_owned());
self.data = &self.data[chunk.len()..];
let chunks_consumed = usize::from(self.data.is_empty());
(chunk, chunks_consumed)
}
}
/// A source of one or more buffers which can be converted into `Bytes` buffers on demand
///
/// The purpose of this data type is to defer conversion as long as possible,
/// so that no heap allocation is required in case no data is writable.
pub(super) trait BytesSource {
/// Returns the next chunk from the source of owned chunks.
///
/// This method will consume parts of the source.
/// Calling it will yield `Bytes` elements up to the configured `limit`.
///
/// The method returns a tuple:
/// - The first item is the yielded `Bytes` element. The element will be
/// empty if the limit is zero or no more data is available.
/// - The second item returns how many complete chunks inside the source had
/// had been consumed. This can be less than 1, if a chunk inside the
/// source had been truncated in order to adhere to the limit. It can also
/// be more than 1, if zero-length chunks had been skipped.
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize);
}
/// Indicates how many bytes and chunks had been transferred in a write operation
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct Written {
/// The amount of bytes which had been written
pub bytes: usize,
/// The amount of full chunks which had been written
///
/// If a chunk was only partially written, it will not be counted by this field.
pub chunks: usize,
}
/// Errors triggered while writing to a send stream
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum WriteError {
/// The peer is not able to accept additional data, or the connection is congested.
///
/// If the peer issues additional flow control credit, a [`StreamEvent::Writable`] event will
/// be generated, indicating that retrying the write might succeed.
///
/// [`StreamEvent::Writable`]: crate::StreamEvent::Writable
#[error("unable to accept further writes")]
Blocked,
/// The peer is no longer accepting data on this stream, and it has been implicitly reset. The
/// stream cannot be finished or further written to.
///
/// Carries an application-defined error code.
///
/// [`StreamEvent::Finished`]: crate::StreamEvent::Finished
#[error("stopped by peer: code {0}")]
Stopped(VarInt),
/// The stream has not been opened or has already been finished or reset
#[error("closed stream")]
ClosedStream,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(super) enum SendState {
/// Sending new data
Ready,
/// Stream was finished; now sending retransmits only
DataSent { finish_acked: bool },
/// Sent RESET
ResetSent,
}
/// Reasons why attempting to finish a stream might fail
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum FinishError {
/// The peer is no longer accepting data on this stream. No
/// [`StreamEvent::Finished`] event will be emitted for this stream.
///
/// Carries an application-defined error code.
///
/// [`StreamEvent::Finished`]: crate::StreamEvent::Finished
#[error("stopped by peer: code {0}")]
Stopped(VarInt),
/// The stream has not been opened or was already finished or reset
#[error("closed stream")]
ClosedStream,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bytes_array() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut chunks = [
Bytes::from_static(b""),
Bytes::from_static(b"Hello "),
Bytes::from_static(b"Wo"),
Bytes::from_static(b""),
Bytes::from_static(b"r"),
Bytes::from_static(b"ld"),
Bytes::from_static(b""),
Bytes::from_static(b" 12345678"),
Bytes::from_static(b"9 ABCDE"),
Bytes::from_static(b"F"),
Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"),
];
let num_chunks = chunks.len();
let last_chunk_len = chunks[chunks.len() - 1].len();
let mut array = BytesArray::from_chunks(&mut chunks);
let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;
if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}
assert_eq!(&buf[..], &full[..limit]);
if limit == full.len() {
// Full consumption of the last chunk
assert_eq!(chunks_consumed, num_chunks);
// Since there are empty chunks, we consume more than there are popped
assert_eq!(chunks_consumed, chunks_popped + 3);
} else if limit > full.len() - last_chunk_len {
// Partial consumption of the last chunk
assert_eq!(chunks_consumed, num_chunks - 1);
assert_eq!(chunks_consumed, chunks_popped + 2);
}
}
}
#[test]
fn byte_slice() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut array = ByteSlice::from_slice(&full[..]);
let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;
if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}
assert_eq!(&buf[..], &full[..limit]);
if limit != 0 {
assert_eq!(chunks_popped, 1);
} else {
assert_eq!(chunks_popped, 0);
}
if limit == full.len() {
assert_eq!(chunks_consumed, 1);
} else {
assert_eq!(chunks_consumed, 0);
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
use crate::Instant;
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
pub(crate) enum Timer {
/// When to send an ack-eliciting probe packet or declare unacked packets lost
LossDetection = 0,
/// When to close the connection after no activity
Idle = 1,
/// When the close timer expires, the connection has been gracefully terminated.
Close = 2,
/// When keys are discarded because they should not be needed anymore
KeyDiscard = 3,
/// When to give up on validating a new path to the peer
PathValidation = 4,
/// When to send a `PING` frame to keep the connection alive
KeepAlive = 5,
/// When pacing will allow us to send a packet
Pacing = 6,
/// When to invalidate old CID and proactively push new one via NEW_CONNECTION_ID frame
PushNewCid = 7,
/// When to send an immediate ACK if there are unacked ack-eliciting packets of the peer
MaxAckDelay = 8,
}
impl Timer {
pub(crate) const VALUES: [Self; 9] = [
Self::LossDetection,
Self::Idle,
Self::Close,
Self::KeyDiscard,
Self::PathValidation,
Self::KeepAlive,
Self::Pacing,
Self::PushNewCid,
Self::MaxAckDelay,
];
}
/// A table of data associated with each distinct kind of `Timer`
#[derive(Debug, Copy, Clone, Default)]
pub(crate) struct TimerTable {
data: [Option<Instant>; 10],
}
impl TimerTable {
pub(super) fn set(&mut self, timer: Timer, time: Instant) {
self.data[timer as usize] = Some(time);
}
pub(super) fn get(&self, timer: Timer) -> Option<Instant> {
self.data[timer as usize]
}
pub(super) fn stop(&mut self, timer: Timer) {
self.data[timer as usize] = None;
}
pub(super) fn next_timeout(&self) -> Option<Instant> {
self.data.iter().filter_map(|&x| x).min()
}
pub(super) fn is_expired(&self, timer: Timer, after: Instant) -> bool {
self.data[timer as usize].is_some_and(|x| x <= after)
}
}

22
vendor/quinn-proto/src/constant_time.rs vendored Normal file
View File

@@ -0,0 +1,22 @@
// This function is non-inline to prevent the optimizer from looking inside it.
#[inline(never)]
fn constant_time_ne(a: &[u8], b: &[u8]) -> u8 {
assert!(a.len() == b.len());
// These useless slices make the optimizer elide the bounds checks.
// See the comment in clone_from_slice() added on Rust commit 6a7bc47.
let len = a.len();
let a = &a[..len];
let b = &b[..len];
let mut tmp = 0;
for i in 0..len {
tmp |= a[i] ^ b[i];
}
tmp // The compare with 0 must happen outside this function.
}
/// Compares byte strings in constant time.
pub(crate) fn eq(a: &[u8], b: &[u8]) -> bool {
a.len() == b.len() && constant_time_ne(a, b) == 0
}

223
vendor/quinn-proto/src/crypto.rs vendored Normal file
View File

@@ -0,0 +1,223 @@
//! Traits and implementations for the QUIC cryptography protocol
//!
//! The protocol logic in Quinn is contained in types that abstract over the actual
//! cryptographic protocol used. This module contains the traits used for this
//! abstraction layer as well as a single implementation of these traits that uses
//! *ring* and rustls to implement the TLS protocol support.
//!
//! Note that usage of any protocol (version) other than TLS 1.3 does not conform to any
//! published versions of the specification, and will not be supported in QUIC v1.
use std::{any::Any, str, sync::Arc};
use bytes::BytesMut;
use crate::{
ConnectError, Side, TransportError, shared::ConnectionId,
transport_parameters::TransportParameters,
};
/// Cryptography interface based on *ring*
#[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
pub(crate) mod ring_like;
/// TLS interface based on rustls
#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
pub mod rustls;
/// A cryptographic session (commonly TLS)
pub trait Session: Send + Sync + 'static {
/// Create the initial set of keys given the client's initial destination ConnectionId
fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys;
/// Get data negotiated during the handshake, if available
///
/// Returns `None` until the connection emits `HandshakeDataReady`.
fn handshake_data(&self) -> Option<Box<dyn Any>>;
/// Get the peer's identity, if available
fn peer_identity(&self) -> Option<Box<dyn Any>>;
/// Get the 0-RTT keys if available (clients only)
///
/// On the client side, this method can be used to see if 0-RTT key material is available
/// to start sending data before the protocol handshake has completed.
///
/// Returns `None` if the key material is not available. This might happen if you have
/// not connected to this server before.
fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn PacketKey>)>;
/// If the 0-RTT-encrypted data has been accepted by the peer
fn early_data_accepted(&self) -> Option<bool>;
/// Returns `true` until the connection is fully established.
fn is_handshaking(&self) -> bool;
/// Read bytes of handshake data
///
/// This should be called with the contents of `CRYPTO` frames. If it returns `Ok`, the
/// caller should call `write_handshake()` to check if the crypto protocol has anything
/// to send to the peer. This method will only return `true` the first time that
/// handshake data is available. Future calls will always return false.
///
/// On success, returns `true` iff `self.handshake_data()` has been populated.
fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError>;
/// The peer's QUIC transport parameters
///
/// These are only available after the first flight from the peer has been received.
fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError>;
/// Writes handshake bytes into the given buffer and optionally returns the negotiated keys
///
/// When the handshake proceeds to the next phase, this method will return a new set of
/// keys to encrypt data with.
fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys>;
/// Compute keys for the next key update
fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn PacketKey>>>;
/// Verify the integrity of a retry packet
fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool;
/// Fill `output` with `output.len()` bytes of keying material derived
/// from the [Session]'s secrets, using `label` and `context` for domain
/// separation.
///
/// This function will fail, returning [ExportKeyingMaterialError],
/// if the requested output length is too large.
fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: &[u8],
) -> Result<(), ExportKeyingMaterialError>;
}
/// A pair of keys for bidirectional communication
pub struct KeyPair<T> {
/// Key for encrypting data
pub local: T,
/// Key for decrypting data
pub remote: T,
}
/// A complete set of keys for a certain packet space
pub struct Keys {
/// Header protection keys
pub header: KeyPair<Box<dyn HeaderKey>>,
/// Packet protection keys
pub packet: KeyPair<Box<dyn PacketKey>>,
}
/// Client-side configuration for the crypto protocol
pub trait ClientConfig: Send + Sync {
/// Start a client session with this configuration
fn start_session(
self: Arc<Self>,
version: u32,
server_name: &str,
params: &TransportParameters,
) -> Result<Box<dyn Session>, ConnectError>;
}
/// Server-side configuration for the crypto protocol
pub trait ServerConfig: Send + Sync {
/// Create the initial set of keys given the client's initial destination ConnectionId
fn initial_keys(
&self,
version: u32,
dst_cid: &ConnectionId,
) -> Result<Keys, UnsupportedVersion>;
/// Generate the integrity tag for a retry packet
///
/// Never called if `initial_keys` rejected `version`.
fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16];
/// Start a server session with this configuration
///
/// Never called if `initial_keys` rejected `version`.
fn start_session(
self: Arc<Self>,
version: u32,
params: &TransportParameters,
) -> Box<dyn Session>;
}
/// Keys used to protect packet payloads
pub trait PacketKey: Send + Sync {
/// Encrypt the packet payload with the given packet number
fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize);
/// Decrypt the packet payload with the given packet number
fn decrypt(
&self,
packet: u64,
header: &[u8],
payload: &mut BytesMut,
) -> Result<(), CryptoError>;
/// The length of the AEAD tag appended to packets on encryption
fn tag_len(&self) -> usize;
/// Maximum number of packets that may be sent using a single key
fn confidentiality_limit(&self) -> u64;
/// Maximum number of incoming packets that may fail decryption before the connection must be
/// abandoned
fn integrity_limit(&self) -> u64;
}
/// Keys used to protect packet headers
pub trait HeaderKey: Send + Sync {
/// Decrypt the given packet's header
fn decrypt(&self, pn_offset: usize, packet: &mut [u8]);
/// Encrypt the given packet's header
fn encrypt(&self, pn_offset: usize, packet: &mut [u8]);
/// The sample size used for this key's algorithm
fn sample_size(&self) -> usize;
}
/// A key for signing with HMAC-based algorithms
pub trait HmacKey: Send + Sync {
/// Method for signing a message
fn sign(&self, data: &[u8], signature_out: &mut [u8]);
/// Length of `sign`'s output
fn signature_len(&self) -> usize;
/// Method for verifying a message
fn verify(&self, data: &[u8], signature: &[u8]) -> Result<(), CryptoError>;
}
/// Error returned by [Session::export_keying_material].
///
/// This error occurs if the requested output length is too large.
#[derive(Debug, PartialEq, Eq)]
pub struct ExportKeyingMaterialError;
/// A pseudo random key for HKDF
pub trait HandshakeTokenKey: Send + Sync {
/// Derive AEAD using hkdf
fn aead_from_hkdf(&self, random_bytes: &[u8]) -> Box<dyn AeadKey>;
}
/// A key for sealing data with AEAD-based algorithms
pub trait AeadKey {
/// Method for sealing message `data`
fn seal(&self, data: &mut Vec<u8>, additional_data: &[u8]) -> Result<(), CryptoError>;
/// Method for opening a sealed message `data`
fn open<'a>(
&self,
data: &'a mut [u8],
additional_data: &[u8],
) -> Result<&'a mut [u8], CryptoError>;
}
/// Generic crypto errors
#[derive(Debug)]
pub struct CryptoError;
/// Error indicating that the specified QUIC version is not supported
#[derive(Debug)]
pub struct UnsupportedVersion;
impl From<UnsupportedVersion> for ConnectError {
fn from(_: UnsupportedVersion) -> Self {
Self::UnsupportedVersion
}
}

View File

@@ -0,0 +1,57 @@
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use aws_lc_rs::{aead, error, hkdf, hmac};
#[cfg(feature = "ring")]
use ring::{aead, error, hkdf, hmac};
use crate::crypto::{self, CryptoError};
impl crypto::HmacKey for hmac::Key {
fn sign(&self, data: &[u8], out: &mut [u8]) {
out.copy_from_slice(hmac::sign(self, data).as_ref());
}
fn signature_len(&self) -> usize {
32
}
fn verify(&self, data: &[u8], signature: &[u8]) -> Result<(), CryptoError> {
Ok(hmac::verify(self, data, signature)?)
}
}
impl crypto::HandshakeTokenKey for hkdf::Prk {
fn aead_from_hkdf(&self, random_bytes: &[u8]) -> Box<dyn crypto::AeadKey> {
let mut key_buffer = [0u8; 32];
let info = [random_bytes];
let okm = self.expand(&info, hkdf::HKDF_SHA256).unwrap();
okm.fill(&mut key_buffer).unwrap();
let key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_buffer).unwrap();
Box::new(aead::LessSafeKey::new(key))
}
}
impl crypto::AeadKey for aead::LessSafeKey {
fn seal(&self, data: &mut Vec<u8>, additional_data: &[u8]) -> Result<(), CryptoError> {
let aad = aead::Aad::from(additional_data);
let zero_nonce = aead::Nonce::assume_unique_for_key([0u8; 12]);
Ok(self.seal_in_place_append_tag(zero_nonce, aad, data)?)
}
fn open<'a>(
&self,
data: &'a mut [u8],
additional_data: &[u8],
) -> Result<&'a mut [u8], CryptoError> {
let aad = aead::Aad::from(additional_data);
let zero_nonce = aead::Nonce::assume_unique_for_key([0u8; 12]);
Ok(self.open_in_place(zero_nonce, aad, data)?)
}
}
impl From<error::Unspecified> for CryptoError {
fn from(_: error::Unspecified) -> Self {
Self
}
}

656
vendor/quinn-proto/src/crypto/rustls.rs vendored Normal file
View File

@@ -0,0 +1,656 @@
use std::{any::Any, io, str, sync::Arc};
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use aws_lc_rs::aead;
use bytes::BytesMut;
#[cfg(feature = "ring")]
use ring::aead;
pub use rustls::Error;
#[cfg(feature = "__rustls-post-quantum-test")]
use rustls::NamedGroup;
use rustls::{
self, CipherSuite,
client::danger::ServerCertVerifier,
pki_types::{CertificateDer, PrivateKeyDer, ServerName},
quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
};
#[cfg(feature = "platform-verifier")]
use rustls_platform_verifier::BuilderVerifierExt;
use crate::{
ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
crypto::{
self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
},
transport_parameters::TransportParameters,
};
impl From<Side> for rustls::Side {
fn from(s: Side) -> Self {
match s {
Side::Client => Self::Client,
Side::Server => Self::Server,
}
}
}
/// A rustls TLS session
pub struct TlsSession {
version: Version,
got_handshake_data: bool,
next_secrets: Option<Secrets>,
inner: Connection,
suite: Suite,
}
impl TlsSession {
fn side(&self) -> Side {
match self.inner {
Connection::Client(_) => Side::Client,
Connection::Server(_) => Side::Server,
}
}
}
impl crypto::Session for TlsSession {
fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
initial_keys(self.version, *dst_cid, side, &self.suite)
}
fn handshake_data(&self) -> Option<Box<dyn Any>> {
if !self.got_handshake_data {
return None;
}
Some(Box::new(HandshakeData {
protocol: self.inner.alpn_protocol().map(|x| x.into()),
server_name: match self.inner {
Connection::Client(_) => None,
Connection::Server(ref session) => session.server_name().map(|x| x.into()),
},
#[cfg(feature = "__rustls-post-quantum-test")]
negotiated_key_exchange_group: self
.inner
.negotiated_key_exchange_group()
.expect("key exchange group is negotiated")
.name(),
}))
}
/// For the rustls `TlsSession`, the `Any` type is `Vec<rustls::pki_types::CertificateDer>`
fn peer_identity(&self) -> Option<Box<dyn Any>> {
self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
Box::new(
v.iter()
.map(|v| v.clone().into_owned())
.collect::<Vec<CertificateDer<'static>>>(),
)
})
}
fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
let keys = self.inner.zero_rtt_keys()?;
Some((Box::new(keys.header), Box::new(keys.packet)))
}
fn early_data_accepted(&self) -> Option<bool> {
match self.inner {
Connection::Client(ref session) => Some(session.is_early_data_accepted()),
_ => None,
}
}
fn is_handshaking(&self) -> bool {
self.inner.is_handshaking()
}
fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
self.inner.read_hs(buf).map_err(|e| {
if let Some(alert) = self.inner.alert() {
TransportError {
code: TransportErrorCode::crypto(alert.into()),
frame: None,
reason: e.to_string(),
}
} else {
TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
}
})?;
if !self.got_handshake_data {
// Hack around the lack of an explicit signal from rustls to reflect ClientHello being
// ready on incoming connections, or ALPN negotiation completing on outgoing
// connections.
let have_server_name = match self.inner {
Connection::Client(_) => false,
Connection::Server(ref session) => session.server_name().is_some(),
};
if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
self.got_handshake_data = true;
return Ok(true);
}
}
Ok(false)
}
fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
match self.inner.quic_transport_parameters() {
None => Ok(None),
Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
Ok(params) => Ok(Some(params)),
Err(e) => Err(e.into()),
},
}
}
fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
let keys = match self.inner.write_hs(buf)? {
KeyChange::Handshake { keys } => keys,
KeyChange::OneRtt { keys, next } => {
self.next_secrets = Some(next);
keys
}
};
Some(Keys {
header: KeyPair {
local: Box::new(keys.local.header),
remote: Box::new(keys.remote.header),
},
packet: KeyPair {
local: Box::new(keys.local.packet),
remote: Box::new(keys.remote.packet),
},
})
}
fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
let secrets = self.next_secrets.as_mut()?;
let keys = secrets.next_packet_keys();
Some(KeyPair {
local: Box::new(keys.local),
remote: Box::new(keys.remote),
})
}
fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
let tag_start = match payload.len().checked_sub(16) {
Some(x) => x,
None => return false,
};
let mut pseudo_packet =
Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
pseudo_packet.push(orig_dst_cid.len() as u8);
pseudo_packet.extend_from_slice(orig_dst_cid);
pseudo_packet.extend_from_slice(header);
let tag_start = tag_start + pseudo_packet.len();
pseudo_packet.extend_from_slice(payload);
let (nonce, key) = match self.version {
Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
_ => unreachable!(),
};
let nonce = aead::Nonce::assume_unique_for_key(nonce);
let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
}
fn export_keying_material(
&self,
output: &mut [u8],
label: &[u8],
context: &[u8],
) -> Result<(), ExportKeyingMaterialError> {
self.inner
.export_keying_material(output, label, Some(context))
.map_err(|_| ExportKeyingMaterialError)?;
Ok(())
}
}
const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
];
const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
];
const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
];
const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
];
impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
let (header, sample) = packet.split_at_mut(pn_offset + 4);
let (first, rest) = header.split_at_mut(1);
let pn_end = Ord::min(pn_offset + 3, rest.len());
self.decrypt_in_place(
&sample[..self.sample_size()],
&mut first[0],
&mut rest[pn_offset - 1..pn_end],
)
.unwrap();
}
fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
let (header, sample) = packet.split_at_mut(pn_offset + 4);
let (first, rest) = header.split_at_mut(1);
let pn_end = Ord::min(pn_offset + 3, rest.len());
self.encrypt_in_place(
&sample[..self.sample_size()],
&mut first[0],
&mut rest[pn_offset - 1..pn_end],
)
.unwrap();
}
fn sample_size(&self) -> usize {
self.sample_len()
}
}
/// Authentication data for (rustls) TLS session
pub struct HandshakeData {
/// The negotiated application protocol, if ALPN is in use
///
/// Guaranteed to be set if a nonempty list of protocols was specified for this connection.
pub protocol: Option<Vec<u8>>,
/// The server name specified by the client, if any
///
/// Always `None` for outgoing connections
pub server_name: Option<String>,
/// The key exchange group negotiated with the peer
#[cfg(feature = "__rustls-post-quantum-test")]
pub negotiated_key_exchange_group: NamedGroup,
}
/// A QUIC-compatible TLS client configuration
///
/// Quinn implicitly constructs a `QuicClientConfig` with reasonable defaults within
/// [`ClientConfig::with_root_certificates()`][root_certs] and [`ClientConfig::with_platform_verifier()`][platform].
/// Alternatively, `QuicClientConfig`'s [`TryFrom`] implementation can be used to wrap around a
/// custom [`rustls::ClientConfig`], in which case care should be taken around certain points:
///
/// - If `enable_early_data` is not set to true, then sending 0-RTT data will not be possible on
/// outgoing connections.
/// - The [`rustls::ClientConfig`] must have TLS 1.3 support enabled for conversion to succeed.
///
/// The object in the `resumption` field of the inner [`rustls::ClientConfig`] determines whether
/// calling `into_0rtt` on outgoing connections returns `Ok` or `Err`. It typically allows
/// `into_0rtt` to proceed if it recognizes the server name, and defaults to an in-memory cache of
/// 256 server names.
///
/// [root_certs]: crate::config::ClientConfig::with_root_certificates()
/// [platform]: crate::config::ClientConfig::with_platform_verifier()
pub struct QuicClientConfig {
pub(crate) inner: Arc<rustls::ClientConfig>,
initial: Suite,
}
impl QuicClientConfig {
#[cfg(feature = "platform-verifier")]
pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
// Keep in sync with `inner()` below
let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap() // The default providers support TLS 1.3
.with_platform_verifier()?
.with_no_client_auth();
inner.enable_early_data = true;
Ok(Self {
// We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
initial: initial_suite_from_provider(inner.crypto_provider())
.expect("no initial cipher suite found"),
inner: Arc::new(inner),
})
}
/// Initialize a sane QUIC-compatible TLS client configuration
///
/// QUIC requires that TLS 1.3 be enabled. Advanced users can use any [`rustls::ClientConfig`] that
/// satisfies this requirement.
pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
let inner = Self::inner(verifier);
Self {
// We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
initial: initial_suite_from_provider(inner.crypto_provider())
.expect("no initial cipher suite found"),
inner: Arc::new(inner),
}
}
/// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite
///
/// This is useful if you want to avoid the initial cipher suite for traffic encryption.
pub fn with_initial(
inner: Arc<rustls::ClientConfig>,
initial: Suite,
) -> Result<Self, NoInitialCipherSuite> {
match initial.suite.common.suite {
CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
_ => Err(NoInitialCipherSuite { specific: true }),
}
}
pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
// Keep in sync with `with_platform_verifier()` above
let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap() // The default providers support TLS 1.3
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth();
config.enable_early_data = true;
config
}
}
impl crypto::ClientConfig for QuicClientConfig {
fn start_session(
self: Arc<Self>,
version: u32,
server_name: &str,
params: &TransportParameters,
) -> Result<Box<dyn crypto::Session>, ConnectError> {
let version = interpret_version(version)?;
Ok(Box::new(TlsSession {
version,
got_handshake_data: false,
next_secrets: None,
inner: rustls::quic::Connection::Client(
rustls::quic::ClientConnection::new(
self.inner.clone(),
version,
ServerName::try_from(server_name)
.map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
.to_owned(),
to_vec(params),
)
.unwrap(),
),
suite: self.initial,
}))
}
}
impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
Arc::new(inner).try_into()
}
}
impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
Ok(Self {
initial: initial_suite_from_provider(inner.crypto_provider())
.ok_or(NoInitialCipherSuite { specific: false })?,
inner,
})
}
}
/// The initial cipher suite (AES-128-GCM-SHA256) is not available
///
/// When the cipher suite is supplied `with_initial()`, it must be
/// [`CipherSuite::TLS13_AES_128_GCM_SHA256`]. When the cipher suite is derived from a config's
/// [`CryptoProvider`][provider], that provider must reference a cipher suite with the same ID.
///
/// [provider]: rustls::crypto::CryptoProvider
#[derive(Clone, Debug)]
pub struct NoInitialCipherSuite {
/// Whether the initial cipher suite was supplied by the caller
specific: bool,
}
impl std::fmt::Display for NoInitialCipherSuite {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(match self.specific {
true => "invalid cipher suite specified",
false => "no initial cipher suite found",
})
}
}
impl std::error::Error for NoInitialCipherSuite {}
/// A QUIC-compatible TLS server configuration
///
/// Quinn implicitly constructs a `QuicServerConfig` with reasonable defaults within
/// [`ServerConfig::with_single_cert()`][single]. Alternatively, `QuicServerConfig`'s [`TryFrom`]
/// implementation or `with_initial` method can be used to wrap around a custom
/// [`rustls::ServerConfig`], in which case care should be taken around certain points:
///
/// - If `max_early_data_size` is not set to `u32::MAX`, the server will not be able to accept
/// incoming 0-RTT data. QUIC prohibits `max_early_data_size` values other than 0 or `u32::MAX`.
/// - The `rustls::ServerConfig` must have TLS 1.3 support enabled for conversion to succeed.
///
/// [single]: crate::config::ServerConfig::with_single_cert()
pub struct QuicServerConfig {
inner: Arc<rustls::ServerConfig>,
initial: Suite,
}
impl QuicServerConfig {
pub(crate) fn new(
cert_chain: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) -> Result<Self, rustls::Error> {
let inner = Self::inner(cert_chain, key)?;
Ok(Self {
// We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
initial: initial_suite_from_provider(inner.crypto_provider())
.expect("no initial cipher suite found"),
inner: Arc::new(inner),
})
}
/// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite
///
/// This is useful if you want to avoid the initial cipher suite for traffic encryption.
pub fn with_initial(
inner: Arc<rustls::ServerConfig>,
initial: Suite,
) -> Result<Self, NoInitialCipherSuite> {
match initial.suite.common.suite {
CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
_ => Err(NoInitialCipherSuite { specific: true }),
}
}
/// Initialize a sane QUIC-compatible TLS server configuration
///
/// QUIC requires that TLS 1.3 be enabled, and that the maximum early data size is either 0 or
/// `u32::MAX`. Advanced users can use any [`rustls::ServerConfig`] that satisfies these
/// requirements.
pub(crate) fn inner(
cert_chain: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) -> Result<rustls::ServerConfig, rustls::Error> {
let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap() // The *ring* default provider supports TLS 1.3
.with_no_client_auth()
.with_single_cert(cert_chain, key)?;
inner.max_early_data_size = u32::MAX;
Ok(inner)
}
}
impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
Arc::new(inner).try_into()
}
}
impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
type Error = NoInitialCipherSuite;
fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
Ok(Self {
initial: initial_suite_from_provider(inner.crypto_provider())
.ok_or(NoInitialCipherSuite { specific: false })?,
inner,
})
}
}
impl crypto::ServerConfig for QuicServerConfig {
fn start_session(
self: Arc<Self>,
version: u32,
params: &TransportParameters,
) -> Box<dyn crypto::Session> {
// Safe: `start_session()` is never called if `initial_keys()` rejected `version`
let version = interpret_version(version).unwrap();
Box::new(TlsSession {
version,
got_handshake_data: false,
next_secrets: None,
inner: rustls::quic::Connection::Server(
rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
.unwrap(),
),
suite: self.initial,
})
}
fn initial_keys(
&self,
version: u32,
dst_cid: &ConnectionId,
) -> Result<Keys, UnsupportedVersion> {
let version = interpret_version(version)?;
Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
}
fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
// Safe: `start_session()` is never called if `initial_keys()` rejected `version`
let version = interpret_version(version).unwrap();
let (nonce, key) = match version {
Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
_ => unreachable!(),
};
let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
pseudo_packet.push(orig_dst_cid.len() as u8);
pseudo_packet.extend_from_slice(orig_dst_cid);
pseudo_packet.extend_from_slice(packet);
let nonce = aead::Nonce::assume_unique_for_key(nonce);
let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
let tag = key
.seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
.unwrap();
let mut result = [0; 16];
result.copy_from_slice(tag.as_ref());
result
}
}
pub(crate) fn initial_suite_from_provider(
provider: &Arc<rustls::crypto::CryptoProvider>,
) -> Option<Suite> {
provider
.cipher_suites
.iter()
.find_map(|cs| match (cs.suite(), cs.tls13()) {
(rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
Some(suite.quic_suite())
}
_ => None,
})
.flatten()
}
pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
#[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
let provider = rustls::crypto::aws_lc_rs::default_provider();
#[cfg(feature = "rustls-ring")]
let provider = rustls::crypto::ring::default_provider();
Arc::new(provider)
}
fn to_vec(params: &TransportParameters) -> Vec<u8> {
let mut bytes = Vec::new();
params.write(&mut bytes);
bytes
}
pub(crate) fn initial_keys(
version: Version,
dst_cid: ConnectionId,
side: Side,
suite: &Suite,
) -> Keys {
let keys = suite.keys(&dst_cid, side.into(), version);
Keys {
header: KeyPair {
local: Box::new(keys.local.header),
remote: Box::new(keys.remote.header),
},
packet: KeyPair {
local: Box::new(keys.local.packet),
remote: Box::new(keys.remote.packet),
},
}
}
impl crypto::PacketKey for Box<dyn PacketKey> {
fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
let (header, payload_tag) = buf.split_at_mut(header_len);
let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
let tag = self.encrypt_in_place(packet, &*header, payload).unwrap();
tag_storage.copy_from_slice(tag.as_ref());
}
fn decrypt(
&self,
packet: u64,
header: &[u8],
payload: &mut BytesMut,
) -> Result<(), CryptoError> {
let plain = self
.decrypt_in_place(packet, header, payload.as_mut())
.map_err(|_| CryptoError)?;
let plain_len = plain.len();
payload.truncate(plain_len);
Ok(())
}
fn tag_len(&self) -> usize {
(**self).tag_len()
}
fn confidentiality_limit(&self) -> u64 {
(**self).confidentiality_limit()
}
fn integrity_limit(&self) -> u64 {
(**self).integrity_limit()
}
}
fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
match version {
0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
_ => Err(UnsupportedVersion),
}
}

1331
vendor/quinn-proto/src/endpoint.rs vendored Normal file

File diff suppressed because it is too large Load Diff

1008
vendor/quinn-proto/src/frame.rs vendored Normal file

File diff suppressed because it is too large Load Diff

336
vendor/quinn-proto/src/lib.rs vendored Normal file
View File

@@ -0,0 +1,336 @@
//! Low-level protocol logic for the QUIC protoocol
//!
//! quinn-proto contains a fully deterministic implementation of QUIC protocol logic. It contains
//! no networking code and does not get any relevant timestamps from the operating system. Most
//! users may want to use the futures-based quinn API instead.
//!
//! The quinn-proto API might be of interest if you want to use it from a C or C++ project
//! through C bindings or if you want to use a different event loop than the one tokio provides.
//!
//! The most important types are `Endpoint`, which conceptually represents the protocol state for
//! a single socket and mostly manages configuration and dispatches incoming datagrams to the
//! related `Connection`. `Connection` types contain the bulk of the protocol logic related to
//! managing a single connection and all the related state (such as streams).
#![cfg_attr(not(fuzzing), warn(missing_docs))]
#![cfg_attr(test, allow(dead_code))]
// Fixes welcome:
#![warn(unreachable_pub)]
#![allow(clippy::cognitive_complexity)]
#![allow(clippy::too_many_arguments)]
#![warn(clippy::use_self)]
use std::{
fmt,
net::{IpAddr, SocketAddr},
ops,
};
mod cid_queue;
pub mod coding;
mod constant_time;
mod range_set;
#[cfg(all(test, any(feature = "rustls-aws-lc-rs", feature = "rustls-ring")))]
mod tests;
pub mod transport_parameters;
mod varint;
pub use varint::{VarInt, VarIntBoundsExceeded};
#[cfg(feature = "bloom")]
mod bloom_token_log;
#[cfg(feature = "bloom")]
pub use bloom_token_log::BloomTokenLog;
mod connection;
pub use crate::connection::{
Chunk, Chunks, ClosedStream, Connection, ConnectionError, ConnectionStats, Datagrams, Event,
FinishError, FrameStats, PathStats, ReadError, ReadableError, RecvStream, RttEstimator,
SendDatagramError, SendStream, ShouldTransmit, StreamEvent, Streams, UdpStats, WriteError,
Written,
};
#[cfg(feature = "qlog")]
pub use connection::qlog::QlogStream;
#[cfg(feature = "rustls")]
pub use rustls;
mod config;
#[cfg(feature = "qlog")]
pub use config::QlogConfig;
pub use config::{
AckFrequencyConfig, ClientConfig, ConfigError, EndpointConfig, IdleTimeout, MtuDiscoveryConfig,
ServerConfig, StdSystemTime, TimeSource, TransportConfig, ValidationTokenConfig,
};
pub mod crypto;
mod frame;
use crate::frame::Frame;
pub use crate::frame::{ApplicationClose, ConnectionClose, Datagram, FrameType};
mod endpoint;
pub use crate::endpoint::{
AcceptError, ConnectError, ConnectionHandle, DatagramEvent, Endpoint, Incoming, RetryError,
};
mod packet;
pub use packet::{
ConnectionIdParser, FixedLengthConnectionIdParser, LongType, PacketDecodeError, PartialDecode,
ProtectedHeader, ProtectedInitialHeader,
};
mod shared;
pub use crate::shared::{ConnectionEvent, ConnectionId, EcnCodepoint, EndpointEvent};
mod transport_error;
pub use crate::transport_error::{Code as TransportErrorCode, Error as TransportError};
pub mod congestion;
mod cid_generator;
pub use crate::cid_generator::{
ConnectionIdGenerator, HashedConnectionIdGenerator, InvalidCid, RandomConnectionIdGenerator,
};
mod token;
use token::ResetToken;
pub use token::{NoneTokenLog, NoneTokenStore, TokenLog, TokenReuseError, TokenStore};
mod token_memory_cache;
pub use token_memory_cache::TokenMemoryCache;
#[cfg(feature = "arbitrary")]
use arbitrary::Arbitrary;
// Deal with time
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
pub(crate) use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[cfg(all(target_family = "wasm", target_os = "unknown"))]
pub(crate) use web_time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[cfg(fuzzing)]
pub mod fuzzing {
pub use crate::connection::{Retransmits, State as ConnectionState, StreamsState};
pub use crate::frame::ResetStream;
pub use crate::packet::PartialDecode;
pub use crate::transport_parameters::TransportParameters;
pub use bytes::{BufMut, BytesMut};
#[cfg(feature = "arbitrary")]
use arbitrary::{Arbitrary, Result, Unstructured};
#[cfg(feature = "arbitrary")]
impl<'arbitrary> Arbitrary<'arbitrary> for TransportParameters {
fn arbitrary(u: &mut Unstructured<'arbitrary>) -> Result<Self> {
Ok(Self {
initial_max_streams_bidi: u.arbitrary()?,
initial_max_streams_uni: u.arbitrary()?,
ack_delay_exponent: u.arbitrary()?,
max_udp_payload_size: u.arbitrary()?,
..Self::default()
})
}
}
#[derive(Debug)]
pub struct PacketParams {
pub local_cid_len: usize,
pub buf: BytesMut,
pub grease_quic_bit: bool,
}
#[cfg(feature = "arbitrary")]
impl<'arbitrary> Arbitrary<'arbitrary> for PacketParams {
fn arbitrary(u: &mut Unstructured<'arbitrary>) -> Result<Self> {
let local_cid_len: usize = u.int_in_range(0..=crate::MAX_CID_SIZE)?;
let bytes: Vec<u8> = Vec::arbitrary(u)?;
let mut buf = BytesMut::new();
buf.put_slice(&bytes[..]);
Ok(Self {
local_cid_len,
buf,
grease_quic_bit: bool::arbitrary(u)?,
})
}
}
}
/// The QUIC protocol version implemented.
pub const DEFAULT_SUPPORTED_VERSIONS: &[u32] = &[
0x00000001,
0xff00_001d,
0xff00_001e,
0xff00_001f,
0xff00_0020,
0xff00_0021,
0xff00_0022,
];
/// Whether an endpoint was the initiator of a connection
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum Side {
/// The initiator of a connection
Client = 0,
/// The acceptor of a connection
Server = 1,
}
impl Side {
#[inline]
/// Shorthand for `self == Side::Client`
pub fn is_client(self) -> bool {
self == Self::Client
}
#[inline]
/// Shorthand for `self == Side::Server`
pub fn is_server(self) -> bool {
self == Self::Server
}
}
impl ops::Not for Side {
type Output = Self;
fn not(self) -> Self {
match self {
Self::Client => Self::Server,
Self::Server => Self::Client,
}
}
}
/// Whether a stream communicates data in both directions or only from the initiator
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum Dir {
/// Data flows in both directions
Bi = 0,
/// Data flows only from the stream's initiator
Uni = 1,
}
impl Dir {
fn iter() -> impl Iterator<Item = Self> {
[Self::Bi, Self::Uni].iter().cloned()
}
}
impl fmt::Display for Dir {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Dir::*;
f.pad(match *self {
Bi => "bidirectional",
Uni => "unidirectional",
})
}
}
/// Identifier for a stream within a particular connection
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct StreamId(u64);
impl fmt::Display for StreamId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let initiator = match self.initiator() {
Side::Client => "client",
Side::Server => "server",
};
let dir = match self.dir() {
Dir::Uni => "uni",
Dir::Bi => "bi",
};
write!(
f,
"{} {}directional stream {}",
initiator,
dir,
self.index()
)
}
}
impl StreamId {
/// Create a new StreamId
pub fn new(initiator: Side, dir: Dir, index: u64) -> Self {
Self((index << 2) | ((dir as u64) << 1) | initiator as u64)
}
/// Which side of a connection initiated the stream
pub fn initiator(self) -> Side {
if self.0 & 0x1 == 0 {
Side::Client
} else {
Side::Server
}
}
/// Which directions data flows in
pub fn dir(self) -> Dir {
if self.0 & 0x2 == 0 { Dir::Bi } else { Dir::Uni }
}
/// Distinguishes streams of the same initiator and directionality
pub fn index(self) -> u64 {
self.0 >> 2
}
}
impl From<StreamId> for VarInt {
fn from(x: StreamId) -> Self {
unsafe { Self::from_u64_unchecked(x.0) }
}
}
impl From<VarInt> for StreamId {
fn from(v: VarInt) -> Self {
Self(v.0)
}
}
impl From<StreamId> for u64 {
fn from(x: StreamId) -> Self {
x.0
}
}
impl coding::Codec for StreamId {
fn decode<B: bytes::Buf>(buf: &mut B) -> coding::Result<Self> {
VarInt::decode(buf).map(|x| Self(x.into_inner()))
}
fn encode<B: bytes::BufMut>(&self, buf: &mut B) {
VarInt::from_u64(self.0).unwrap().encode(buf);
}
}
/// An outgoing packet
#[derive(Debug)]
#[must_use]
pub struct Transmit {
/// The socket this datagram should be sent to
pub destination: SocketAddr,
/// Explicit congestion notification bits to set on the packet
pub ecn: Option<EcnCodepoint>,
/// Amount of data written to the caller-supplied buffer
pub size: usize,
/// The segment size if this transmission contains multiple datagrams.
/// This is `None` if the transmit only contains a single datagram
pub segment_size: Option<usize>,
/// Optional source IP address for the datagram
pub src_ip: Option<IpAddr>,
}
//
// Useful internal constants
//
/// The maximum number of CIDs we bother to issue per connection
const LOC_CID_COUNT: u64 = 8;
const RESET_TOKEN_SIZE: usize = 16;
const MAX_CID_SIZE: usize = 20;
const MIN_INITIAL_SIZE: u16 = 1200;
/// <https://www.rfc-editor.org/rfc/rfc9000.html#name-datagram-size>
const INITIAL_MTU: u16 = 1200;
const MAX_UDP_PAYLOAD: u16 = 65527;
const TIMER_GRANULARITY: Duration = Duration::from_millis(1);
/// Maximum number of streams that can be uniquely identified by a stream ID
const MAX_STREAM_COUNT: u64 = 1 << 60;

1014
vendor/quinn-proto/src/packet.rs vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,209 @@
use std::ops::Range;
use tinyvec::TinyVec;
/// A set of u64 values optimized for long runs and random insert/delete/contains
///
/// `ArrayRangeSet` uses an array representation, where each array entry represents
/// a range.
///
/// The array-based RangeSet provides 2 benefits:
/// - There exists an inline representation, which avoids the need of heap
/// allocating ACK ranges for SentFrames for small ranges.
/// - Iterating over ranges should usually be faster since there is only
/// a single cache-friendly contiguous range.
///
/// `ArrayRangeSet` is especially useful for tracking ACK ranges where the amount
/// of ranges is usually very low (since ACK numbers are in consecutive fashion
/// unless reordering or packet loss occur).
#[derive(Debug, Default)]
pub struct ArrayRangeSet(TinyVec<[Range<u64>; ARRAY_RANGE_SET_INLINE_CAPACITY]>);
/// The capacity of elements directly stored in [`ArrayRangeSet`]
///
/// An inline capacity of 2 is chosen to keep `SentFrame` below 128 bytes.
const ARRAY_RANGE_SET_INLINE_CAPACITY: usize = 2;
impl Clone for ArrayRangeSet {
fn clone(&self) -> Self {
// tinyvec keeps the heap representation after clones.
// We rather prefer the inline representation for clones if possible,
// since clones (e.g. for storage in `SentFrames`) are rarely mutated
if self.0.is_inline() || self.0.len() > ARRAY_RANGE_SET_INLINE_CAPACITY {
return Self(self.0.clone());
}
let mut vec = TinyVec::new();
vec.extend_from_slice(self.0.as_slice());
Self(vec)
}
}
impl ArrayRangeSet {
pub fn new() -> Self {
Default::default()
}
pub fn iter(&self) -> impl DoubleEndedIterator<Item = Range<u64>> + '_ {
self.0.iter().cloned()
}
pub fn elts(&self) -> impl Iterator<Item = u64> + '_ {
self.iter().flatten()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn contains(&self, x: u64) -> bool {
for range in self.0.iter() {
if range.start > x {
// We only get here if there was no prior range that contained x
return false;
} else if range.contains(&x) {
return true;
}
}
false
}
pub fn subtract(&mut self, other: &Self) {
// TODO: This can potentially be made more efficient, since the we know
// individual ranges are not overlapping, and the next range must start
// after the last one finished
for range in &other.0 {
self.remove(range.clone());
}
}
pub fn insert_one(&mut self, x: u64) -> bool {
self.insert(x..x + 1)
}
pub fn insert(&mut self, x: Range<u64>) -> bool {
let mut result = false;
if x.is_empty() {
// Don't try to deal with ranges where x.end <= x.start
return false;
}
let mut idx = 0;
while idx != self.0.len() {
let range = &mut self.0[idx];
if range.start > x.end {
// The range is fully before this range and therefore not extensible.
// Add a new range to the left
self.0.insert(idx, x);
return true;
} else if range.start > x.start {
// The new range starts before this range but overlaps.
// Extend the current range to the left
// Note that we don't have to merge a potential left range, since
// this case would have been captured by merging the right range
// in the previous loop iteration
result = true;
range.start = x.start;
}
// At this point we have handled all parts of the new range which
// are in front of the current range. Now we handle everything from
// the start of the current range
if x.end <= range.end {
// Fully contained
return result;
} else if x.start <= range.end {
// Extend the current range to the end of the new range.
// Since it's not contained it must be bigger
range.end = x.end;
// Merge all follow-up ranges which overlap
while idx != self.0.len() - 1 {
let curr = self.0[idx].clone();
let next = self.0[idx + 1].clone();
if curr.end >= next.start {
self.0[idx].end = next.end.max(curr.end);
self.0.remove(idx + 1);
} else {
break;
}
}
return true;
}
idx += 1;
}
// Insert a range at the end
self.0.push(x);
true
}
pub fn remove(&mut self, x: Range<u64>) -> bool {
let mut result = false;
if x.is_empty() {
// Don't try to deal with ranges where x.end <= x.start
return false;
}
let mut idx = 0;
while idx != self.0.len() && x.start != x.end {
let range = self.0[idx].clone();
if x.end <= range.start {
// The range is fully before this range
return result;
} else if x.start >= range.end {
// The range is fully after this range
idx += 1;
continue;
}
// The range overlaps with this range
result = true;
let left = range.start..x.start;
let right = x.end..range.end;
if left.is_empty() && right.is_empty() {
self.0.remove(idx);
} else if left.is_empty() {
self.0[idx] = right;
idx += 1;
} else if right.is_empty() {
self.0[idx] = left;
idx += 1;
} else {
self.0[idx] = right;
self.0.insert(idx, left);
idx += 2;
}
}
result
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn pop_min(&mut self) -> Option<Range<u64>> {
if !self.0.is_empty() {
Some(self.0.remove(0))
} else {
None
}
}
pub fn min(&self) -> Option<u64> {
self.iter().next().map(|x| x.start)
}
pub fn max(&self) -> Option<u64> {
self.iter().next_back().map(|x| x.end - 1)
}
}

View File

@@ -0,0 +1,381 @@
use std::{
cmp,
cmp::Ordering,
collections::{BTreeMap, btree_map},
ops::{
Bound::{Excluded, Included},
Range,
},
};
/// A set of u64 values optimized for long runs and random insert/delete/contains
#[derive(Debug, Default, Clone)]
pub struct RangeSet(BTreeMap<u64, u64>);
impl RangeSet {
pub fn new() -> Self {
Default::default()
}
pub fn contains(&self, x: u64) -> bool {
self.pred(x).is_some_and(|(_, end)| end > x)
}
pub fn insert_one(&mut self, x: u64) -> bool {
if let Some((start, end)) = self.pred(x) {
match end.cmp(&x) {
// Wholly contained
Ordering::Greater => {
return false;
}
Ordering::Equal => {
// Extend existing
self.0.remove(&start);
let mut new_end = x + 1;
if let Some((next_start, next_end)) = self.succ(x) {
if next_start == new_end {
self.0.remove(&next_start);
new_end = next_end;
}
}
self.0.insert(start, new_end);
return true;
}
_ => {}
}
}
let mut new_end = x + 1;
if let Some((next_start, next_end)) = self.succ(x) {
if next_start == new_end {
self.0.remove(&next_start);
new_end = next_end;
}
}
self.0.insert(x, new_end);
true
}
pub fn insert(&mut self, mut x: Range<u64>) -> bool {
if x.is_empty() {
return false;
}
if let Some((start, end)) = self.pred(x.start) {
if end >= x.end {
// Wholly contained
return false;
} else if end >= x.start {
// Extend overlapping predecessor
self.0.remove(&start);
x.start = start;
}
}
while let Some((next_start, next_end)) = self.succ(x.start) {
if next_start > x.end {
break;
}
// Overlaps with successor
self.0.remove(&next_start);
x.end = cmp::max(next_end, x.end);
}
self.0.insert(x.start, x.end);
true
}
/// Find closest range to `x` that begins at or before it
fn pred(&self, x: u64) -> Option<(u64, u64)> {
self.0
.range((Included(0), Included(x)))
.next_back()
.map(|(&x, &y)| (x, y))
}
/// Find the closest range to `x` that begins after it
fn succ(&self, x: u64) -> Option<(u64, u64)> {
self.0
.range((Excluded(x), Included(u64::MAX)))
.next()
.map(|(&x, &y)| (x, y))
}
pub fn remove(&mut self, x: Range<u64>) -> bool {
if x.is_empty() {
return false;
}
let before = match self.pred(x.start) {
Some((start, end)) if end > x.start => {
self.0.remove(&start);
if start < x.start {
self.0.insert(start, x.start);
}
if end > x.end {
self.0.insert(x.end, end);
}
// Short-circuit if we cannot possibly overlap with another range
if end >= x.end {
return true;
}
true
}
Some(_) | None => false,
};
let mut after = false;
while let Some((start, end)) = self.succ(x.start) {
if start >= x.end {
break;
}
after = true;
self.0.remove(&start);
if end > x.end {
self.0.insert(x.end, end);
break;
}
}
before || after
}
/// Add a range to the set, returning the intersection of current ranges with the new one
pub fn replace(&mut self, mut range: Range<u64>) -> Replace<'_> {
let pred = if let Some((prev_start, prev_end)) = self
.pred(range.start)
.filter(|&(_, end)| end >= range.start)
{
self.0.remove(&prev_start);
let replaced_start = range.start;
range.start = range.start.min(prev_start);
let replaced_end = range.end.min(prev_end);
range.end = range.end.max(prev_end);
if replaced_start != replaced_end {
Some(replaced_start..replaced_end)
} else {
None
}
} else {
None
};
Replace {
set: self,
range,
pred,
}
}
pub fn add(&mut self, other: &Self) {
for (&start, &end) in &other.0 {
self.insert(start..end);
}
}
pub fn subtract(&mut self, other: &Self) {
for (&start, &end) in &other.0 {
self.remove(start..end);
}
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn min(&self) -> Option<u64> {
self.0.first_key_value().map(|(&start, _)| start)
}
pub fn max(&self) -> Option<u64> {
self.0.last_key_value().map(|(_, &end)| end - 1)
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn iter(&self) -> Iter<'_> {
Iter(self.0.iter())
}
pub fn elts(&self) -> EltIter<'_> {
EltIter {
inner: self.0.iter(),
next: 0,
end: 0,
}
}
pub fn peek_min(&self) -> Option<Range<u64>> {
let (&start, &end) = self.0.iter().next()?;
Some(start..end)
}
pub fn pop_min(&mut self) -> Option<Range<u64>> {
let result = self.peek_min()?;
self.0.remove(&result.start);
Some(result)
}
}
pub struct Iter<'a>(btree_map::Iter<'a, u64, u64>);
impl Iterator for Iter<'_> {
type Item = Range<u64>;
fn next(&mut self) -> Option<Range<u64>> {
let (&start, &end) = self.0.next()?;
Some(start..end)
}
}
impl DoubleEndedIterator for Iter<'_> {
fn next_back(&mut self) -> Option<Range<u64>> {
let (&start, &end) = self.0.next_back()?;
Some(start..end)
}
}
impl<'a> IntoIterator for &'a RangeSet {
type Item = Range<u64>;
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
self.iter()
}
}
pub struct EltIter<'a> {
inner: btree_map::Iter<'a, u64, u64>,
next: u64,
end: u64,
}
impl Iterator for EltIter<'_> {
type Item = u64;
fn next(&mut self) -> Option<u64> {
if self.next == self.end {
let (&start, &end) = self.inner.next()?;
self.next = start;
self.end = end;
}
let x = self.next;
self.next += 1;
Some(x)
}
}
impl DoubleEndedIterator for EltIter<'_> {
fn next_back(&mut self) -> Option<u64> {
if self.next == self.end {
let (&start, &end) = self.inner.next_back()?;
self.next = start;
self.end = end;
}
self.end -= 1;
Some(self.end)
}
}
/// Iterator returned by `RangeSet::replace`
pub struct Replace<'a> {
set: &'a mut RangeSet,
/// Portion of the intersection arising from a range beginning at or before the newly inserted
/// range
pred: Option<Range<u64>>,
/// Union of the input range and all ranges that have been visited by the iterator so far
range: Range<u64>,
}
impl Iterator for Replace<'_> {
type Item = Range<u64>;
fn next(&mut self) -> Option<Range<u64>> {
if let Some(pred) = self.pred.take() {
// If a range starting before the inserted range overlapped with it, return the
// corresponding overlap first
return Some(pred);
}
let (next_start, next_end) = self.set.succ(self.range.start)?;
if next_start > self.range.end {
// If the next successor range starts after the current range ends, there can be no more
// overlaps. This is sound even when `self.range.end` is increased because `RangeSet` is
// guaranteed not to contain pairs of ranges that could be simplified.
return None;
}
// Remove the redundant range...
self.set.0.remove(&next_start);
// ...and handle the case where the redundant range ends later than the new range.
let replaced_end = self.range.end.min(next_end);
self.range.end = self.range.end.max(next_end);
if next_start == replaced_end {
// If the redundant range started exactly where the new range ended, there was no
// overlap with it or any later range.
None
} else {
Some(next_start..replaced_end)
}
}
}
impl Drop for Replace<'_> {
fn drop(&mut self) {
// Ensure we drain all remaining overlapping ranges
for _ in &mut *self {}
// Insert the final aggregate range
self.set.0.insert(self.range.start, self.range.end);
}
}
/// This module contains tests which only apply for this `RangeSet` implementation
///
/// Tests which apply for all implementations can be found in the `tests.rs` module
#[cfg(test)]
mod tests {
#![allow(clippy::single_range_in_vec_init)] // https://github.com/rust-lang/rust-clippy/issues/11086
use super::*;
#[test]
fn replace_contained() {
let mut set = RangeSet::new();
set.insert(2..4);
assert_eq!(set.replace(1..5).collect::<Vec<_>>(), &[2..4]);
assert_eq!(set.len(), 1);
assert_eq!(set.peek_min().unwrap(), 1..5);
}
#[test]
fn replace_contains() {
let mut set = RangeSet::new();
set.insert(1..5);
assert_eq!(set.replace(2..4).collect::<Vec<_>>(), &[2..4]);
assert_eq!(set.len(), 1);
assert_eq!(set.peek_min().unwrap(), 1..5);
}
#[test]
fn replace_pred() {
let mut set = RangeSet::new();
set.insert(2..4);
assert_eq!(set.replace(3..5).collect::<Vec<_>>(), &[3..4]);
assert_eq!(set.len(), 1);
assert_eq!(set.peek_min().unwrap(), 2..5);
}
#[test]
fn replace_succ() {
let mut set = RangeSet::new();
set.insert(2..4);
assert_eq!(set.replace(1..3).collect::<Vec<_>>(), &[2..3]);
assert_eq!(set.len(), 1);
assert_eq!(set.peek_min().unwrap(), 1..4);
}
#[test]
fn replace_exact_pred() {
let mut set = RangeSet::new();
set.insert(2..4);
assert_eq!(set.replace(4..6).collect::<Vec<_>>(), &[]);
assert_eq!(set.len(), 1);
assert_eq!(set.peek_min().unwrap(), 2..6);
}
#[test]
fn replace_exact_succ() {
let mut set = RangeSet::new();
set.insert(2..4);
assert_eq!(set.replace(0..2).collect::<Vec<_>>(), &[]);
assert_eq!(set.len(), 1);
assert_eq!(set.peek_min().unwrap(), 0..4);
}
}

View File

@@ -0,0 +1,7 @@
mod array_range_set;
mod btree_range_set;
#[cfg(test)]
mod tests;
pub(crate) use array_range_set::ArrayRangeSet;
pub(crate) use btree_range_set::RangeSet;

View File

@@ -0,0 +1,263 @@
use std::ops::Range;
use super::*;
macro_rules! common_set_tests {
($set_name:ident, $set_type:ident) => {
mod $set_name {
use super::*;
#[test]
fn merge_and_split() {
let mut set = $set_type::new();
assert!(set.insert(0..2));
assert!(set.insert(2..4));
assert!(!set.insert(1..3));
assert_eq!(set.len(), 1);
assert_eq!(&set.elts().collect::<Vec<_>>()[..], [0, 1, 2, 3]);
assert!(!set.contains(4));
assert!(set.remove(2..3));
assert_eq!(set.len(), 2);
assert!(!set.contains(2));
assert_eq!(&set.elts().collect::<Vec<_>>()[..], [0, 1, 3]);
}
#[test]
fn double_merge_exact() {
let mut set = $set_type::new();
assert!(set.insert(0..2));
assert!(set.insert(4..6));
assert_eq!(set.len(), 2);
assert!(set.insert(2..4));
assert_eq!(set.len(), 1);
assert_eq!(&set.elts().collect::<Vec<_>>()[..], [0, 1, 2, 3, 4, 5]);
}
#[test]
fn single_merge_low() {
let mut set = $set_type::new();
assert!(set.insert(0..2));
assert!(set.insert(4..6));
assert_eq!(set.len(), 2);
assert!(set.insert(2..3));
assert_eq!(set.len(), 2);
assert_eq!(&set.elts().collect::<Vec<_>>()[..], [0, 1, 2, 4, 5]);
}
#[test]
fn single_merge_high() {
let mut set = $set_type::new();
assert!(set.insert(0..2));
assert!(set.insert(4..6));
assert_eq!(set.len(), 2);
assert!(set.insert(3..4));
assert_eq!(set.len(), 2);
assert_eq!(&set.elts().collect::<Vec<_>>()[..], [0, 1, 3, 4, 5]);
}
#[test]
fn double_merge_wide() {
let mut set = $set_type::new();
assert!(set.insert(0..2));
assert!(set.insert(4..6));
assert_eq!(set.len(), 2);
assert!(set.insert(1..5));
assert_eq!(set.len(), 1);
assert_eq!(&set.elts().collect::<Vec<_>>()[..], [0, 1, 2, 3, 4, 5]);
}
#[test]
fn double_remove() {
let mut set = $set_type::new();
assert!(set.insert(0..2));
assert!(set.insert(4..6));
assert!(set.remove(1..5));
assert_eq!(set.len(), 2);
assert_eq!(&set.elts().collect::<Vec<_>>()[..], [0, 5]);
}
#[test]
fn insert_multiple() {
let mut set = $set_type::new();
assert!(set.insert(0..1));
assert!(set.insert(2..3));
assert!(set.insert(4..5));
assert!(set.insert(0..5));
assert_eq!(set.len(), 1);
}
#[test]
fn remove_multiple() {
let mut set = $set_type::new();
assert!(set.insert(0..1));
assert!(set.insert(2..3));
assert!(set.insert(4..5));
assert!(set.remove(0..5));
assert!(set.is_empty());
}
#[test]
fn double_insert() {
let mut set = $set_type::new();
assert!(set.insert(0..2));
assert!(!set.insert(0..2));
assert!(set.insert(2..4));
assert!(!set.insert(2..4));
assert!(!set.insert(0..4));
assert!(!set.insert(1..2));
assert!(!set.insert(1..3));
assert!(!set.insert(1..4));
assert_eq!(set.len(), 1);
}
#[test]
fn skip_empty_ranges() {
let mut set = $set_type::new();
assert!(!set.insert(2..2));
assert_eq!(set.len(), 0);
assert!(!set.insert(4..4));
assert_eq!(set.len(), 0);
assert!(!set.insert(0..0));
assert_eq!(set.len(), 0);
}
#[test]
fn compare_insert_to_reference() {
const MAX_RANGE: u64 = 50;
for start in 0..=MAX_RANGE {
for end in 0..=MAX_RANGE {
println!("insert({}..{})", start, end);
let (mut set, mut reference) = create_initial_sets(MAX_RANGE);
assert_eq!(set.insert(start..end), reference.insert(start..end));
assert_sets_equal(&set, &reference);
}
}
}
#[test]
fn compare_remove_to_reference() {
const MAX_RANGE: u64 = 50;
for start in 0..=MAX_RANGE {
for end in 0..=MAX_RANGE {
println!("remove({}..{})", start, end);
let (mut set, mut reference) = create_initial_sets(MAX_RANGE);
assert_eq!(set.remove(start..end), reference.remove(start..end));
assert_sets_equal(&set, &reference);
}
}
}
#[test]
fn min_max() {
let mut set = $set_type::new();
set.insert(1..3);
set.insert(4..5);
set.insert(6..10);
assert_eq!(set.min(), Some(1));
assert_eq!(set.max(), Some(9));
}
fn create_initial_sets(max_range: u64) -> ($set_type, RefRangeSet) {
let mut set = $set_type::new();
let mut reference = RefRangeSet::new(max_range as usize);
assert_sets_equal(&set, &reference);
assert_eq!(set.insert(2..6), reference.insert(2..6));
assert_eq!(set.insert(10..14), reference.insert(10..14));
assert_eq!(set.insert(14..14), reference.insert(14..14));
assert_eq!(set.insert(18..19), reference.insert(18..19));
assert_eq!(set.insert(20..21), reference.insert(20..21));
assert_eq!(set.insert(22..24), reference.insert(22..24));
assert_eq!(set.insert(26..30), reference.insert(26..30));
assert_eq!(set.insert(34..38), reference.insert(34..38));
assert_eq!(set.insert(42..44), reference.insert(42..44));
assert_sets_equal(&set, &reference);
(set, reference)
}
fn assert_sets_equal(set: &$set_type, reference: &RefRangeSet) {
assert_eq!(set.len(), reference.len());
assert_eq!(set.is_empty(), reference.is_empty());
assert_eq!(set.elts().collect::<Vec<_>>()[..], reference.elts()[..]);
}
}
};
}
common_set_tests!(range_set, RangeSet);
common_set_tests!(array_range_set, ArrayRangeSet);
/// A very simple reference implementation of a RangeSet
struct RefRangeSet {
data: Vec<bool>,
}
impl RefRangeSet {
fn new(capacity: usize) -> Self {
Self {
data: vec![false; capacity],
}
}
fn len(&self) -> usize {
let mut last = false;
let mut count = 0;
for v in self.data.iter() {
if !last && *v {
count += 1;
}
last = *v;
}
count
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn insert(&mut self, x: Range<u64>) -> bool {
let mut result = false;
assert!(x.end <= self.data.len() as u64);
for i in x {
let i = i as usize;
if !self.data[i] {
result = true;
self.data[i] = true;
}
}
result
}
fn remove(&mut self, x: Range<u64>) -> bool {
let mut result = false;
assert!(x.end <= self.data.len() as u64);
for i in x {
let i = i as usize;
if self.data[i] {
result = true;
self.data[i] = false;
}
}
result
}
fn elts(&self) -> Vec<u64> {
self.data
.iter()
.enumerate()
.filter_map(|(i, e)| if *e { Some(i as u64) } else { None })
.collect()
}
}

180
vendor/quinn-proto/src/shared.rs vendored Normal file
View File

@@ -0,0 +1,180 @@
use std::{fmt, net::SocketAddr};
use bytes::{Buf, BufMut, BytesMut};
use crate::{Instant, MAX_CID_SIZE, ResetToken, coding::BufExt, packet::PartialDecode};
/// Events sent from an Endpoint to a Connection
#[derive(Debug)]
pub struct ConnectionEvent(pub(crate) ConnectionEventInner);
#[derive(Debug)]
pub(crate) enum ConnectionEventInner {
/// A datagram has been received for the Connection
Datagram(DatagramConnectionEvent),
/// New connection identifiers have been issued for the Connection
NewIdentifiers(Vec<IssuedCid>, Instant),
}
/// Variant of [`ConnectionEventInner`].
#[derive(Debug)]
pub(crate) struct DatagramConnectionEvent {
pub(crate) now: Instant,
pub(crate) remote: SocketAddr,
pub(crate) ecn: Option<EcnCodepoint>,
pub(crate) first_decode: PartialDecode,
pub(crate) remaining: Option<BytesMut>,
}
/// Events sent from a Connection to an Endpoint
#[derive(Debug)]
pub struct EndpointEvent(pub(crate) EndpointEventInner);
impl EndpointEvent {
/// Construct an event that indicating that a `Connection` will no longer emit events
///
/// Useful for notifying an `Endpoint` that a `Connection` has been destroyed outside of the
/// usual state machine flow, e.g. when being dropped by the user.
pub fn drained() -> Self {
Self(EndpointEventInner::Drained)
}
/// Determine whether this is the last event a `Connection` will emit
///
/// Useful for determining when connection-related event loop state can be freed.
pub fn is_drained(&self) -> bool {
self.0 == EndpointEventInner::Drained
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum EndpointEventInner {
/// The connection has been drained
Drained,
/// The reset token and/or address eligible for generating resets has been updated
ResetToken(SocketAddr, ResetToken),
/// The connection needs connection identifiers
NeedIdentifiers(Instant, u64),
/// Stop routing connection ID for this sequence number to the connection
/// When `bool == true`, a new connection ID will be issued to peer
RetireConnectionId(Instant, u64, bool),
}
/// Protocol-level identifier for a connection.
///
/// Mainly useful for identifying this connection's packets on the wire with tools like Wireshark.
#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct ConnectionId {
/// length of CID
len: u8,
/// CID in byte array
bytes: [u8; MAX_CID_SIZE],
}
impl ConnectionId {
/// Construct cid from byte array
pub fn new(bytes: &[u8]) -> Self {
debug_assert!(bytes.len() <= MAX_CID_SIZE);
let mut res = Self {
len: bytes.len() as u8,
bytes: [0; MAX_CID_SIZE],
};
res.bytes[..bytes.len()].copy_from_slice(bytes);
res
}
/// Constructs cid by reading `len` bytes from a `Buf`
///
/// Callers need to assure that `buf.remaining() >= len`
pub fn from_buf(buf: &mut (impl Buf + ?Sized), len: usize) -> Self {
debug_assert!(len <= MAX_CID_SIZE);
let mut res = Self {
len: len as u8,
bytes: [0; MAX_CID_SIZE],
};
buf.copy_to_slice(&mut res[..len]);
res
}
/// Decode from long header format
pub(crate) fn decode_long(buf: &mut impl Buf) -> Option<Self> {
let len = buf.get::<u8>().ok()? as usize;
match len > MAX_CID_SIZE || buf.remaining() < len {
false => Some(Self::from_buf(buf, len)),
true => None,
}
}
/// Encode in long header format
pub(crate) fn encode_long(&self, buf: &mut impl BufMut) {
buf.put_u8(self.len() as u8);
buf.put_slice(self);
}
}
impl ::std::ops::Deref for ConnectionId {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.bytes[0..self.len as usize]
}
}
impl ::std::ops::DerefMut for ConnectionId {
fn deref_mut(&mut self) -> &mut [u8] {
&mut self.bytes[0..self.len as usize]
}
}
impl fmt::Debug for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.bytes[0..self.len as usize].fmt(f)
}
}
impl fmt::Display for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.iter() {
write!(f, "{byte:02x}")?;
}
Ok(())
}
}
/// Explicit congestion notification codepoint
#[repr(u8)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum EcnCodepoint {
/// The ECT(0) codepoint, indicating that an endpoint is ECN-capable
Ect0 = 0b10,
/// The ECT(1) codepoint, indicating that an endpoint is ECN-capable
Ect1 = 0b01,
/// The CE codepoint, signalling that congestion was experienced
Ce = 0b11,
}
impl EcnCodepoint {
/// Create new object from the given bits
pub fn from_bits(x: u8) -> Option<Self> {
use EcnCodepoint::*;
Some(match x & 0b11 {
0b10 => Ect0,
0b01 => Ect1,
0b11 => Ce,
_ => {
return None;
}
})
}
/// Returns whether the codepoint is a CE, signalling that congestion was experienced
pub fn is_ce(self) -> bool {
matches!(self, Self::Ce)
}
}
#[derive(Debug, Copy, Clone)]
pub(crate) struct IssuedCid {
pub(crate) sequence: u64,
pub(crate) id: ConnectionId,
pub(crate) reset_token: ResetToken,
}

3376
vendor/quinn-proto/src/tests/mod.rs vendored Normal file

File diff suppressed because it is too large Load Diff

333
vendor/quinn-proto/src/tests/token.rs vendored Normal file
View File

@@ -0,0 +1,333 @@
//! Tests specifically for tokens
use super::*;
#[cfg(all(target_family = "wasm", target_os = "unknown"))]
use wasm_bindgen_test::wasm_bindgen_test as test;
#[test]
fn stateless_retry() {
let _guard = subscribe();
let mut pair = Pair::default();
pair.server.handle_incoming = Box::new(validate_incoming);
let (client_ch, _server_ch) = pair.connect();
pair.client
.connections
.get_mut(&client_ch)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}
#[test]
fn retry_token_expired() {
let _guard = subscribe();
let fake_time = Arc::new(FakeTimeSource::new());
let retry_token_lifetime = Duration::from_secs(1);
let mut pair = Pair::default();
pair.server.handle_incoming = Box::new(validate_incoming);
let mut config = server_config();
config
.time_source(Arc::clone(&fake_time) as _)
.retry_token_lifetime(retry_token_lifetime);
pair.server.set_server_config(Some(Arc::new(config)));
let client_ch = pair.begin_connect(client_config());
pair.drive_client();
pair.drive_server();
pair.drive_client();
// to expire retry token
fake_time.advance(retry_token_lifetime + Duration::from_millis(1));
pair.drive();
assert_matches!(
pair.client_conn_mut(client_ch).poll(),
Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(err) })
if err.error_code == TransportErrorCode::INVALID_TOKEN
);
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}
#[test]
fn use_token() {
let _guard = subscribe();
let mut pair = Pair::default();
let client_config = client_config();
let (client_ch, _server_ch) = pair.connect_with(client_config.clone());
pair.client
.connections
.get_mut(&client_ch)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
pair.server.handle_incoming = Box::new(|incoming| {
assert!(incoming.remote_address_validated());
assert!(incoming.may_retry());
IncomingConnectionBehavior::Accept
});
let (client_ch_2, _server_ch_2) = pair.connect_with(client_config);
pair.client
.connections
.get_mut(&client_ch_2)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}
#[test]
fn retry_then_use_token() {
let _guard = subscribe();
let mut pair = Pair::default();
let client_config = client_config();
pair.server.handle_incoming = Box::new(validate_incoming);
let (client_ch, _server_ch) = pair.connect_with(client_config.clone());
pair.client
.connections
.get_mut(&client_ch)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
pair.server.handle_incoming = Box::new(|incoming| {
assert!(incoming.remote_address_validated());
assert!(incoming.may_retry());
IncomingConnectionBehavior::Accept
});
let (client_ch_2, _server_ch_2) = pair.connect_with(client_config);
pair.client
.connections
.get_mut(&client_ch_2)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}
#[test]
fn use_token_then_retry() {
let _guard = subscribe();
let mut pair = Pair::default();
let client_config = client_config();
let (client_ch, _server_ch) = pair.connect_with(client_config.clone());
pair.client
.connections
.get_mut(&client_ch)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
pair.server.handle_incoming = Box::new({
let mut i = 0;
move |incoming| {
if i == 0 {
assert!(incoming.remote_address_validated());
assert!(incoming.may_retry());
i += 1;
IncomingConnectionBehavior::Retry
} else if i == 1 {
assert!(incoming.remote_address_validated());
assert!(!incoming.may_retry());
i += 1;
IncomingConnectionBehavior::Accept
} else {
panic!("too many handle_incoming iterations")
}
}
});
let (client_ch_2, _server_ch_2) = pair.connect_with(client_config);
pair.client
.connections
.get_mut(&client_ch_2)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}
#[test]
fn use_same_token_twice() {
#[derive(Default)]
struct EvilTokenStore(Mutex<Bytes>);
impl TokenStore for EvilTokenStore {
fn insert(&self, _server_name: &str, token: Bytes) {
let mut lock = self.0.lock().unwrap();
if lock.is_empty() {
*lock = token;
}
}
fn take(&self, _server_name: &str) -> Option<Bytes> {
let lock = self.0.lock().unwrap();
if lock.is_empty() {
None
} else {
Some(lock.clone())
}
}
}
let _guard = subscribe();
let mut pair = Pair::default();
let mut client_config = client_config();
client_config.token_store(Arc::new(EvilTokenStore::default()));
let (client_ch, _server_ch) = pair.connect_with(client_config.clone());
pair.client
.connections
.get_mut(&client_ch)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
pair.server.handle_incoming = Box::new(|incoming| {
assert!(incoming.remote_address_validated());
assert!(incoming.may_retry());
IncomingConnectionBehavior::Accept
});
let (client_ch_2, _server_ch_2) = pair.connect_with(client_config.clone());
pair.client
.connections
.get_mut(&client_ch_2)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
pair.server.handle_incoming = Box::new(|incoming| {
assert!(!incoming.remote_address_validated());
assert!(incoming.may_retry());
IncomingConnectionBehavior::Accept
});
let (client_ch_3, _server_ch_3) = pair.connect_with(client_config);
pair.client
.connections
.get_mut(&client_ch_3)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}
#[test]
fn use_token_expired() {
let _guard = subscribe();
let fake_time = Arc::new(FakeTimeSource::new());
let lifetime = Duration::from_secs(10000);
let mut server_config = server_config();
server_config
.time_source(Arc::clone(&fake_time) as _)
.validation_token
.lifetime(lifetime);
let mut pair = Pair::new(Default::default(), server_config);
let client_config = client_config();
let (client_ch, _server_ch) = pair.connect_with(client_config.clone());
pair.client
.connections
.get_mut(&client_ch)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
pair.server.handle_incoming = Box::new(|incoming| {
assert!(incoming.remote_address_validated());
assert!(incoming.may_retry());
IncomingConnectionBehavior::Accept
});
let (client_ch_2, _server_ch_2) = pair.connect_with(client_config.clone());
pair.client
.connections
.get_mut(&client_ch_2)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
fake_time.advance(lifetime + Duration::from_secs(1));
pair.server.handle_incoming = Box::new(|incoming| {
assert!(!incoming.remote_address_validated());
assert!(incoming.may_retry());
IncomingConnectionBehavior::Accept
});
let (client_ch_3, _server_ch_3) = pair.connect_with(client_config);
pair.client
.connections
.get_mut(&client_ch_3)
.unwrap()
.close(pair.time, VarInt(42), Bytes::new());
pair.drive();
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}
pub(super) struct FakeTimeSource(Mutex<SystemTime>);
impl FakeTimeSource {
pub(super) fn new() -> Self {
Self(Mutex::new(SystemTime::now()))
}
pub(super) fn advance(&self, dur: Duration) {
*self.0.lock().unwrap() += dur;
}
}
impl TimeSource for FakeTimeSource {
fn now(&self) -> SystemTime {
*self.0.lock().unwrap()
}
}

745
vendor/quinn-proto/src/tests/util.rs vendored Normal file
View File

@@ -0,0 +1,745 @@
use std::{
cmp,
collections::{HashMap, HashSet, VecDeque},
env,
io::{self, Write},
mem,
net::{Ipv6Addr, SocketAddr, UdpSocket},
ops::RangeFrom,
str,
sync::{Arc, Mutex},
};
use assert_matches::assert_matches;
use bytes::BytesMut;
use lazy_static::lazy_static;
use rustls::{
KeyLogFile,
client::WebPkiServerVerifier,
pki_types::{CertificateDer, PrivateKeyDer},
};
use tracing::{info_span, trace};
use super::crypto::rustls::{QuicClientConfig, QuicServerConfig, configured_provider};
use super::*;
use crate::{Duration, Instant};
pub(super) const DEFAULT_MTU: usize = 1452;
pub(super) struct Pair {
pub(super) server: TestEndpoint,
pub(super) client: TestEndpoint,
/// Start time
epoch: Instant,
/// Current time
pub(super) time: Instant,
/// Simulates the maximum size allowed for UDP payloads by the link (packets exceeding this size will be dropped)
pub(super) mtu: usize,
/// Simulates explicit congestion notification
pub(super) congestion_experienced: bool,
// One-way
pub(super) latency: Duration,
/// Number of spin bit flips
pub(super) spins: u64,
last_spin: bool,
}
impl Pair {
pub(super) fn default_with_deterministic_pns() -> Self {
let mut cfg = server_config();
let mut transport = TransportConfig::default();
transport.deterministic_packet_numbers(true);
cfg.transport = Arc::new(transport);
Self::new(Default::default(), cfg)
}
pub(super) fn new(endpoint_config: Arc<EndpointConfig>, server_config: ServerConfig) -> Self {
let server = Endpoint::new(
endpoint_config.clone(),
Some(Arc::new(server_config)),
true,
None,
);
let client = Endpoint::new(endpoint_config, None, true, None);
Self::new_from_endpoint(client, server)
}
pub(super) fn new_from_endpoint(client: Endpoint, server: Endpoint) -> Self {
let server_addr = SocketAddr::new(
Ipv6Addr::LOCALHOST.into(),
SERVER_PORTS.lock().unwrap().next().unwrap(),
);
let client_addr = SocketAddr::new(
Ipv6Addr::LOCALHOST.into(),
CLIENT_PORTS.lock().unwrap().next().unwrap(),
);
let now = Instant::now();
Self {
server: TestEndpoint::new(server, server_addr),
client: TestEndpoint::new(client, client_addr),
epoch: now,
time: now,
mtu: DEFAULT_MTU,
latency: Duration::ZERO,
spins: 0,
last_spin: false,
congestion_experienced: false,
}
}
/// Returns whether the connection is not idle
pub(super) fn step(&mut self) -> bool {
self.drive_client();
self.drive_server();
if self.client.is_idle() && self.server.is_idle() {
return false;
}
let client_t = self.client.next_wakeup();
let server_t = self.server.next_wakeup();
match min_opt(client_t, server_t) {
Some(t) if Some(t) == client_t => {
if t != self.time {
self.time = self.time.max(t);
trace!("advancing to {:?} for client", self.time - self.epoch);
}
true
}
Some(t) if Some(t) == server_t => {
if t != self.time {
self.time = self.time.max(t);
trace!("advancing to {:?} for server", self.time - self.epoch);
}
true
}
Some(_) => unreachable!(),
None => false,
}
}
/// Advance time until both connections are idle
pub(super) fn drive(&mut self) {
while self.step() {}
}
/// Advance time until both connections are idle, or after 100 steps have been executed
///
/// Returns true if the amount of steps exceeds the bounds, because the connections never became
/// idle
pub(super) fn drive_bounded(&mut self) -> bool {
for _ in 0..100 {
if !self.step() {
return false;
}
}
true
}
pub(super) fn drive_client(&mut self) {
let span = info_span!("client");
let _guard = span.enter();
self.client.drive(self.time, self.server.addr);
for (packet, buffer) in self.client.outbound.drain(..) {
let packet_size = packet_size(&packet, &buffer);
if packet_size > self.mtu {
info!(packet_size, "dropping packet (max size exceeded)");
continue;
}
if buffer[0] & packet::LONG_HEADER_FORM == 0 {
let spin = buffer[0] & packet::SPIN_BIT != 0;
self.spins += (spin == self.last_spin) as u64;
self.last_spin = spin;
}
if let Some(ref socket) = self.client.socket {
socket.send_to(&buffer, packet.destination).unwrap();
}
if self.server.addr == packet.destination {
let ecn = set_congestion_experienced(packet.ecn, self.congestion_experienced);
self.server.inbound.push_back((
self.time + self.latency,
ecn,
buffer.as_ref().into(),
));
}
}
}
pub(super) fn drive_server(&mut self) {
let span = info_span!("server");
let _guard = span.enter();
self.server.drive(self.time, self.client.addr);
for (packet, buffer) in self.server.outbound.drain(..) {
let packet_size = packet_size(&packet, &buffer);
if packet_size > self.mtu {
info!(packet_size, "dropping packet (max size exceeded)");
continue;
}
if let Some(ref socket) = self.server.socket {
socket.send_to(&buffer, packet.destination).unwrap();
}
if self.client.addr == packet.destination {
let ecn = set_congestion_experienced(packet.ecn, self.congestion_experienced);
self.client.inbound.push_back((
self.time + self.latency,
ecn,
buffer.as_ref().into(),
));
}
}
}
pub(super) fn connect(&mut self) -> (ConnectionHandle, ConnectionHandle) {
self.connect_with(client_config())
}
pub(super) fn connect_with(
&mut self,
config: ClientConfig,
) -> (ConnectionHandle, ConnectionHandle) {
info!("connecting");
let client_ch = self.begin_connect(config);
self.drive();
let server_ch = self.server.assert_accept();
self.finish_connect(client_ch, server_ch);
(client_ch, server_ch)
}
/// Just start connecting the client
pub(super) fn begin_connect(&mut self, config: ClientConfig) -> ConnectionHandle {
let span = info_span!("client");
let _guard = span.enter();
let (client_ch, client_conn) = self
.client
.connect(self.time, config, self.server.addr, "localhost")
.unwrap();
self.client.connections.insert(client_ch, client_conn);
client_ch
}
fn finish_connect(&mut self, client_ch: ConnectionHandle, server_ch: ConnectionHandle) {
assert_matches!(
self.client_conn_mut(client_ch).poll(),
Some(Event::HandshakeDataReady)
);
assert_matches!(
self.client_conn_mut(client_ch).poll(),
Some(Event::Connected)
);
assert_matches!(
self.server_conn_mut(server_ch).poll(),
Some(Event::HandshakeDataReady)
);
assert_matches!(
self.server_conn_mut(server_ch).poll(),
Some(Event::Connected)
);
}
pub(super) fn client_conn_mut(&mut self, ch: ConnectionHandle) -> &mut Connection {
self.client.connections.get_mut(&ch).unwrap()
}
pub(super) fn client_streams(&mut self, ch: ConnectionHandle) -> Streams<'_> {
self.client_conn_mut(ch).streams()
}
pub(super) fn client_send(&mut self, ch: ConnectionHandle, s: StreamId) -> SendStream<'_> {
self.client_conn_mut(ch).send_stream(s)
}
pub(super) fn client_recv(&mut self, ch: ConnectionHandle, s: StreamId) -> RecvStream<'_> {
self.client_conn_mut(ch).recv_stream(s)
}
pub(super) fn client_datagrams(&mut self, ch: ConnectionHandle) -> Datagrams<'_> {
self.client_conn_mut(ch).datagrams()
}
pub(super) fn server_conn_mut(&mut self, ch: ConnectionHandle) -> &mut Connection {
self.server.connections.get_mut(&ch).unwrap()
}
pub(super) fn server_streams(&mut self, ch: ConnectionHandle) -> Streams<'_> {
self.server_conn_mut(ch).streams()
}
pub(super) fn server_send(&mut self, ch: ConnectionHandle, s: StreamId) -> SendStream<'_> {
self.server_conn_mut(ch).send_stream(s)
}
pub(super) fn server_recv(&mut self, ch: ConnectionHandle, s: StreamId) -> RecvStream<'_> {
self.server_conn_mut(ch).recv_stream(s)
}
pub(super) fn server_datagrams(&mut self, ch: ConnectionHandle) -> Datagrams<'_> {
self.server_conn_mut(ch).datagrams()
}
}
impl Default for Pair {
fn default() -> Self {
Self::new(Default::default(), server_config())
}
}
pub(super) struct TestEndpoint {
pub(super) endpoint: Endpoint,
pub(super) addr: SocketAddr,
socket: Option<UdpSocket>,
timeout: Option<Instant>,
pub(super) outbound: VecDeque<(Transmit, Bytes)>,
delayed: VecDeque<(Transmit, Bytes)>,
pub(super) inbound: VecDeque<(Instant, Option<EcnCodepoint>, BytesMut)>,
accepted: Option<Result<ConnectionHandle, ConnectionError>>,
pub(super) connections: HashMap<ConnectionHandle, Connection>,
conn_events: HashMap<ConnectionHandle, VecDeque<ConnectionEvent>>,
pub(super) captured_packets: Vec<Vec<u8>>,
pub(super) capture_inbound_packets: bool,
pub(super) handle_incoming: Box<dyn FnMut(&Incoming) -> IncomingConnectionBehavior>,
pub(super) waiting_incoming: Vec<Incoming>,
}
#[derive(Debug, Copy, Clone)]
pub(super) enum IncomingConnectionBehavior {
Accept,
Reject,
Retry,
Wait,
}
pub(super) fn validate_incoming(incoming: &Incoming) -> IncomingConnectionBehavior {
if incoming.remote_address_validated() {
IncomingConnectionBehavior::Accept
} else {
IncomingConnectionBehavior::Retry
}
}
impl TestEndpoint {
fn new(endpoint: Endpoint, addr: SocketAddr) -> Self {
let socket = if env::var_os("SSLKEYLOGFILE").is_some() {
let socket = UdpSocket::bind(addr).expect("failed to bind UDP socket");
socket
.set_read_timeout(Some(Duration::from_millis(10)))
.unwrap();
Some(socket)
} else {
None
};
Self {
endpoint,
addr,
socket,
timeout: None,
outbound: VecDeque::new(),
delayed: VecDeque::new(),
inbound: VecDeque::new(),
accepted: None,
connections: HashMap::default(),
conn_events: HashMap::default(),
captured_packets: Vec::new(),
capture_inbound_packets: false,
handle_incoming: Box::new(|_| IncomingConnectionBehavior::Accept),
waiting_incoming: Vec::new(),
}
}
pub(super) fn drive(&mut self, now: Instant, remote: SocketAddr) {
self.drive_incoming(now, remote);
self.drive_outgoing(now);
}
pub(super) fn drive_incoming(&mut self, now: Instant, remote: SocketAddr) {
if let Some(ref socket) = self.socket {
loop {
let mut buf = [0; 8192];
if socket.recv_from(&mut buf).is_err() {
break;
}
}
}
let buffer_size = self.endpoint.config().get_max_udp_payload_size() as usize;
let mut buf = Vec::with_capacity(buffer_size);
while self.inbound.front().is_some_and(|x| x.0 <= now) {
let (recv_time, ecn, packet) = self.inbound.pop_front().unwrap();
if let Some(event) = self
.endpoint
.handle(recv_time, remote, None, ecn, packet, &mut buf)
{
match event {
DatagramEvent::NewConnection(incoming) => {
match (self.handle_incoming)(&incoming) {
IncomingConnectionBehavior::Accept => {
let _ = self.try_accept(incoming, now);
}
IncomingConnectionBehavior::Reject => {
self.reject(incoming);
}
IncomingConnectionBehavior::Retry => {
self.retry(incoming);
}
IncomingConnectionBehavior::Wait => {
self.waiting_incoming.push(incoming);
}
}
}
DatagramEvent::ConnectionEvent(ch, event) => {
if self.capture_inbound_packets {
let packet = self.connections[&ch].decode_packet(&event);
self.captured_packets.extend(packet);
}
self.conn_events.entry(ch).or_default().push_back(event);
}
DatagramEvent::Response(transmit) => {
let size = transmit.size;
self.outbound.extend(split_transmit(transmit, &buf[..size]));
buf.clear();
}
}
}
}
}
pub(super) fn drive_outgoing(&mut self, now: Instant) {
let buffer_size = self.endpoint.config().get_max_udp_payload_size() as usize;
let mut buf = Vec::with_capacity(buffer_size);
loop {
let mut endpoint_events: Vec<(ConnectionHandle, EndpointEvent)> = vec![];
for (ch, conn) in self.connections.iter_mut() {
if self.timeout.is_some_and(|x| x <= now) {
self.timeout = None;
conn.handle_timeout(now);
}
for (_, mut events) in self.conn_events.drain() {
for event in events.drain(..) {
conn.handle_event(event);
}
}
while let Some(event) = conn.poll_endpoint_events() {
endpoint_events.push((*ch, event));
}
while let Some(transmit) = conn.poll_transmit(now, MAX_DATAGRAMS, &mut buf) {
let size = transmit.size;
self.outbound.extend(split_transmit(transmit, &buf[..size]));
buf.clear();
}
self.timeout = conn.poll_timeout();
}
if endpoint_events.is_empty() {
break;
}
for (ch, event) in endpoint_events {
if let Some(event) = self.handle_event(ch, event) {
if let Some(conn) = self.connections.get_mut(&ch) {
conn.handle_event(event);
}
}
}
}
}
pub(super) fn next_wakeup(&self) -> Option<Instant> {
let next_inbound = self.inbound.front().map(|x| x.0);
min_opt(self.timeout, next_inbound)
}
fn is_idle(&self) -> bool {
self.connections.values().all(|x| x.is_idle())
}
pub(super) fn delay_outbound(&mut self) {
assert!(self.delayed.is_empty());
mem::swap(&mut self.delayed, &mut self.outbound);
}
pub(super) fn finish_delay(&mut self) {
self.outbound.extend(self.delayed.drain(..));
}
pub(super) fn try_accept(
&mut self,
incoming: Incoming,
now: Instant,
) -> Result<ConnectionHandle, ConnectionError> {
let mut buf = Vec::new();
match self.endpoint.accept(incoming, now, &mut buf, None) {
Ok((ch, conn)) => {
self.connections.insert(ch, conn);
self.accepted = Some(Ok(ch));
Ok(ch)
}
Err(error) => {
if let Some(transmit) = error.response {
let size = transmit.size;
self.outbound.extend(split_transmit(transmit, &buf[..size]));
}
self.accepted = Some(Err(error.cause.clone()));
Err(error.cause)
}
}
}
pub(super) fn retry(&mut self, incoming: Incoming) {
let mut buf = Vec::new();
let transmit = self.endpoint.retry(incoming, &mut buf).unwrap();
let size = transmit.size;
self.outbound.extend(split_transmit(transmit, &buf[..size]));
}
pub(super) fn reject(&mut self, incoming: Incoming) {
let mut buf = Vec::new();
let transmit = self.endpoint.refuse(incoming, &mut buf);
let size = transmit.size;
self.outbound.extend(split_transmit(transmit, &buf[..size]));
}
pub(super) fn assert_accept(&mut self) -> ConnectionHandle {
self.accepted
.take()
.expect("server didn't try connecting")
.expect("server experienced error connecting")
}
pub(super) fn assert_accept_error(&mut self) -> ConnectionError {
self.accepted
.take()
.expect("server didn't try connecting")
.expect_err("server did unexpectedly connect without error")
}
pub(super) fn assert_no_accept(&self) {
assert!(self.accepted.is_none(), "server did unexpectedly connect")
}
}
impl ::std::ops::Deref for TestEndpoint {
type Target = Endpoint;
fn deref(&self) -> &Endpoint {
&self.endpoint
}
}
impl ::std::ops::DerefMut for TestEndpoint {
fn deref_mut(&mut self) -> &mut Endpoint {
&mut self.endpoint
}
}
pub(super) fn subscribe() -> tracing::subscriber::DefaultGuard {
let builder = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::Level::TRACE)
.with_writer(|| TestWriter);
// tracing uses std::time to trace time, which panics in wasm.
#[cfg(all(target_family = "wasm", target_os = "unknown"))]
let builder = builder.without_time();
tracing::subscriber::set_default(builder.finish())
}
struct TestWriter;
impl Write for TestWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
print!(
"{}",
str::from_utf8(buf).expect("tried to log invalid UTF-8")
);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
io::stdout().flush()
}
}
pub(super) fn server_config() -> ServerConfig {
let mut config = ServerConfig::with_crypto(Arc::new(server_crypto()));
if !cfg!(feature = "bloom") {
config
.validation_token
.sent(2)
.log(Arc::new(SimpleTokenLog::default()));
}
config
}
pub(super) fn server_config_with_cert(
cert: CertificateDer<'static>,
key: PrivateKeyDer<'static>,
) -> ServerConfig {
let mut config = ServerConfig::with_crypto(Arc::new(server_crypto_with_cert(cert, key)));
config
.validation_token
.sent(2)
.log(Arc::new(SimpleTokenLog::default()));
config
}
pub(super) fn server_crypto() -> QuicServerConfig {
server_crypto_inner(None, None)
}
pub(super) fn server_crypto_with_alpn(alpn: Vec<Vec<u8>>) -> QuicServerConfig {
server_crypto_inner(None, Some(alpn))
}
pub(super) fn server_crypto_with_cert(
cert: CertificateDer<'static>,
key: PrivateKeyDer<'static>,
) -> QuicServerConfig {
server_crypto_inner(Some((cert, key)), None)
}
fn server_crypto_inner(
identity: Option<(CertificateDer<'static>, PrivateKeyDer<'static>)>,
alpn: Option<Vec<Vec<u8>>>,
) -> QuicServerConfig {
let (cert, key) = identity.unwrap_or_else(|| {
(
CERTIFIED_KEY.cert.der().clone(),
PrivateKeyDer::Pkcs8(CERTIFIED_KEY.signing_key.serialize_der().into()),
)
});
let mut config = QuicServerConfig::inner(vec![cert], key).unwrap();
if let Some(alpn) = alpn {
config.alpn_protocols = alpn;
}
config.try_into().unwrap()
}
pub(super) fn client_config() -> ClientConfig {
ClientConfig::new(Arc::new(client_crypto()))
}
pub(super) fn client_config_with_deterministic_pns() -> ClientConfig {
let mut cfg = ClientConfig::new(Arc::new(client_crypto()));
let mut transport = TransportConfig::default();
transport.deterministic_packet_numbers(true);
cfg.transport = Arc::new(transport);
cfg
}
pub(super) fn client_config_with_certs(certs: Vec<CertificateDer<'static>>) -> ClientConfig {
ClientConfig::new(Arc::new(client_crypto_inner(Some(certs), None)))
}
pub(super) fn client_crypto() -> QuicClientConfig {
client_crypto_inner(None, None)
}
pub(super) fn client_crypto_with_alpn(protocols: Vec<Vec<u8>>) -> QuicClientConfig {
client_crypto_inner(None, Some(protocols))
}
fn client_crypto_inner(
certs: Option<Vec<CertificateDer<'static>>>,
alpn: Option<Vec<Vec<u8>>>,
) -> QuicClientConfig {
let mut roots = rustls::RootCertStore::empty();
for cert in certs.unwrap_or_else(|| vec![CERTIFIED_KEY.cert.der().clone()]) {
roots.add(cert).unwrap();
}
let mut inner = QuicClientConfig::inner(
WebPkiServerVerifier::builder_with_provider(Arc::new(roots), configured_provider())
.build()
.unwrap(),
);
inner.key_log = Arc::new(KeyLogFile::new());
if let Some(alpn) = alpn {
inner.alpn_protocols = alpn;
}
inner.try_into().unwrap()
}
pub(super) fn min_opt<T: Ord>(x: Option<T>, y: Option<T>) -> Option<T> {
match (x, y) {
(Some(x), Some(y)) => Some(cmp::min(x, y)),
(Some(x), _) => Some(x),
(_, Some(y)) => Some(y),
_ => None,
}
}
/// The maximum of datagrams TestEndpoint will produce via `poll_transmit`
const MAX_DATAGRAMS: usize = 10;
fn split_transmit(transmit: Transmit, buffer: &[u8]) -> Vec<(Transmit, Bytes)> {
let mut buffer = Bytes::copy_from_slice(buffer);
let segment_size = match transmit.segment_size {
Some(segment_size) => segment_size,
_ => return vec![(transmit, buffer)],
};
let mut transmits = Vec::new();
while !buffer.is_empty() {
let end = segment_size.min(buffer.len());
let contents = buffer.split_to(end);
transmits.push((
Transmit {
destination: transmit.destination,
size: contents.len(),
ecn: transmit.ecn,
segment_size: None,
src_ip: transmit.src_ip,
},
contents,
));
}
transmits
}
fn packet_size(transmit: &Transmit, buffer: &Bytes) -> usize {
if transmit.segment_size.is_some() {
panic!("This transmit is meant to be split into multiple packets!");
}
buffer.len()
}
fn set_congestion_experienced(
x: Option<EcnCodepoint>,
congestion_experienced: bool,
) -> Option<EcnCodepoint> {
x.map(|codepoint| match congestion_experienced {
true => EcnCodepoint::Ce,
false => codepoint,
})
}
lazy_static! {
pub static ref SERVER_PORTS: Mutex<RangeFrom<u16>> = Mutex::new(4433..);
pub static ref CLIENT_PORTS: Mutex<RangeFrom<u16>> = Mutex::new(44433..);
pub(crate) static ref CERTIFIED_KEY: rcgen::CertifiedKey<rcgen::KeyPair> =
rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
}
#[derive(Default)]
struct SimpleTokenLog(Mutex<HashSet<u128>>);
impl TokenLog for SimpleTokenLog {
fn check_and_insert(
&self,
nonce: u128,
_issued: SystemTime,
_lifetime: Duration,
) -> Result<(), TokenReuseError> {
if self.0.lock().unwrap().insert(nonce) {
Ok(())
} else {
Err(TokenReuseError)
}
}
}

507
vendor/quinn-proto/src/token.rs vendored Normal file
View File

@@ -0,0 +1,507 @@
use std::{
fmt,
mem::size_of,
net::{IpAddr, SocketAddr},
};
use bytes::{Buf, BufMut, Bytes};
use rand::Rng;
use crate::{
Duration, RESET_TOKEN_SIZE, ServerConfig, SystemTime, UNIX_EPOCH,
coding::{BufExt, BufMutExt},
crypto::{HandshakeTokenKey, HmacKey},
packet::InitialHeader,
shared::ConnectionId,
};
/// Responsible for limiting clients' ability to reuse validation tokens
///
/// [_RFC 9000 § 8.1.4:_](https://www.rfc-editor.org/rfc/rfc9000.html#section-8.1.4)
///
/// > Attackers could replay tokens to use servers as amplifiers in DDoS attacks. To protect
/// > against such attacks, servers MUST ensure that replay of tokens is prevented or limited.
/// > Servers SHOULD ensure that tokens sent in Retry packets are only accepted for a short time,
/// > as they are returned immediately by clients. Tokens that are provided in NEW_TOKEN frames
/// > (Section 19.7) need to be valid for longer but SHOULD NOT be accepted multiple times.
/// > Servers are encouraged to allow tokens to be used only once, if possible; tokens MAY include
/// > additional information about clients to further narrow applicability or reuse.
///
/// `TokenLog` pertains only to tokens provided in NEW_TOKEN frames.
pub trait TokenLog: Send + Sync {
/// Record that the token was used and, ideally, return a token reuse error if the token may
/// have been already used previously
///
/// False negatives and false positives are both permissible. Called when a client uses an
/// address validation token.
///
/// Parameters:
/// - `nonce`: A server-generated random unique value for the token.
/// - `issued`: The time the server issued the token.
/// - `lifetime`: The expiration time of address validation tokens sent via NEW_TOKEN frames,
/// as configured by [`ServerValidationTokenConfig::lifetime`][1].
///
/// [1]: crate::ValidationTokenConfig::lifetime
///
/// ## Security & Performance
///
/// To the extent that it is possible to repeatedly trigger false negatives (returning `Ok` for
/// a token which has been reused), an attacker could use the server to perform [amplification
/// attacks][2]. The QUIC specification requires that this be limited, if not prevented fully.
///
/// A false positive (returning `Err` for a token which has never been used) is not a security
/// vulnerability; it is permissible for a `TokenLog` to always return `Err`. A false positive
/// causes the token to be ignored, which may cause the transmission of some 0.5-RTT data to be
/// delayed until the handshake completes, if a sufficient amount of 0.5-RTT data it sent.
///
/// [2]: https://en.wikipedia.org/wiki/Denial-of-service_attack#Amplification
fn check_and_insert(
&self,
nonce: u128,
issued: SystemTime,
lifetime: Duration,
) -> Result<(), TokenReuseError>;
}
/// Error for when a validation token may have been reused
pub struct TokenReuseError;
/// Null implementation of [`TokenLog`], which never accepts tokens
pub struct NoneTokenLog;
impl TokenLog for NoneTokenLog {
fn check_and_insert(&self, _: u128, _: SystemTime, _: Duration) -> Result<(), TokenReuseError> {
Err(TokenReuseError)
}
}
/// Responsible for storing validation tokens received from servers and retrieving them for use in
/// subsequent connections
pub trait TokenStore: Send + Sync {
/// Potentially store a token for later one-time use
///
/// Called when a NEW_TOKEN frame is received from the server.
fn insert(&self, server_name: &str, token: Bytes);
/// Try to find and take a token that was stored with the given server name
///
/// The same token must never be returned from `take` twice, as doing so can be used to
/// de-anonymize a client's traffic.
///
/// Called when trying to connect to a server. It is always ok for this to return `None`.
fn take(&self, server_name: &str) -> Option<Bytes>;
}
/// Null implementation of [`TokenStore`], which does not store any tokens
pub struct NoneTokenStore;
impl TokenStore for NoneTokenStore {
fn insert(&self, _: &str, _: Bytes) {}
fn take(&self, _: &str) -> Option<Bytes> {
None
}
}
/// State in an `Incoming` determined by a token or lack thereof
#[derive(Debug)]
pub(crate) struct IncomingToken {
pub(crate) retry_src_cid: Option<ConnectionId>,
pub(crate) orig_dst_cid: ConnectionId,
pub(crate) validated: bool,
}
impl IncomingToken {
/// Construct for an `Incoming` given the first packet header, or error if the connection
/// cannot be established
pub(crate) fn from_header(
header: &InitialHeader,
server_config: &ServerConfig,
remote_address: SocketAddr,
) -> Result<Self, InvalidRetryTokenError> {
let unvalidated = Self {
retry_src_cid: None,
orig_dst_cid: header.dst_cid,
validated: false,
};
// Decode token or short-circuit
if header.token.is_empty() {
return Ok(unvalidated);
}
// In cases where a token cannot be decrypted/decoded, we must allow for the possibility
// that this is caused not by client malfeasance, but by the token having been generated by
// an incompatible endpoint, e.g. a different version or a neighbor behind the same load
// balancer. In such cases we proceed as if there was no token.
//
// [_RFC 9000 § 8.1.3:_](https://www.rfc-editor.org/rfc/rfc9000.html#section-8.1.3-10)
//
// > If the token is invalid, then the server SHOULD proceed as if the client did not have
// > a validated address, including potentially sending a Retry packet.
let Some(retry) = Token::decode(&*server_config.token_key, &header.token) else {
return Ok(unvalidated);
};
// Validate token, then convert into Self
match retry.payload {
TokenPayload::Retry {
address,
orig_dst_cid,
issued,
} => {
if address != remote_address {
return Err(InvalidRetryTokenError);
}
if issued + server_config.retry_token_lifetime < server_config.time_source.now() {
return Err(InvalidRetryTokenError);
}
Ok(Self {
retry_src_cid: Some(header.dst_cid),
orig_dst_cid,
validated: true,
})
}
TokenPayload::Validation { ip, issued } => {
if ip != remote_address.ip() {
return Ok(unvalidated);
}
if issued + server_config.validation_token.lifetime
< server_config.time_source.now()
{
return Ok(unvalidated);
}
if server_config
.validation_token
.log
.check_and_insert(retry.nonce, issued, server_config.validation_token.lifetime)
.is_err()
{
return Ok(unvalidated);
}
Ok(Self {
retry_src_cid: None,
orig_dst_cid: header.dst_cid,
validated: true,
})
}
}
}
}
/// Error for a token being unambiguously from a Retry packet, and not valid
///
/// The connection cannot be established.
pub(crate) struct InvalidRetryTokenError;
/// Retry or validation token
pub(crate) struct Token {
/// Content that is encrypted from the client
pub(crate) payload: TokenPayload,
/// Randomly generated value, which must be unique, and is visible to the client
nonce: u128,
}
impl Token {
/// Construct with newly sampled randomness
pub(crate) fn new(payload: TokenPayload, rng: &mut impl Rng) -> Self {
Self {
nonce: rng.random(),
payload,
}
}
/// Encode and encrypt
pub(crate) fn encode(&self, key: &dyn HandshakeTokenKey) -> Vec<u8> {
let mut buf = Vec::new();
// Encode payload
match self.payload {
TokenPayload::Retry {
address,
orig_dst_cid,
issued,
} => {
buf.put_u8(TokenType::Retry as u8);
encode_addr(&mut buf, address);
orig_dst_cid.encode_long(&mut buf);
encode_unix_secs(&mut buf, issued);
}
TokenPayload::Validation { ip, issued } => {
buf.put_u8(TokenType::Validation as u8);
encode_ip(&mut buf, ip);
encode_unix_secs(&mut buf, issued);
}
}
// Encrypt
let aead_key = key.aead_from_hkdf(&self.nonce.to_le_bytes());
aead_key.seal(&mut buf, &[]).unwrap();
buf.extend(&self.nonce.to_le_bytes());
buf
}
/// Decode and decrypt
fn decode(key: &dyn HandshakeTokenKey, raw_token_bytes: &[u8]) -> Option<Self> {
// Decrypt
// MSRV: split_at_checked requires 1.80.0
let nonce_slice_start = raw_token_bytes.len().checked_sub(size_of::<u128>())?;
let (sealed_token, nonce_bytes) = raw_token_bytes.split_at(nonce_slice_start);
let nonce = u128::from_le_bytes(nonce_bytes.try_into().unwrap());
let aead_key = key.aead_from_hkdf(nonce_bytes);
let mut sealed_token = sealed_token.to_vec();
let data = aead_key.open(&mut sealed_token, &[]).ok()?;
// Decode payload
let mut reader = &data[..];
let payload = match TokenType::from_byte((&mut reader).get::<u8>().ok()?)? {
TokenType::Retry => TokenPayload::Retry {
address: decode_addr(&mut reader)?,
orig_dst_cid: ConnectionId::decode_long(&mut reader)?,
issued: decode_unix_secs(&mut reader)?,
},
TokenType::Validation => TokenPayload::Validation {
ip: decode_ip(&mut reader)?,
issued: decode_unix_secs(&mut reader)?,
},
};
if !reader.is_empty() {
// Consider extra bytes a decoding error (it may be from an incompatible endpoint)
return None;
}
Some(Self { nonce, payload })
}
}
/// Content of a [`Token`] that is encrypted from the client
pub(crate) enum TokenPayload {
/// Token originating from a Retry packet
Retry {
/// The client's address
address: SocketAddr,
/// The destination connection ID set in the very first packet from the client
orig_dst_cid: ConnectionId,
/// The time at which this token was issued
issued: SystemTime,
},
/// Token originating from a NEW_TOKEN frame
Validation {
/// The client's IP address (its port is likely to change between sessions)
ip: IpAddr,
/// The time at which this token was issued
issued: SystemTime,
},
}
/// Variant tag for a [`TokenPayload`]
#[derive(Copy, Clone)]
#[repr(u8)]
enum TokenType {
Retry = 0,
Validation = 1,
}
impl TokenType {
fn from_byte(n: u8) -> Option<Self> {
use TokenType::*;
[Retry, Validation].into_iter().find(|ty| *ty as u8 == n)
}
}
fn encode_addr(buf: &mut Vec<u8>, address: SocketAddr) {
encode_ip(buf, address.ip());
buf.put_u16(address.port());
}
fn decode_addr<B: Buf>(buf: &mut B) -> Option<SocketAddr> {
let ip = decode_ip(buf)?;
let port = buf.get().ok()?;
Some(SocketAddr::new(ip, port))
}
fn encode_ip(buf: &mut Vec<u8>, ip: IpAddr) {
match ip {
IpAddr::V4(x) => {
buf.put_u8(0);
buf.put_slice(&x.octets());
}
IpAddr::V6(x) => {
buf.put_u8(1);
buf.put_slice(&x.octets());
}
}
}
fn decode_ip<B: Buf>(buf: &mut B) -> Option<IpAddr> {
match buf.get::<u8>().ok()? {
0 => buf.get().ok().map(IpAddr::V4),
1 => buf.get().ok().map(IpAddr::V6),
_ => None,
}
}
fn encode_unix_secs(buf: &mut Vec<u8>, time: SystemTime) {
buf.write::<u64>(
time.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
);
}
fn decode_unix_secs<B: Buf>(buf: &mut B) -> Option<SystemTime> {
Some(UNIX_EPOCH + Duration::from_secs(buf.get::<u64>().ok()?))
}
/// Stateless reset token
///
/// Used for an endpoint to securely communicate that it has lost state for a connection.
#[allow(clippy::derived_hash_with_manual_eq)] // Custom PartialEq impl matches derived semantics
#[derive(Debug, Copy, Clone, Hash)]
pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]);
impl ResetToken {
pub(crate) fn new(key: &dyn HmacKey, id: ConnectionId) -> Self {
let mut signature = vec![0; key.signature_len()];
key.sign(&id, &mut signature);
// TODO: Server ID??
let mut result = [0; RESET_TOKEN_SIZE];
result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]);
result.into()
}
}
impl PartialEq for ResetToken {
fn eq(&self, other: &Self) -> bool {
crate::constant_time::eq(&self.0, &other.0)
}
}
impl Eq for ResetToken {}
impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken {
fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self {
Self(x)
}
}
impl std::ops::Deref for ResetToken {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.0
}
}
impl fmt::Display for ResetToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.iter() {
write!(f, "{byte:02x}")?;
}
Ok(())
}
}
#[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))]
mod test {
use super::*;
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use aws_lc_rs::hkdf;
use rand::prelude::*;
#[cfg(feature = "ring")]
use ring::hkdf;
fn token_round_trip(payload: TokenPayload) -> TokenPayload {
let rng = &mut rand::rng();
let token = Token::new(payload, rng);
let mut master_key = [0; 64];
rng.fill_bytes(&mut master_key);
let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
let encoded = token.encode(&prk);
let decoded = Token::decode(&prk, &encoded).expect("token didn't decrypt / decode");
assert_eq!(token.nonce, decoded.nonce);
decoded.payload
}
#[test]
fn retry_token_sanity() {
use crate::MAX_CID_SIZE;
use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator};
use crate::{Duration, UNIX_EPOCH};
use std::net::Ipv6Addr;
let address_1 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
let issued_1 = UNIX_EPOCH + Duration::from_secs(42); // Fractional seconds would be lost
let payload_1 = TokenPayload::Retry {
address: address_1,
orig_dst_cid: orig_dst_cid_1,
issued: issued_1,
};
let TokenPayload::Retry {
address: address_2,
orig_dst_cid: orig_dst_cid_2,
issued: issued_2,
} = token_round_trip(payload_1)
else {
panic!("token decoded as wrong variant");
};
assert_eq!(address_1, address_2);
assert_eq!(orig_dst_cid_1, orig_dst_cid_2);
assert_eq!(issued_1, issued_2);
}
#[test]
fn validation_token_sanity() {
use crate::{Duration, UNIX_EPOCH};
use std::net::Ipv6Addr;
let ip_1 = Ipv6Addr::LOCALHOST.into();
let issued_1 = UNIX_EPOCH + Duration::from_secs(42); // Fractional seconds would be lost
let payload_1 = TokenPayload::Validation {
ip: ip_1,
issued: issued_1,
};
let TokenPayload::Validation {
ip: ip_2,
issued: issued_2,
} = token_round_trip(payload_1)
else {
panic!("token decoded as wrong variant");
};
assert_eq!(ip_1, ip_2);
assert_eq!(issued_1, issued_2);
}
#[test]
fn invalid_token_returns_err() {
use super::*;
use rand::RngCore;
let rng = &mut rand::rng();
let mut master_key = [0; 64];
rng.fill_bytes(&mut master_key);
let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
let mut invalid_token = Vec::new();
let mut random_data = [0; 32];
rand::rng().fill_bytes(&mut random_data);
invalid_token.put_slice(&random_data);
// Assert: garbage sealed data returns err
assert!(Token::decode(&prk, &invalid_token).is_none());
}
}

View File

@@ -0,0 +1,246 @@
//! Storing tokens sent from servers in NEW_TOKEN frames and using them in subsequent connections
use std::{
collections::{HashMap, VecDeque, hash_map},
sync::{Arc, Mutex},
};
use bytes::Bytes;
use lru_slab::LruSlab;
use tracing::trace;
use crate::token::TokenStore;
/// `TokenStore` implementation that stores up to `N` tokens per server name for up to a
/// limited number of server names, in-memory
#[derive(Debug)]
pub struct TokenMemoryCache(Mutex<State>);
impl TokenMemoryCache {
/// Construct empty
pub fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
Self(Mutex::new(State::new(
max_server_names,
max_tokens_per_server,
)))
}
}
impl TokenStore for TokenMemoryCache {
fn insert(&self, server_name: &str, token: Bytes) {
trace!(%server_name, "storing token");
self.0.lock().unwrap().store(server_name, token)
}
fn take(&self, server_name: &str) -> Option<Bytes> {
let token = self.0.lock().unwrap().take(server_name);
trace!(%server_name, found=%token.is_some(), "taking token");
token
}
}
/// Defaults to a maximum of 256 servers and 2 tokens per server
impl Default for TokenMemoryCache {
fn default() -> Self {
Self::new(256, 2)
}
}
/// Lockable inner state of `TokenMemoryCache`
#[derive(Debug)]
struct State {
max_server_names: u32,
max_tokens_per_server: usize,
// map from server name to index in lru
lookup: HashMap<Arc<str>, u32>,
lru: LruSlab<CacheEntry>,
}
impl State {
fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
Self {
max_server_names,
max_tokens_per_server,
lookup: HashMap::new(),
lru: LruSlab::default(),
}
}
fn store(&mut self, server_name: &str, token: Bytes) {
if self.max_server_names == 0 {
// the rest of this method assumes that we can always insert a new entry so long as
// we're willing to evict a pre-existing entry. thus, an entry limit of 0 is an edge
// case we must short-circuit on now.
return;
}
if self.max_tokens_per_server == 0 {
// similarly to above, the rest of this method assumes that we can always push a new
// token to a queue so long as we're willing to evict a pre-existing token, so we
// short-circuit on the edge case of a token limit of 0.
return;
}
let server_name = Arc::<str>::from(server_name);
match self.lookup.entry(server_name.clone()) {
hash_map::Entry::Occupied(hmap_entry) => {
// key already exists, push the new token to its token queue
let tokens = &mut self.lru.get_mut(*hmap_entry.get()).tokens;
if tokens.len() >= self.max_tokens_per_server {
debug_assert!(tokens.len() == self.max_tokens_per_server);
tokens.pop_front().unwrap();
}
tokens.push_back(token);
}
hash_map::Entry::Vacant(hmap_entry) => {
// key does not yet exist, create a new one, evicting the oldest if necessary
let removed_key = if self.lru.len() >= self.max_server_names {
// unwrap safety: max_server_names is > 0, so there's at least one entry, so
// lru() is some
Some(self.lru.remove(self.lru.lru().unwrap()).server_name)
} else {
None
};
hmap_entry.insert(self.lru.insert(CacheEntry::new(server_name, token)));
// for borrowing reasons, we must defer removing the evicted hmap entry to here
if let Some(removed_slot) = removed_key {
let removed = self.lookup.remove(&removed_slot);
debug_assert!(removed.is_some());
}
}
};
}
fn take(&mut self, server_name: &str) -> Option<Bytes> {
let slab_key = *self.lookup.get(server_name)?;
// pop from entry's token queue
let entry = self.lru.get_mut(slab_key);
// unwrap safety: we never leave tokens empty
let token = entry.tokens.pop_front().unwrap();
if entry.tokens.is_empty() {
// token stack emptied, remove entry
self.lru.remove(slab_key);
self.lookup.remove(server_name);
}
Some(token)
}
}
/// Cache entry within `TokenMemoryCache`'s LRU slab
#[derive(Debug)]
struct CacheEntry {
server_name: Arc<str>,
// invariant: tokens is never empty
tokens: VecDeque<Bytes>,
}
impl CacheEntry {
/// Construct with a single token
fn new(server_name: Arc<str>, token: Bytes) -> Self {
let mut tokens = VecDeque::new();
tokens.push_back(token);
Self {
server_name,
tokens,
}
}
}
#[cfg(test)]
mod tests {
use std::collections::VecDeque;
use super::*;
use rand::prelude::*;
use rand_pcg::Pcg32;
fn new_rng() -> impl Rng {
Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeefu128.to_le_bytes())
}
#[test]
fn cache_test() {
let mut rng = new_rng();
const N: usize = 2;
for _ in 0..10 {
let mut cache_1: Vec<(u32, VecDeque<Bytes>)> = Vec::new(); // keep it sorted oldest to newest
let cache_2 = TokenMemoryCache::new(20, 2);
for i in 0..200 {
let server_name = rng.random::<u32>() % 10;
if rng.random_bool(0.666) {
// store
let token = Bytes::from(vec![i]);
println!("STORE {server_name} {token:?}");
if let Some((j, _)) = cache_1
.iter()
.enumerate()
.find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
{
let (_, mut queue) = cache_1.remove(j);
queue.push_back(token.clone());
if queue.len() > N {
queue.pop_front();
}
cache_1.push((server_name, queue));
} else {
let mut queue = VecDeque::new();
queue.push_back(token.clone());
cache_1.push((server_name, queue));
if cache_1.len() > 20 {
cache_1.remove(0);
}
}
cache_2.insert(&server_name.to_string(), token);
} else {
// take
println!("TAKE {server_name}");
let expecting = cache_1
.iter()
.enumerate()
.find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
.map(|(j, _)| j)
.map(|j| {
let (_, mut queue) = cache_1.remove(j);
let token = queue.pop_front().unwrap();
if !queue.is_empty() {
cache_1.push((server_name, queue));
}
token
});
println!("EXPECTING {expecting:?}");
assert_eq!(cache_2.take(&server_name.to_string()), expecting);
}
}
}
}
#[test]
fn zero_max_server_names() {
// test that this edge case doesn't panic
let cache = TokenMemoryCache::new(0, 2);
for i in 0..10 {
cache.insert(&i.to_string(), Bytes::from(vec![i]));
for j in 0..10 {
assert!(cache.take(&j.to_string()).is_none());
}
}
}
#[test]
fn zero_queue_length() {
// test that this edge case doesn't panic
let cache = TokenMemoryCache::new(256, 0);
for i in 0..10 {
cache.insert(&i.to_string(), Bytes::from(vec![i]));
for j in 0..10 {
assert!(cache.take(&j.to_string()).is_none());
}
}
}
}

View File

@@ -0,0 +1,132 @@
use std::fmt;
use bytes::{Buf, BufMut};
use crate::{
coding::{self, BufExt, BufMutExt},
frame,
};
/// Transport-level errors occur when a peer violates the protocol specification
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Error {
/// Type of error
pub code: Code,
/// Frame type that triggered the error
pub frame: Option<frame::FrameType>,
/// Human-readable explanation of the reason
pub reason: String,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.code.fmt(f)?;
if let Some(frame) = self.frame {
write!(f, " in {frame}")?;
}
if !self.reason.is_empty() {
write!(f, ": {}", self.reason)?;
}
Ok(())
}
}
impl std::error::Error for Error {}
impl From<Code> for Error {
fn from(x: Code) -> Self {
Self {
code: x,
frame: None,
reason: "".to_string(),
}
}
}
/// Transport-level error code
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct Code(u64);
impl Code {
/// Create QUIC error code from TLS alert code
pub fn crypto(code: u8) -> Self {
Self(0x100 | u64::from(code))
}
}
impl coding::Codec for Code {
fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
Ok(Self(buf.get_var()?))
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.write_var(self.0)
}
}
impl From<Code> for u64 {
fn from(x: Code) -> Self {
x.0
}
}
macro_rules! errors {
{$($name:ident($val:expr) $desc:expr;)*} => {
#[allow(non_snake_case, unused)]
impl Error {
$(
pub(crate) fn $name<T>(reason: T) -> Self where T: Into<String> {
Self {
code: Code::$name,
frame: None,
reason: reason.into(),
}
}
)*
}
impl Code {
$(#[doc = $desc] pub const $name: Self = Code($val);)*
}
impl fmt::Debug for Code {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
$($val => f.write_str(stringify!($name)),)*
x if (0x100..0x200).contains(&x) => write!(f, "Code::crypto({:02x})", self.0 as u8),
_ => write!(f, "Code({:x})", self.0),
}
}
}
impl fmt::Display for Code {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
$($val => f.write_str($desc),)*
// We're trying to be abstract over the crypto protocol, so human-readable descriptions here is tricky.
_ if self.0 >= 0x100 && self.0 < 0x200 => write!(f, "the cryptographic handshake failed: error {}", self.0 & 0xFF),
_ => f.write_str("unknown error"),
}
}
}
}
}
errors! {
NO_ERROR(0x0) "the connection is being closed abruptly in the absence of any error";
INTERNAL_ERROR(0x1) "the endpoint encountered an internal error and cannot continue with the connection";
CONNECTION_REFUSED(0x2) "the server refused to accept a new connection";
FLOW_CONTROL_ERROR(0x3) "received more data than permitted in advertised data limits";
STREAM_LIMIT_ERROR(0x4) "received a frame for a stream identifier that exceeded advertised the stream limit for the corresponding stream type";
STREAM_STATE_ERROR(0x5) "received a frame for a stream that was not in a state that permitted that frame";
FINAL_SIZE_ERROR(0x6) "received a STREAM frame or a RESET_STREAM frame containing a different final size to the one already established";
FRAME_ENCODING_ERROR(0x7) "received a frame that was badly formatted";
TRANSPORT_PARAMETER_ERROR(0x8) "received transport parameters that were badly formatted, included an invalid value, was absent even though it is mandatory, was present though it is forbidden, or is otherwise in error";
CONNECTION_ID_LIMIT_ERROR(0x9) "the number of connection IDs provided by the peer exceeds the advertised active_connection_id_limit";
PROTOCOL_VIOLATION(0xA) "detected an error with protocol compliance that was not covered by more specific error codes";
INVALID_TOKEN(0xB) "received an invalid Retry Token in a client Initial";
APPLICATION_ERROR(0xC) "the application or application protocol caused the connection to be closed during the handshake";
CRYPTO_BUFFER_EXCEEDED(0xD) "received more data in CRYPTO frames than can be buffered";
KEY_UPDATE_ERROR(0xE) "key update error";
AEAD_LIMIT_REACHED(0xF) "the endpoint has reached the confidentiality or integrity limit for the AEAD algorithm";
NO_VIABLE_PATH(0x10) "no viable network path exists";
}

View File

@@ -0,0 +1,874 @@
//! QUIC connection transport parameters
//!
//! The `TransportParameters` type is used to represent the transport parameters
//! negotiated by peers while establishing a QUIC connection. This process
//! happens as part of the establishment of the TLS session. As such, the types
//! contained in this modules should generally only be referred to by custom
//! implementations of the `crypto::Session` trait.
use std::{
convert::TryFrom,
net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
};
use bytes::{Buf, BufMut};
use rand::{Rng as _, RngCore, seq::SliceRandom as _};
use thiserror::Error;
use crate::{
LOC_CID_COUNT, MAX_CID_SIZE, MAX_STREAM_COUNT, RESET_TOKEN_SIZE, ResetToken, Side,
TIMER_GRANULARITY, TransportError, VarInt,
cid_generator::ConnectionIdGenerator,
cid_queue::CidQueue,
coding::{BufExt, BufMutExt, UnexpectedEnd},
config::{EndpointConfig, ServerConfig, TransportConfig},
shared::ConnectionId,
};
// Apply a given macro to a list of all the transport parameters having integer types, along with
// their codes and default values. Using this helps us avoid error-prone duplication of the
// contained information across decoding, encoding, and the `Default` impl. Whenever we want to do
// something with transport parameters, we'll handle the bulk of cases by writing a macro that
// takes a list of arguments in this form, then passing it to this macro.
macro_rules! apply_params {
($macro:ident) => {
$macro! {
// #[doc] name (id) = default,
/// Milliseconds, disabled if zero
max_idle_timeout(MaxIdleTimeout) = 0,
/// Limits the size of UDP payloads that the endpoint is willing to receive
max_udp_payload_size(MaxUdpPayloadSize) = 65527,
/// Initial value for the maximum amount of data that can be sent on the connection
initial_max_data(InitialMaxData) = 0,
/// Initial flow control limit for locally-initiated bidirectional streams
initial_max_stream_data_bidi_local(InitialMaxStreamDataBidiLocal) = 0,
/// Initial flow control limit for peer-initiated bidirectional streams
initial_max_stream_data_bidi_remote(InitialMaxStreamDataBidiRemote) = 0,
/// Initial flow control limit for unidirectional streams
initial_max_stream_data_uni(InitialMaxStreamDataUni) = 0,
/// Initial maximum number of bidirectional streams the peer may initiate
initial_max_streams_bidi(InitialMaxStreamsBidi) = 0,
/// Initial maximum number of unidirectional streams the peer may initiate
initial_max_streams_uni(InitialMaxStreamsUni) = 0,
/// Exponent used to decode the ACK Delay field in the ACK frame
ack_delay_exponent(AckDelayExponent) = 3,
/// Maximum amount of time in milliseconds by which the endpoint will delay sending
/// acknowledgments
max_ack_delay(MaxAckDelay) = 25,
/// Maximum number of connection IDs from the peer that an endpoint is willing to store
active_connection_id_limit(ActiveConnectionIdLimit) = 2,
}
};
}
macro_rules! make_struct {
{$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => {
/// Transport parameters used to negotiate connection-level preferences between peers
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct TransportParameters {
$($(#[$doc])* pub(crate) $name : VarInt,)*
/// Does the endpoint support active connection migration
pub(crate) disable_active_migration: bool,
/// Maximum size for datagram frames
pub(crate) max_datagram_frame_size: Option<VarInt>,
/// The value that the endpoint included in the Source Connection ID field of the first
/// Initial packet it sends for the connection
pub(crate) initial_src_cid: Option<ConnectionId>,
/// The endpoint is willing to receive QUIC packets containing any value for the fixed
/// bit
pub(crate) grease_quic_bit: bool,
/// Minimum amount of time in microseconds by which the endpoint is able to delay
/// sending acknowledgments
///
/// If a value is provided, it implies that the endpoint supports QUIC Acknowledgement
/// Frequency
pub(crate) min_ack_delay: Option<VarInt>,
// Server-only
/// The value of the Destination Connection ID field from the first Initial packet sent
/// by the client
pub(crate) original_dst_cid: Option<ConnectionId>,
/// The value that the server included in the Source Connection ID field of a Retry
/// packet
pub(crate) retry_src_cid: Option<ConnectionId>,
/// Token used by the client to verify a stateless reset from the server
pub(crate) stateless_reset_token: Option<ResetToken>,
/// The server's preferred address for communication after handshake completion
pub(crate) preferred_address: Option<PreferredAddress>,
/// The randomly generated reserved transport parameter to sustain future extensibility
/// of transport parameter extensions.
/// When present, it is included during serialization but ignored during deserialization.
pub(crate) grease_transport_parameter: Option<ReservedTransportParameter>,
/// Defines the order in which transport parameters are serialized.
///
/// This field is initialized only for outgoing `TransportParameters` instances and
/// is set to `None` for `TransportParameters` received from a peer.
pub(crate) write_order: Option<[u8; TransportParameterId::SUPPORTED.len()]>,
}
// We deliberately don't implement the `Default` trait, since that would be public, and
// downstream crates should never construct `TransportParameters` except by decoding those
// supplied by a peer.
impl TransportParameters {
/// Standard defaults, used if the peer does not supply a given parameter.
pub(crate) fn default() -> Self {
Self {
$($name: VarInt::from_u32($default),)*
disable_active_migration: false,
max_datagram_frame_size: None,
initial_src_cid: None,
grease_quic_bit: false,
min_ack_delay: None,
original_dst_cid: None,
retry_src_cid: None,
stateless_reset_token: None,
preferred_address: None,
grease_transport_parameter: None,
write_order: None,
}
}
}
}
}
apply_params!(make_struct);
impl TransportParameters {
pub(crate) fn new(
config: &TransportConfig,
endpoint_config: &EndpointConfig,
cid_gen: &dyn ConnectionIdGenerator,
initial_src_cid: ConnectionId,
server_config: Option<&ServerConfig>,
rng: &mut impl RngCore,
) -> Self {
Self {
initial_src_cid: Some(initial_src_cid),
initial_max_streams_bidi: config.max_concurrent_bidi_streams,
initial_max_streams_uni: config.max_concurrent_uni_streams,
initial_max_data: config.receive_window,
initial_max_stream_data_bidi_local: config.stream_receive_window,
initial_max_stream_data_bidi_remote: config.stream_receive_window,
initial_max_stream_data_uni: config.stream_receive_window,
max_udp_payload_size: endpoint_config.max_udp_payload_size,
max_idle_timeout: config.max_idle_timeout.unwrap_or(VarInt(0)),
disable_active_migration: server_config.is_some_and(|c| !c.migration),
active_connection_id_limit: if cid_gen.cid_len() == 0 {
2 // i.e. default, i.e. unsent
} else {
CidQueue::LEN as u32
}
.into(),
max_datagram_frame_size: config
.datagram_receive_buffer_size
.map(|x| (x.min(u16::MAX.into()) as u16).into()),
grease_quic_bit: endpoint_config.grease_quic_bit,
min_ack_delay: Some(
VarInt::from_u64(u64::try_from(TIMER_GRANULARITY.as_micros()).unwrap()).unwrap(),
),
grease_transport_parameter: Some(ReservedTransportParameter::random(rng)),
write_order: Some({
let mut order = std::array::from_fn(|i| i as u8);
order.shuffle(rng);
order
}),
..Self::default()
}
}
/// Check that these parameters are legal when resuming from
/// certain cached parameters
pub(crate) fn validate_resumption_from(&self, cached: &Self) -> Result<(), TransportError> {
if cached.active_connection_id_limit > self.active_connection_id_limit
|| cached.initial_max_data > self.initial_max_data
|| cached.initial_max_stream_data_bidi_local > self.initial_max_stream_data_bidi_local
|| cached.initial_max_stream_data_bidi_remote > self.initial_max_stream_data_bidi_remote
|| cached.initial_max_stream_data_uni > self.initial_max_stream_data_uni
|| cached.initial_max_streams_bidi > self.initial_max_streams_bidi
|| cached.initial_max_streams_uni > self.initial_max_streams_uni
|| cached.max_datagram_frame_size > self.max_datagram_frame_size
|| cached.grease_quic_bit && !self.grease_quic_bit
{
return Err(TransportError::PROTOCOL_VIOLATION(
"0-RTT accepted with incompatible transport parameters",
));
}
Ok(())
}
/// Maximum number of CIDs to issue to this peer
///
/// Consider both a) the active_connection_id_limit from the other end; and
/// b) LOC_CID_COUNT used locally
pub(crate) fn issue_cids_limit(&self) -> u64 {
self.active_connection_id_limit.0.min(LOC_CID_COUNT)
}
}
/// A server's preferred address
///
/// This is communicated as a transport parameter during TLS session establishment.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) struct PreferredAddress {
pub(crate) address_v4: Option<SocketAddrV4>,
pub(crate) address_v6: Option<SocketAddrV6>,
pub(crate) connection_id: ConnectionId,
pub(crate) stateless_reset_token: ResetToken,
}
impl PreferredAddress {
fn wire_size(&self) -> u16 {
4 + 2 + 16 + 2 + 1 + self.connection_id.len() as u16 + 16
}
fn write<W: BufMut>(&self, w: &mut W) {
w.write(self.address_v4.map_or(Ipv4Addr::UNSPECIFIED, |x| *x.ip()));
w.write::<u16>(self.address_v4.map_or(0, |x| x.port()));
w.write(self.address_v6.map_or(Ipv6Addr::UNSPECIFIED, |x| *x.ip()));
w.write::<u16>(self.address_v6.map_or(0, |x| x.port()));
w.write::<u8>(self.connection_id.len() as u8);
w.put_slice(&self.connection_id);
w.put_slice(&self.stateless_reset_token);
}
fn read<R: Buf>(r: &mut R) -> Result<Self, Error> {
let ip_v4 = r.get::<Ipv4Addr>()?;
let port_v4 = r.get::<u16>()?;
let ip_v6 = r.get::<Ipv6Addr>()?;
let port_v6 = r.get::<u16>()?;
let cid_len = r.get::<u8>()?;
if r.remaining() < cid_len as usize || cid_len > MAX_CID_SIZE as u8 {
return Err(Error::Malformed);
}
let mut stage = [0; MAX_CID_SIZE];
r.copy_to_slice(&mut stage[0..cid_len as usize]);
let cid = ConnectionId::new(&stage[0..cid_len as usize]);
if r.remaining() < 16 {
return Err(Error::Malformed);
}
let mut token = [0; RESET_TOKEN_SIZE];
r.copy_to_slice(&mut token);
let address_v4 = if ip_v4.is_unspecified() && port_v4 == 0 {
None
} else {
Some(SocketAddrV4::new(ip_v4, port_v4))
};
let address_v6 = if ip_v6.is_unspecified() && port_v6 == 0 {
None
} else {
Some(SocketAddrV6::new(ip_v6, port_v6, 0, 0))
};
if address_v4.is_none() && address_v6.is_none() {
return Err(Error::IllegalValue);
}
Ok(Self {
address_v4,
address_v6,
connection_id: cid,
stateless_reset_token: token.into(),
})
}
}
/// Errors encountered while decoding `TransportParameters`
#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
pub enum Error {
/// Parameters that are semantically invalid
#[error("parameter had illegal value")]
IllegalValue,
/// Catch-all error for problems while decoding transport parameters
#[error("parameters were malformed")]
Malformed,
}
impl From<Error> for TransportError {
fn from(e: Error) -> Self {
match e {
Error::IllegalValue => Self::TRANSPORT_PARAMETER_ERROR("illegal value"),
Error::Malformed => Self::TRANSPORT_PARAMETER_ERROR("malformed"),
}
}
}
impl From<UnexpectedEnd> for Error {
fn from(_: UnexpectedEnd) -> Self {
Self::Malformed
}
}
impl TransportParameters {
/// Encode `TransportParameters` into buffer
pub fn write<W: BufMut>(&self, w: &mut W) {
for idx in self
.write_order
.as_ref()
.unwrap_or(&std::array::from_fn(|i| i as u8))
{
let id = TransportParameterId::SUPPORTED[*idx as usize];
match id {
TransportParameterId::ReservedTransportParameter => {
if let Some(param) = self.grease_transport_parameter {
param.write(w);
}
}
TransportParameterId::StatelessResetToken => {
if let Some(ref x) = self.stateless_reset_token {
w.write_var(id as u64);
w.write_var(16);
w.put_slice(x);
}
}
TransportParameterId::DisableActiveMigration => {
if self.disable_active_migration {
w.write_var(id as u64);
w.write_var(0);
}
}
TransportParameterId::MaxDatagramFrameSize => {
if let Some(x) = self.max_datagram_frame_size {
w.write_var(id as u64);
w.write_var(x.size() as u64);
w.write(x);
}
}
TransportParameterId::PreferredAddress => {
if let Some(ref x) = self.preferred_address {
w.write_var(id as u64);
w.write_var(x.wire_size() as u64);
x.write(w);
}
}
TransportParameterId::OriginalDestinationConnectionId => {
if let Some(ref cid) = self.original_dst_cid {
w.write_var(id as u64);
w.write_var(cid.len() as u64);
w.put_slice(cid);
}
}
TransportParameterId::InitialSourceConnectionId => {
if let Some(ref cid) = self.initial_src_cid {
w.write_var(id as u64);
w.write_var(cid.len() as u64);
w.put_slice(cid);
}
}
TransportParameterId::RetrySourceConnectionId => {
if let Some(ref cid) = self.retry_src_cid {
w.write_var(id as u64);
w.write_var(cid.len() as u64);
w.put_slice(cid);
}
}
TransportParameterId::GreaseQuicBit => {
if self.grease_quic_bit {
w.write_var(id as u64);
w.write_var(0);
}
}
TransportParameterId::MinAckDelayDraft07 => {
if let Some(x) = self.min_ack_delay {
w.write_var(id as u64);
w.write_var(x.size() as u64);
w.write(x);
}
}
id => {
macro_rules! write_params {
{$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => {
match id {
$(TransportParameterId::$id => {
if self.$name.0 != $default {
w.write_var(id as u64);
w.write(VarInt::try_from(self.$name.size()).unwrap());
w.write(self.$name);
}
})*,
_ => {
unimplemented!("Missing implementation of write for transport parameter with code {id:?}");
}
}
}
}
apply_params!(write_params);
}
}
}
}
/// Decode `TransportParameters` from buffer
pub fn read<R: Buf>(side: Side, r: &mut R) -> Result<Self, Error> {
// Initialize to protocol-specified defaults
let mut params = Self::default();
// State to check for duplicate transport parameters.
macro_rules! param_state {
{$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => {{
struct ParamState {
$($name: bool,)*
}
ParamState {
$($name: false,)*
}
}}
}
let mut got = apply_params!(param_state);
while r.has_remaining() {
let id = r.get_var()?;
let len = r.get_var()?;
if (r.remaining() as u64) < len {
return Err(Error::Malformed);
}
let len = len as usize;
let Ok(id) = TransportParameterId::try_from(id) else {
// unknown transport parameters are ignored
r.advance(len);
continue;
};
match id {
TransportParameterId::OriginalDestinationConnectionId => {
decode_cid(len, &mut params.original_dst_cid, r)?
}
TransportParameterId::StatelessResetToken => {
if len != 16 || params.stateless_reset_token.is_some() {
return Err(Error::Malformed);
}
let mut tok = [0; RESET_TOKEN_SIZE];
r.copy_to_slice(&mut tok);
params.stateless_reset_token = Some(tok.into());
}
TransportParameterId::DisableActiveMigration => {
if len != 0 || params.disable_active_migration {
return Err(Error::Malformed);
}
params.disable_active_migration = true;
}
TransportParameterId::PreferredAddress => {
if params.preferred_address.is_some() {
return Err(Error::Malformed);
}
params.preferred_address = Some(PreferredAddress::read(&mut r.take(len))?);
}
TransportParameterId::InitialSourceConnectionId => {
decode_cid(len, &mut params.initial_src_cid, r)?
}
TransportParameterId::RetrySourceConnectionId => {
decode_cid(len, &mut params.retry_src_cid, r)?
}
TransportParameterId::MaxDatagramFrameSize => {
if len > 8 || params.max_datagram_frame_size.is_some() {
return Err(Error::Malformed);
}
params.max_datagram_frame_size = Some(r.get()?);
}
TransportParameterId::GreaseQuicBit => match len {
0 => params.grease_quic_bit = true,
_ => return Err(Error::Malformed),
},
TransportParameterId::MinAckDelayDraft07 => params.min_ack_delay = Some(r.get()?),
_ => {
macro_rules! parse {
{$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => {
match id {
$(TransportParameterId::$id => {
let value = r.get::<VarInt>()?;
if len != value.size() || got.$name { return Err(Error::Malformed); }
params.$name = value.into();
got.$name = true;
})*
_ => r.advance(len),
}
}
}
apply_params!(parse);
}
}
}
// Semantic validation
// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.26.1
if params.ack_delay_exponent.0 > 20
// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.28.1
|| params.max_ack_delay.0 >= 1 << 14
// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-6.2.1
|| params.active_connection_id_limit.0 < 2
// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.10.1
|| params.max_udp_payload_size.0 < 1200
// https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2
|| params.initial_max_streams_bidi.0 > MAX_STREAM_COUNT
|| params.initial_max_streams_uni.0 > MAX_STREAM_COUNT
// https://www.ietf.org/archive/id/draft-ietf-quic-ack-frequency-08.html#section-3-4
|| params.min_ack_delay.is_some_and(|min_ack_delay| {
// min_ack_delay uses microseconds, whereas max_ack_delay uses milliseconds
min_ack_delay.0 > params.max_ack_delay.0 * 1_000
})
// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-8
|| (side.is_server()
&& (params.original_dst_cid.is_some()
|| params.preferred_address.is_some()
|| params.retry_src_cid.is_some()
|| params.stateless_reset_token.is_some()))
// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.38.1
|| params
.preferred_address.is_some_and(|x| x.connection_id.is_empty())
{
return Err(Error::IllegalValue);
}
Ok(params)
}
}
/// A reserved transport parameter.
///
/// It has an identifier of the form 31 * N + 27 for the integer value of N.
/// Such identifiers are reserved to exercise the requirement that unknown transport parameters be ignored.
/// The reserved transport parameter has no semantics and can carry arbitrary values.
/// It may be included in transport parameters sent to the peer, and should be ignored when received.
///
/// See spec: <https://www.rfc-editor.org/rfc/rfc9000.html#section-18.1>
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) struct ReservedTransportParameter {
/// The reserved identifier of the transport parameter
id: VarInt,
/// Buffer to store the parameter payload
payload: [u8; Self::MAX_PAYLOAD_LEN],
/// The number of bytes to include in the wire format from the `payload` buffer
payload_len: usize,
}
impl ReservedTransportParameter {
/// Generates a transport parameter with a random payload and a reserved ID.
///
/// The implementation is inspired by quic-go and quiche:
/// 1. <https://github.com/quic-go/quic-go/blob/3e0a67b2476e1819752f04d75968de042b197b56/internal/wire/transport_parameters.go#L338-L344>
/// 2. <https://github.com/google/quiche/blob/cb1090b20c40e2f0815107857324e99acf6ec567/quiche/quic/core/crypto/transport_parameters.cc#L843-L860>
fn random(rng: &mut impl RngCore) -> Self {
let id = Self::generate_reserved_id(rng);
let payload_len = rng.random_range(0..Self::MAX_PAYLOAD_LEN);
let payload = {
let mut slice = [0u8; Self::MAX_PAYLOAD_LEN];
rng.fill_bytes(&mut slice[..payload_len]);
slice
};
Self {
id,
payload,
payload_len,
}
}
fn write(&self, w: &mut impl BufMut) {
w.write_var(self.id.0);
w.write_var(self.payload_len as u64);
w.put_slice(&self.payload[..self.payload_len]);
}
/// Generates a random reserved identifier of the form `31 * N + 27`, as required by RFC 9000.
/// Reserved transport parameter identifiers are used to test compliance with the requirement
/// that unknown transport parameters must be ignored by peers.
/// See: <https://www.rfc-editor.org/rfc/rfc9000.html#section-18.1> and <https://www.rfc-editor.org/rfc/rfc9000.html#section-22.3>
fn generate_reserved_id(rng: &mut impl RngCore) -> VarInt {
let id = {
let rand = rng.random_range(0u64..(1 << 62) - 27);
let n = rand / 31;
31 * n + 27
};
debug_assert!(
id % 31 == 27,
"generated id does not have the form of 31 * N + 27"
);
VarInt::from_u64(id).expect(
"generated id does fit into range of allowed transport parameter IDs: [0; 2^62)",
)
}
/// The maximum length of the payload to include as the parameter payload.
/// This value is not a specification-imposed limit but is chosen to match
/// the limit used by other implementations of QUIC, e.g., quic-go and quiche.
const MAX_PAYLOAD_LEN: usize = 16;
}
#[repr(u64)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TransportParameterId {
// https://www.rfc-editor.org/rfc/rfc9000.html#iana-tp-table
OriginalDestinationConnectionId = 0x00,
MaxIdleTimeout = 0x01,
StatelessResetToken = 0x02,
MaxUdpPayloadSize = 0x03,
InitialMaxData = 0x04,
InitialMaxStreamDataBidiLocal = 0x05,
InitialMaxStreamDataBidiRemote = 0x06,
InitialMaxStreamDataUni = 0x07,
InitialMaxStreamsBidi = 0x08,
InitialMaxStreamsUni = 0x09,
AckDelayExponent = 0x0A,
MaxAckDelay = 0x0B,
DisableActiveMigration = 0x0C,
PreferredAddress = 0x0D,
ActiveConnectionIdLimit = 0x0E,
InitialSourceConnectionId = 0x0F,
RetrySourceConnectionId = 0x10,
// Smallest possible ID of reserved transport parameter https://datatracker.ietf.org/doc/html/rfc9000#section-22.3
ReservedTransportParameter = 0x1B,
// https://www.rfc-editor.org/rfc/rfc9221.html#section-3
MaxDatagramFrameSize = 0x20,
// https://www.rfc-editor.org/rfc/rfc9287.html#section-3
GreaseQuicBit = 0x2AB2,
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-ack-frequency#section-10.1
MinAckDelayDraft07 = 0xFF04DE1B,
}
impl TransportParameterId {
/// Array with all supported transport parameter IDs
const SUPPORTED: [Self; 21] = [
Self::MaxIdleTimeout,
Self::MaxUdpPayloadSize,
Self::InitialMaxData,
Self::InitialMaxStreamDataBidiLocal,
Self::InitialMaxStreamDataBidiRemote,
Self::InitialMaxStreamDataUni,
Self::InitialMaxStreamsBidi,
Self::InitialMaxStreamsUni,
Self::AckDelayExponent,
Self::MaxAckDelay,
Self::ActiveConnectionIdLimit,
Self::ReservedTransportParameter,
Self::StatelessResetToken,
Self::DisableActiveMigration,
Self::MaxDatagramFrameSize,
Self::PreferredAddress,
Self::OriginalDestinationConnectionId,
Self::InitialSourceConnectionId,
Self::RetrySourceConnectionId,
Self::GreaseQuicBit,
Self::MinAckDelayDraft07,
];
}
impl std::cmp::PartialEq<u64> for TransportParameterId {
fn eq(&self, other: &u64) -> bool {
*other == (*self as u64)
}
}
impl TryFrom<u64> for TransportParameterId {
type Error = ();
fn try_from(value: u64) -> Result<Self, Self::Error> {
let param = match value {
id if Self::MaxIdleTimeout == id => Self::MaxIdleTimeout,
id if Self::MaxUdpPayloadSize == id => Self::MaxUdpPayloadSize,
id if Self::InitialMaxData == id => Self::InitialMaxData,
id if Self::InitialMaxStreamDataBidiLocal == id => Self::InitialMaxStreamDataBidiLocal,
id if Self::InitialMaxStreamDataBidiRemote == id => {
Self::InitialMaxStreamDataBidiRemote
}
id if Self::InitialMaxStreamDataUni == id => Self::InitialMaxStreamDataUni,
id if Self::InitialMaxStreamsBidi == id => Self::InitialMaxStreamsBidi,
id if Self::InitialMaxStreamsUni == id => Self::InitialMaxStreamsUni,
id if Self::AckDelayExponent == id => Self::AckDelayExponent,
id if Self::MaxAckDelay == id => Self::MaxAckDelay,
id if Self::ActiveConnectionIdLimit == id => Self::ActiveConnectionIdLimit,
id if Self::ReservedTransportParameter == id => Self::ReservedTransportParameter,
id if Self::StatelessResetToken == id => Self::StatelessResetToken,
id if Self::DisableActiveMigration == id => Self::DisableActiveMigration,
id if Self::MaxDatagramFrameSize == id => Self::MaxDatagramFrameSize,
id if Self::PreferredAddress == id => Self::PreferredAddress,
id if Self::OriginalDestinationConnectionId == id => {
Self::OriginalDestinationConnectionId
}
id if Self::InitialSourceConnectionId == id => Self::InitialSourceConnectionId,
id if Self::RetrySourceConnectionId == id => Self::RetrySourceConnectionId,
id if Self::GreaseQuicBit == id => Self::GreaseQuicBit,
id if Self::MinAckDelayDraft07 == id => Self::MinAckDelayDraft07,
_ => return Err(()),
};
Ok(param)
}
}
fn decode_cid(len: usize, value: &mut Option<ConnectionId>, r: &mut impl Buf) -> Result<(), Error> {
if len > MAX_CID_SIZE || value.is_some() || r.remaining() < len {
return Err(Error::Malformed);
}
*value = Some(ConnectionId::from_buf(r, len));
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn coding() {
let mut buf = Vec::new();
let params = TransportParameters {
initial_src_cid: Some(ConnectionId::new(&[])),
original_dst_cid: Some(ConnectionId::new(&[])),
initial_max_streams_bidi: 16u32.into(),
initial_max_streams_uni: 16u32.into(),
ack_delay_exponent: 2u32.into(),
max_udp_payload_size: 1200u32.into(),
preferred_address: Some(PreferredAddress {
address_v4: Some(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 42)),
address_v6: Some(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 24, 0, 0)),
connection_id: ConnectionId::new(&[0x42]),
stateless_reset_token: [0xab; RESET_TOKEN_SIZE].into(),
}),
grease_quic_bit: true,
min_ack_delay: Some(2_000u32.into()),
..TransportParameters::default()
};
params.write(&mut buf);
assert_eq!(
TransportParameters::read(Side::Client, &mut buf.as_slice()).unwrap(),
params
);
}
#[test]
fn reserved_transport_parameter_generate_reserved_id() {
let mut rngs = [
StepRng(0),
StepRng(1),
StepRng(27),
StepRng(31),
StepRng(u32::MAX as u64),
StepRng(u32::MAX as u64 - 1),
StepRng(u32::MAX as u64 + 1),
StepRng(u32::MAX as u64 - 27),
StepRng(u32::MAX as u64 + 27),
StepRng(u32::MAX as u64 - 31),
StepRng(u32::MAX as u64 + 31),
StepRng(u64::MAX),
StepRng(u64::MAX - 1),
StepRng(u64::MAX - 27),
StepRng(u64::MAX - 31),
StepRng(1 << 62),
StepRng((1 << 62) - 1),
StepRng((1 << 62) + 1),
StepRng((1 << 62) - 27),
StepRng((1 << 62) + 27),
StepRng((1 << 62) - 31),
StepRng((1 << 62) + 31),
];
for rng in &mut rngs {
let id = ReservedTransportParameter::generate_reserved_id(rng);
assert!(id.0 % 31 == 27)
}
}
struct StepRng(u64);
impl RngCore for StepRng {
#[inline]
fn next_u32(&mut self) -> u32 {
self.next_u64() as u32
}
#[inline]
fn next_u64(&mut self) -> u64 {
let res = self.0;
self.0 = self.0.wrapping_add(1);
res
}
#[inline]
fn fill_bytes(&mut self, dst: &mut [u8]) {
let mut left = dst;
while left.len() >= 8 {
let (l, r) = left.split_at_mut(8);
left = r;
l.copy_from_slice(&self.next_u64().to_le_bytes());
}
let n = left.len();
if n > 0 {
left.copy_from_slice(&self.next_u32().to_le_bytes()[..n]);
}
}
}
#[test]
fn reserved_transport_parameter_ignored_when_read() {
let mut buf = Vec::new();
let reserved_parameter = ReservedTransportParameter::random(&mut rand::rng());
assert!(reserved_parameter.payload_len < ReservedTransportParameter::MAX_PAYLOAD_LEN);
assert!(reserved_parameter.id.0 % 31 == 27);
reserved_parameter.write(&mut buf);
assert!(!buf.is_empty());
let read_params = TransportParameters::read(Side::Server, &mut buf.as_slice()).unwrap();
assert_eq!(read_params, TransportParameters::default());
}
#[test]
fn read_semantic_validation() {
#[allow(clippy::type_complexity)]
let illegal_params_builders: Vec<Box<dyn FnMut(&mut TransportParameters)>> = vec![
Box::new(|t| {
// This min_ack_delay is bigger than max_ack_delay!
let min_ack_delay = t.max_ack_delay.0 * 1_000 + 1;
t.min_ack_delay = Some(VarInt::from_u64(min_ack_delay).unwrap())
}),
Box::new(|t| {
// Preferred address can only be sent by senders (and we are reading the transport
// params as a client)
t.preferred_address = Some(PreferredAddress {
address_v4: Some(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 42)),
address_v6: None,
connection_id: ConnectionId::new(&[]),
stateless_reset_token: [0xab; RESET_TOKEN_SIZE].into(),
})
}),
];
for mut builder in illegal_params_builders {
let mut buf = Vec::new();
let mut params = TransportParameters::default();
builder(&mut params);
params.write(&mut buf);
assert_eq!(
TransportParameters::read(Side::Server, &mut buf.as_slice()),
Err(Error::IllegalValue)
);
}
}
#[test]
fn resumption_params_validation() {
let high_limit = TransportParameters {
initial_max_streams_uni: 32u32.into(),
..TransportParameters::default()
};
let low_limit = TransportParameters {
initial_max_streams_uni: 16u32.into(),
..TransportParameters::default()
};
high_limit.validate_resumption_from(&low_limit).unwrap();
low_limit.validate_resumption_from(&high_limit).unwrap_err();
}
}

193
vendor/quinn-proto/src/varint.rs vendored Normal file
View File

@@ -0,0 +1,193 @@
use std::{convert::TryInto, fmt};
use bytes::{Buf, BufMut};
use thiserror::Error;
use crate::coding::{self, Codec, UnexpectedEnd};
#[cfg(feature = "arbitrary")]
use arbitrary::Arbitrary;
/// An integer less than 2^62
///
/// Values of this type are suitable for encoding as QUIC variable-length integer.
// It would be neat if we could express to Rust that the top two bits are available for use as enum
// discriminants
#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct VarInt(pub(crate) u64);
impl VarInt {
/// The largest representable value
pub const MAX: Self = Self((1 << 62) - 1);
/// The largest encoded value length
pub const MAX_SIZE: usize = 8;
/// Construct a `VarInt` infallibly
pub const fn from_u32(x: u32) -> Self {
Self(x as u64)
}
/// Succeeds iff `x` < 2^62
pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
if x < 2u64.pow(62) {
Ok(Self(x))
} else {
Err(VarIntBoundsExceeded)
}
}
/// Create a VarInt without ensuring it's in range
///
/// # Safety
///
/// `x` must be less than 2^62.
pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
Self(x)
}
/// Extract the integer value
pub const fn into_inner(self) -> u64 {
self.0
}
/// Compute the number of bytes needed to encode this value
pub(crate) const fn size(self) -> usize {
let x = self.0;
if x < 2u64.pow(6) {
1
} else if x < 2u64.pow(14) {
2
} else if x < 2u64.pow(30) {
4
} else if x < 2u64.pow(62) {
8
} else {
panic!("malformed VarInt");
}
}
}
impl From<VarInt> for u64 {
fn from(x: VarInt) -> Self {
x.0
}
}
impl From<u8> for VarInt {
fn from(x: u8) -> Self {
Self(x.into())
}
}
impl From<u16> for VarInt {
fn from(x: u16) -> Self {
Self(x.into())
}
}
impl From<u32> for VarInt {
fn from(x: u32) -> Self {
Self(x.into())
}
}
impl std::convert::TryFrom<u64> for VarInt {
type Error = VarIntBoundsExceeded;
/// Succeeds iff `x` < 2^62
fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
Self::from_u64(x)
}
}
impl std::convert::TryFrom<u128> for VarInt {
type Error = VarIntBoundsExceeded;
/// Succeeds iff `x` < 2^62
fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
}
}
impl std::convert::TryFrom<usize> for VarInt {
type Error = VarIntBoundsExceeded;
/// Succeeds iff `x` < 2^62
fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
Self::try_from(x as u64)
}
}
impl fmt::Debug for VarInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for VarInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[cfg(feature = "arbitrary")]
impl<'arbitrary> Arbitrary<'arbitrary> for VarInt {
fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result<Self> {
Ok(Self(u.int_in_range(0..=Self::MAX.0)?))
}
}
/// Error returned when constructing a `VarInt` from a value >= 2^62
#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
#[error("value too large for varint encoding")]
pub struct VarIntBoundsExceeded;
impl Codec for VarInt {
fn decode<B: Buf>(r: &mut B) -> coding::Result<Self> {
if !r.has_remaining() {
return Err(UnexpectedEnd);
}
let mut buf = [0; 8];
buf[0] = r.get_u8();
let tag = buf[0] >> 6;
buf[0] &= 0b0011_1111;
let x = match tag {
0b00 => u64::from(buf[0]),
0b01 => {
if r.remaining() < 1 {
return Err(UnexpectedEnd);
}
r.copy_to_slice(&mut buf[1..2]);
u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
}
0b10 => {
if r.remaining() < 3 {
return Err(UnexpectedEnd);
}
r.copy_to_slice(&mut buf[1..4]);
u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
}
0b11 => {
if r.remaining() < 7 {
return Err(UnexpectedEnd);
}
r.copy_to_slice(&mut buf[1..8]);
u64::from_be_bytes(buf)
}
_ => unreachable!(),
};
Ok(Self(x))
}
fn encode<B: BufMut>(&self, w: &mut B) {
let x = self.0;
if x < 2u64.pow(6) {
w.put_u8(x as u8);
} else if x < 2u64.pow(14) {
w.put_u16((0b01 << 14) | x as u16);
} else if x < 2u64.pow(30) {
w.put_u32((0b10 << 30) | x as u32);
} else if x < 2u64.pow(62) {
w.put_u64((0b11 << 62) | x);
} else {
unreachable!("malformed VarInt")
}
}
}

Binary file not shown.

File diff suppressed because one or more lines are too long