chore: checkpoint before Python removal

This commit is contained in:
2026-03-26 22:33:59 +00:00
parent 683cec9307
commit e568ddf82a
29972 changed files with 11269302 additions and 2 deletions

View File

@@ -0,0 +1,387 @@
#![allow(unused_imports)]
use crate::compression::CompressionLevel;
use crate::{
compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody},
BoxError,
};
#[cfg(feature = "compression-br")]
use async_compression::tokio::bufread::BrotliEncoder;
#[cfg(feature = "compression-gzip")]
use async_compression::tokio::bufread::GzipEncoder;
#[cfg(feature = "compression-deflate")]
use async_compression::tokio::bufread::ZlibEncoder;
#[cfg(feature = "compression-zstd")]
use async_compression::tokio::bufread::ZstdEncoder;
use bytes::{Buf, Bytes};
use http::HeaderMap;
use http_body::Body;
use pin_project_lite::pin_project;
use std::{
io,
marker::PhantomData,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio_util::io::StreamReader;
use super::pin_project_cfg::pin_project_cfg;
pin_project! {
/// Response body of [`Compression`].
///
/// [`Compression`]: super::Compression
pub struct CompressionBody<B>
where
B: Body,
{
#[pin]
pub(crate) inner: BodyInner<B>,
}
}
impl<B> Default for CompressionBody<B>
where
B: Body + Default,
{
fn default() -> Self {
Self {
inner: BodyInner::Identity {
inner: B::default(),
},
}
}
}
impl<B> CompressionBody<B>
where
B: Body,
{
pub(crate) fn new(inner: BodyInner<B>) -> Self {
Self { inner }
}
/// Get a reference to the inner body
pub fn get_ref(&self) -> &B {
match &self.inner {
#[cfg(feature = "compression-gzip")]
BodyInner::Gzip { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
#[cfg(feature = "compression-deflate")]
BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
#[cfg(feature = "compression-br")]
BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
#[cfg(feature = "compression-zstd")]
BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
BodyInner::Identity { inner } => inner,
}
}
/// Get a mutable reference to the inner body
pub fn get_mut(&mut self) -> &mut B {
match &mut self.inner {
#[cfg(feature = "compression-gzip")]
BodyInner::Gzip { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
#[cfg(feature = "compression-deflate")]
BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
#[cfg(feature = "compression-br")]
BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
#[cfg(feature = "compression-zstd")]
BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
BodyInner::Identity { inner } => inner,
}
}
/// Get a pinned mutable reference to the inner body
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
match self.project().inner.project() {
#[cfg(feature = "compression-gzip")]
BodyInnerProj::Gzip { inner } => inner
.project()
.read
.get_pin_mut()
.get_pin_mut()
.get_pin_mut()
.get_pin_mut(),
#[cfg(feature = "compression-deflate")]
BodyInnerProj::Deflate { inner } => inner
.project()
.read
.get_pin_mut()
.get_pin_mut()
.get_pin_mut()
.get_pin_mut(),
#[cfg(feature = "compression-br")]
BodyInnerProj::Brotli { inner } => inner
.project()
.read
.get_pin_mut()
.get_pin_mut()
.get_pin_mut()
.get_pin_mut(),
#[cfg(feature = "compression-zstd")]
BodyInnerProj::Zstd { inner } => inner
.project()
.read
.get_pin_mut()
.get_pin_mut()
.get_pin_mut()
.get_pin_mut(),
BodyInnerProj::Identity { inner } => inner,
}
}
/// Consume `self`, returning the inner body
pub fn into_inner(self) -> B {
match self.inner {
#[cfg(feature = "compression-gzip")]
BodyInner::Gzip { inner } => inner
.read
.into_inner()
.into_inner()
.into_inner()
.into_inner(),
#[cfg(feature = "compression-deflate")]
BodyInner::Deflate { inner } => inner
.read
.into_inner()
.into_inner()
.into_inner()
.into_inner(),
#[cfg(feature = "compression-br")]
BodyInner::Brotli { inner } => inner
.read
.into_inner()
.into_inner()
.into_inner()
.into_inner(),
#[cfg(feature = "compression-zstd")]
BodyInner::Zstd { inner } => inner
.read
.into_inner()
.into_inner()
.into_inner()
.into_inner(),
BodyInner::Identity { inner } => inner,
}
}
}
#[cfg(feature = "compression-gzip")]
type GzipBody<B> = WrapBody<GzipEncoder<B>>;
#[cfg(feature = "compression-deflate")]
type DeflateBody<B> = WrapBody<ZlibEncoder<B>>;
#[cfg(feature = "compression-br")]
type BrotliBody<B> = WrapBody<BrotliEncoder<B>>;
#[cfg(feature = "compression-zstd")]
type ZstdBody<B> = WrapBody<ZstdEncoder<B>>;
pin_project_cfg! {
#[project = BodyInnerProj]
pub(crate) enum BodyInner<B>
where
B: Body,
{
#[cfg(feature = "compression-gzip")]
Gzip {
#[pin]
inner: GzipBody<B>,
},
#[cfg(feature = "compression-deflate")]
Deflate {
#[pin]
inner: DeflateBody<B>,
},
#[cfg(feature = "compression-br")]
Brotli {
#[pin]
inner: BrotliBody<B>,
},
#[cfg(feature = "compression-zstd")]
Zstd {
#[pin]
inner: ZstdBody<B>,
},
Identity {
#[pin]
inner: B,
},
}
}
impl<B: Body> BodyInner<B> {
#[cfg(feature = "compression-gzip")]
pub(crate) fn gzip(inner: WrapBody<GzipEncoder<B>>) -> Self {
Self::Gzip { inner }
}
#[cfg(feature = "compression-deflate")]
pub(crate) fn deflate(inner: WrapBody<ZlibEncoder<B>>) -> Self {
Self::Deflate { inner }
}
#[cfg(feature = "compression-br")]
pub(crate) fn brotli(inner: WrapBody<BrotliEncoder<B>>) -> Self {
Self::Brotli { inner }
}
#[cfg(feature = "compression-zstd")]
pub(crate) fn zstd(inner: WrapBody<ZstdEncoder<B>>) -> Self {
Self::Zstd { inner }
}
pub(crate) fn identity(inner: B) -> Self {
Self::Identity { inner }
}
}
impl<B> Body for CompressionBody<B>
where
B: Body,
B::Error: Into<BoxError>,
{
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
match self.project().inner.project() {
#[cfg(feature = "compression-gzip")]
BodyInnerProj::Gzip { inner } => inner.poll_frame(cx),
#[cfg(feature = "compression-deflate")]
BodyInnerProj::Deflate { inner } => inner.poll_frame(cx),
#[cfg(feature = "compression-br")]
BodyInnerProj::Brotli { inner } => inner.poll_frame(cx),
#[cfg(feature = "compression-zstd")]
BodyInnerProj::Zstd { inner } => inner.poll_frame(cx),
BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) {
Some(Ok(frame)) => {
let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
Poll::Ready(Some(Ok(frame)))
}
Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
None => Poll::Ready(None),
},
}
}
fn size_hint(&self) -> http_body::SizeHint {
if let BodyInner::Identity { inner } = &self.inner {
inner.size_hint()
} else {
http_body::SizeHint::new()
}
}
fn is_end_stream(&self) -> bool {
if let BodyInner::Identity { inner } = &self.inner {
inner.is_end_stream()
} else {
false
}
}
}
#[cfg(feature = "compression-gzip")]
impl<B> DecorateAsyncRead for GzipEncoder<B>
where
B: Body,
{
type Input = AsyncReadBody<B>;
type Output = GzipEncoder<Self::Input>;
fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
GzipEncoder::with_quality(input, quality.into_async_compression())
}
fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
pinned.get_pin_mut()
}
}
#[cfg(feature = "compression-deflate")]
impl<B> DecorateAsyncRead for ZlibEncoder<B>
where
B: Body,
{
type Input = AsyncReadBody<B>;
type Output = ZlibEncoder<Self::Input>;
fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
ZlibEncoder::with_quality(input, quality.into_async_compression())
}
fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
pinned.get_pin_mut()
}
}
#[cfg(feature = "compression-br")]
impl<B> DecorateAsyncRead for BrotliEncoder<B>
where
B: Body,
{
type Input = AsyncReadBody<B>;
type Output = BrotliEncoder<Self::Input>;
fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
// The brotli crate used under the hood here has a default compression level of 11,
// which is the max for brotli. This causes extremely slow compression times, so we
// manually set a default of 4 here.
//
// This is the same default used by NGINX for on-the-fly brotli compression.
let level = match quality {
CompressionLevel::Default => async_compression::Level::Precise(4),
other => other.into_async_compression(),
};
BrotliEncoder::with_quality(input, level)
}
fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
pinned.get_pin_mut()
}
}
#[cfg(feature = "compression-zstd")]
impl<B> DecorateAsyncRead for ZstdEncoder<B>
where
B: Body,
{
type Input = AsyncReadBody<B>;
type Output = ZstdEncoder<Self::Input>;
fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
// See https://issues.chromium.org/issues/41493659:
// "For memory usage reasons, Chromium limits the window size to 8MB"
// See https://datatracker.ietf.org/doc/html/rfc8878#name-window-descriptor
// "For improved interoperability, it's recommended for decoders to support values
// of Window_Size up to 8 MB and for encoders not to generate frames requiring a
// Window_Size larger than 8 MB."
// Level 17 in zstd (as of v1.5.6) is the first level with a window size of 8 MB (2^23):
// https://github.com/facebook/zstd/blob/v1.5.6/lib/compress/clevels.h#L25-L51
// Set the parameter for all levels >= 17. This will either have no effect (but reduce
// the risk of future changes in zstd) or limit the window log to 8MB.
let needs_window_limit = match quality {
CompressionLevel::Best => true, // level 20
CompressionLevel::Precise(level) => level >= 17,
_ => false,
};
// The parameter is not set for levels below 17 as it will increase the window size
// for those levels.
if needs_window_limit {
let params = [async_compression::zstd::CParameter::window_log(23)];
ZstdEncoder::with_quality_and_params(input, quality.into_async_compression(), &params)
} else {
ZstdEncoder::with_quality(input, quality.into_async_compression())
}
}
fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
pinned.get_pin_mut()
}
}

View File

@@ -0,0 +1,133 @@
#![allow(unused_imports)]
use super::{body::BodyInner, CompressionBody};
use crate::compression::predicate::Predicate;
use crate::compression::CompressionLevel;
use crate::compression_utils::WrapBody;
use crate::content_encoding::Encoding;
use http::{header, HeaderMap, HeaderValue, Response};
use http_body::Body;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
pin_project! {
/// Response future of [`Compression`].
///
/// [`Compression`]: super::Compression
#[derive(Debug)]
pub struct ResponseFuture<F, P> {
#[pin]
pub(crate) inner: F,
pub(crate) encoding: Encoding,
pub(crate) predicate: P,
pub(crate) quality: CompressionLevel,
}
}
impl<F, B, E, P> Future for ResponseFuture<F, P>
where
F: Future<Output = Result<Response<B>, E>>,
B: Body,
P: Predicate,
{
type Output = Result<Response<CompressionBody<B>>, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = ready!(self.as_mut().project().inner.poll(cx)?);
// never recompress responses that are already compressed
let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
// never compress responses that are ranges
&& !res.headers().contains_key(header::CONTENT_RANGE)
&& self.predicate.should_compress(&res);
let (mut parts, body) = res.into_parts();
if should_compress
&& !parts.headers.get_all(header::VARY).iter().any(|value| {
contains_ignore_ascii_case(
value.as_bytes(),
header::ACCEPT_ENCODING.as_str().as_bytes(),
)
})
{
parts
.headers
.append(header::VARY, header::ACCEPT_ENCODING.into());
}
let body = match (should_compress, self.encoding) {
// if compression is _not_ supported or the client doesn't accept it
(false, _) | (_, Encoding::Identity) => {
return Poll::Ready(Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
)))
}
#[cfg(feature = "compression-gzip")]
(_, Encoding::Gzip) => {
CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "compression-deflate")]
(_, Encoding::Deflate) => {
CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "compression-br")]
(_, Encoding::Brotli) => {
CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "compression-zstd")]
(_, Encoding::Zstd) => {
CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "fs")]
#[allow(unreachable_patterns)]
(true, _) => {
// This should never happen because the `AcceptEncoding` struct which is used to determine
// `self.encoding` will only enable the different compression algorithms if the
// corresponding crate feature has been enabled. This means
// Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
// features enabled.
//
// The match arm is still required though because the `fs` feature uses the
// Encoding struct independently and requires no compression logic to be enabled.
// This means a combination of an individual compression feature and `fs` will fail
// to compile without this branch even though it will never be reached.
//
// To safeguard against refactors that changes this relationship or other bugs the
// server will return an uncompressed response instead of panicking since that could
// become a ddos attack vector.
return Poll::Ready(Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
)));
}
};
parts.headers.remove(header::ACCEPT_RANGES);
parts.headers.remove(header::CONTENT_LENGTH);
parts
.headers
.insert(header::CONTENT_ENCODING, self.encoding.into_header_value());
let res = Response::from_parts(parts, body);
Poll::Ready(Ok(res))
}
}
fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
while needle.len() <= haystack.len() {
if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
return true;
}
haystack = &haystack[1..];
}
false
}

View File

@@ -0,0 +1,240 @@
use super::{Compression, Predicate};
use crate::compression::predicate::DefaultPredicate;
use crate::compression::CompressionLevel;
use crate::compression_utils::AcceptEncoding;
use tower_layer::Layer;
/// Compress response bodies of the underlying service.
///
/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
/// `Content-Encoding` header to responses.
///
/// See the [module docs](crate::compression) for more details.
#[derive(Clone, Debug, Default)]
pub struct CompressionLayer<P = DefaultPredicate> {
accept: AcceptEncoding,
predicate: P,
quality: CompressionLevel,
}
impl<S, P> Layer<S> for CompressionLayer<P>
where
P: Predicate,
{
type Service = Compression<S, P>;
fn layer(&self, inner: S) -> Self::Service {
Compression {
inner,
accept: self.accept,
predicate: self.predicate.clone(),
quality: self.quality,
}
}
}
impl CompressionLayer {
/// Creates a new [`CompressionLayer`].
pub fn new() -> Self {
Self::default()
}
/// Sets whether to enable the gzip encoding.
#[cfg(feature = "compression-gzip")]
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
/// Sets whether to enable the Deflate encoding.
#[cfg(feature = "compression-deflate")]
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
/// Sets whether to enable the Brotli encoding.
#[cfg(feature = "compression-br")]
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
/// Sets whether to enable the Zstd encoding.
#[cfg(feature = "compression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
/// Sets the compression quality.
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
/// Disables the gzip encoding.
///
/// This method is available even if the `gzip` crate feature is disabled.
pub fn no_gzip(mut self) -> Self {
self.accept.set_gzip(false);
self
}
/// Disables the Deflate encoding.
///
/// This method is available even if the `deflate` crate feature is disabled.
pub fn no_deflate(mut self) -> Self {
self.accept.set_deflate(false);
self
}
/// Disables the Brotli encoding.
///
/// This method is available even if the `br` crate feature is disabled.
pub fn no_br(mut self) -> Self {
self.accept.set_br(false);
self
}
/// Disables the Zstd encoding.
///
/// This method is available even if the `zstd` crate feature is disabled.
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}
/// Replace the current compression predicate.
///
/// See [`Compression::compress_when`] for more details.
pub fn compress_when<C>(self, predicate: C) -> CompressionLayer<C>
where
C: Predicate,
{
CompressionLayer {
accept: self.accept,
predicate,
quality: self.quality,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::Body;
use http::{header::ACCEPT_ENCODING, Request, Response};
use http_body_util::BodyExt;
use std::convert::Infallible;
use tokio::fs::File;
use tokio_util::io::ReaderStream;
use tower::{Service, ServiceBuilder, ServiceExt};
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
// Open the file.
let file = File::open("Cargo.toml").await.expect("file missing");
// Convert the file into a `Stream`.
let stream = ReaderStream::new(file);
// Convert the `Stream` into a `Body`.
let body = Body::from_stream(stream);
// Create response.
Ok(Response::new(body))
}
#[tokio::test]
async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> {
let deflate_only_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_br()
.no_gzip();
let mut service = ServiceBuilder::new()
// Compress responses based on the `Accept-Encoding` header.
.layer(deflate_only_layer)
.service_fn(handle);
// Call the service with the deflate only layer
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.ready().await?.call(request).await?;
assert_eq!(response.headers()["content-encoding"], "deflate");
// Read the body
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
let deflate_bytes_len = bytes.len();
let br_only_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_gzip()
.no_deflate();
let mut service = ServiceBuilder::new()
// Compress responses based on the `Accept-Encoding` header.
.layer(br_only_layer)
.service_fn(handle);
// Call the service with the br only layer
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.ready().await?.call(request).await?;
assert_eq!(response.headers()["content-encoding"], "br");
// Read the body
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
let br_byte_length = bytes.len();
// check the corresponding algorithms are actually used
// br should compresses better than deflate
assert!(br_byte_length < deflate_bytes_len * 9 / 10);
Ok(())
}
/// Test ensuring that zstd compression will not exceed an 8MiB window size; browsers do not
/// accept responses using 16MiB+ window sizes.
#[tokio::test]
async fn zstd_is_web_safe() -> Result<(), crate::BoxError> {
async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
}
// zstd will (I believe) lower its window size if a larger one isn't beneficial and
// it knows the size of the input; use an 18MiB body to ensure it would want a
// >=16MiB window (though it might not be able to see the input size here).
let zstd_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_br()
.no_deflate()
.no_gzip();
let mut service = ServiceBuilder::new().layer(zstd_layer).service_fn(zeroes);
let request = Request::builder()
.header(ACCEPT_ENCODING, "zstd")
.body(Body::empty())?;
let response = service.ready().await?.call(request).await?;
assert_eq!(response.headers()["content-encoding"], "zstd");
let body = response.into_body();
let bytes = body.collect().await?.to_bytes();
let mut dec = zstd::Decoder::new(&*bytes)?;
dec.window_log_max(23)?; // Limit window size accepted by decoder to 2 ^ 23 bytes (8MiB)
std::io::copy(&mut dec, &mut std::io::sink())?;
Ok(())
}
}

511
vendor/tower-http/src/compression/mod.rs vendored Normal file
View File

@@ -0,0 +1,511 @@
//! Middleware that compresses response bodies.
//!
//! # Example
//!
//! Example showing how to respond with the compressed contents of a file.
//!
//! ```rust
//! use bytes::{Bytes, BytesMut};
//! use http::{Request, Response, header::ACCEPT_ENCODING};
//! use http_body_util::{Full, BodyExt, StreamBody, combinators::UnsyncBoxBody};
//! use http_body::Frame;
//! use std::convert::Infallible;
//! use tokio::fs::{self, File};
//! use tokio_util::io::ReaderStream;
//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
//! use tower_http::{compression::CompressionLayer, BoxError};
//! use futures_util::TryStreamExt;
//!
//! type BoxBody = UnsyncBoxBody<Bytes, std::io::Error>;
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), BoxError> {
//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<BoxBody>, Infallible> {
//! // Open the file.
//! let file = File::open("Cargo.toml").await.expect("file missing");
//! // Convert the file into a `Stream` of `Bytes`.
//! let stream = ReaderStream::new(file);
//! // Convert the stream into a stream of data `Frame`s.
//! let stream = stream.map_ok(Frame::data);
//! // Convert the `Stream` into a `Body`.
//! let body = StreamBody::new(stream);
//! // Erase the type because its very hard to name in the function signature.
//! let body = body.boxed_unsync();
//! // Create response.
//! Ok(Response::new(body))
//! }
//!
//! let mut service = ServiceBuilder::new()
//! // Compress responses based on the `Accept-Encoding` header.
//! .layer(CompressionLayer::new())
//! .service_fn(handle);
//!
//! // Call the service.
//! let request = Request::builder()
//! .header(ACCEPT_ENCODING, "gzip")
//! .body(Full::<Bytes>::default())?;
//!
//! let response = service
//! .ready()
//! .await?
//! .call(request)
//! .await?;
//!
//! assert_eq!(response.headers()["content-encoding"], "gzip");
//!
//! // Read the body
//! let bytes = response
//! .into_body()
//! .collect()
//! .await?
//! .to_bytes();
//!
//! // The compressed body should be smaller 🤞
//! let uncompressed_len = fs::read_to_string("Cargo.toml").await?.len();
//! assert!(bytes.len() < uncompressed_len);
//! #
//! # Ok(())
//! # }
//! ```
//!
pub mod predicate;
mod body;
mod future;
mod layer;
mod pin_project_cfg;
mod service;
#[doc(inline)]
pub use self::{
body::CompressionBody,
future::ResponseFuture,
layer::CompressionLayer,
predicate::{DefaultPredicate, Predicate},
service::Compression,
};
pub use crate::compression_utils::CompressionLevel;
#[cfg(test)]
mod tests {
use crate::compression::predicate::SizeAbove;
use super::*;
use crate::test_helpers::{Body, WithTrailers};
use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
use flate2::read::GzDecoder;
use http::header::{
ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE,
};
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use http_body::Body as _;
use http_body_util::BodyExt;
use std::convert::Infallible;
use std::io::Read;
use std::sync::{Arc, RwLock};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::io::StreamReader;
use tower::{service_fn, Service, ServiceExt};
// Compression filter allows every other request to be compressed
#[derive(Clone)]
struct Always;
impl Predicate for Always {
fn should_compress<B>(&self, _: &http::Response<B>) -> bool
where
B: http_body::Body,
{
true
}
}
#[tokio::test]
async fn gzip_works() {
let svc = service_fn(handle);
let mut svc = Compression::new(svc).compress_when(Always);
// call the service
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
// read the compressed body
let collected = res.into_body().collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
let compressed_data = collected.to_bytes();
// decompress the body
// doing this with flate2 as that is much easier than async-compression and blocking during
// tests is fine
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
// trailers are maintained
assert_eq!(trailers["foo"], "bar");
}
#[tokio::test]
async fn x_gzip_works() {
let svc = service_fn(handle);
let mut svc = Compression::new(svc).compress_when(Always);
// call the service
let req = Request::builder()
.header("accept-encoding", "x-gzip")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
// we treat x-gzip as equivalent to gzip and don't have to return x-gzip
// taking extra caution by checking all headers with this name
assert_eq!(
res.headers()
.get_all("content-encoding")
.iter()
.collect::<Vec<&HeaderValue>>(),
vec!(HeaderValue::from_static("gzip"))
);
// read the compressed body
let collected = res.into_body().collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
let compressed_data = collected.to_bytes();
// decompress the body
// doing this with flate2 as that is much easier than async-compression and blocking during
// tests is fine
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
// trailers are maintained
assert_eq!(trailers["foo"], "bar");
}
#[tokio::test]
async fn zstd_works() {
let svc = service_fn(handle);
let mut svc = Compression::new(svc).compress_when(Always);
// call the service
let req = Request::builder()
.header("accept-encoding", "zstd")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
// read the compressed body
let body = res.into_body();
let compressed_data = body.collect().await.unwrap().to_bytes();
// decompress the body
let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
let decompressed = String::from_utf8(decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn no_recompress() {
const DATA: &str = "Hello, World! I'm already compressed with br!";
let svc = service_fn(|_| async {
let buf = {
let mut buf = Vec::new();
let mut enc = BrotliEncoder::new(&mut buf);
enc.write_all(DATA.as_bytes()).await?;
enc.flush().await?;
buf
};
let resp = Response::builder()
.header("content-encoding", "br")
.body(Body::from(buf))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
let mut svc = Compression::new(svc);
// call the service
//
// note: the accept-encoding doesn't match the content-encoding above, so that
// we're able to see if the compression layer triggered or not
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
// check we didn't recompress
assert_eq!(
res.headers()
.get("content-encoding")
.and_then(|h| h.to_str().ok())
.unwrap_or_default(),
"br",
);
// read the compressed body
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
// decompress the body
let data = {
let mut output_buf = Vec::new();
let mut decoder = BrotliDecoder::new(&mut output_buf);
decoder
.write_all(&data)
.await
.expect("couldn't brotli-decode");
decoder.flush().await.expect("couldn't flush");
output_buf
};
assert_eq!(data, DATA.as_bytes());
}
async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
let mut trailers = HeaderMap::new();
trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
let body = Body::from("Hello, World!").with_trailers(trailers);
Ok(Response::builder().body(body).unwrap())
}
#[tokio::test]
async fn will_not_compress_if_filtered_out() {
use predicate::Predicate;
const DATA: &str = "Hello world uncompressed";
let svc_fn = service_fn(|_| async {
let resp = Response::builder()
// .header("content-encoding", "br")
.body(Body::from(DATA.as_bytes()))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
// Compression filter allows every other request to be compressed
#[derive(Default, Clone)]
struct EveryOtherResponse(Arc<RwLock<u64>>);
#[allow(clippy::dbg_macro)]
impl Predicate for EveryOtherResponse {
fn should_compress<B>(&self, _: &http::Response<B>) -> bool
where
B: http_body::Body,
{
let mut guard = self.0.write().unwrap();
let should_compress = *guard % 2 != 0;
*guard += 1;
dbg!(should_compress)
}
}
let mut svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
// read the uncompressed body
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
assert_eq!(DATA, &still_uncompressed);
// Compression filter will compress the next body
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
// read the compressed body
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
assert!(String::from_utf8(data.to_vec()).is_err());
}
#[tokio::test]
async fn doesnt_compress_images() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
res.headers_mut()
.insert(CONTENT_TYPE, "image/png".parse().unwrap());
Ok(res)
}
let svc = Compression::new(service_fn(handle));
let res = svc
.oneshot(
Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(res.headers().get(CONTENT_ENCODING).is_none());
}
#[tokio::test]
async fn does_compress_svg() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
res.headers_mut()
.insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
Ok(res)
}
let svc = Compression::new(service_fn(handle));
let res = svc
.oneshot(
Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
}
#[tokio::test]
async fn compress_with_quality() {
const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
let level = CompressionLevel::Best;
let svc = service_fn(|_| async {
let resp = Response::builder()
.body(Body::from(DATA.as_bytes()))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
let mut svc = Compression::new(svc).quality(level);
// call the service
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
// read the compressed body
let body = res.into_body();
let compressed_data = body.collect().await.unwrap().to_bytes();
// build the compressed body with the same quality level
let compressed_with_level = {
use async_compression::tokio::bufread::BrotliEncoder;
let stream = Box::pin(futures_util::stream::once(async move {
Ok::<_, std::io::Error>(DATA.as_bytes())
}));
let reader = StreamReader::new(stream);
let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
let mut buf = Vec::new();
enc.read_to_end(&mut buf).await.unwrap();
buf
};
assert_eq!(
compressed_data,
compressed_with_level.as_slice(),
"Compression level is not respected"
);
}
#[tokio::test]
async fn should_not_compress_ranges() {
let svc = service_fn(|_| async {
let mut res = Response::new(Body::from("Hello"));
let headers = res.headers_mut();
headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap());
Ok::<_, std::io::Error>(res)
});
let mut svc = Compression::new(svc).compress_when(Always);
// call the service
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.header(RANGE, "bytes=0-4")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let headers = res.headers().clone();
// read the uncompressed body
let collected = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(headers[ACCEPT_RANGES], "bytes");
assert!(!headers.contains_key(CONTENT_ENCODING));
assert_eq!(collected, "Hello");
}
#[tokio::test]
async fn should_strip_accept_ranges_header_when_compressing() {
let svc = service_fn(|_| async {
let mut res = Response::new(Body::from("Hello, World!"));
res.headers_mut()
.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
Ok::<_, std::io::Error>(res)
});
let mut svc = Compression::new(svc).compress_when(Always);
// call the service
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let headers = res.headers().clone();
// read the compressed body
let collected = res.into_body().collect().await.unwrap();
let compressed_data = collected.to_bytes();
// decompress the body
// doing this with flate2 as that is much easier than async-compression and blocking during
// tests is fine
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert!(!headers.contains_key(ACCEPT_RANGES));
assert_eq!(headers[CONTENT_ENCODING], "gzip");
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn size_hint_identity() {
let msg = "Hello, world!";
let svc = service_fn(|_| async { Ok::<_, std::io::Error>(Response::new(Body::from(msg))) });
let mut svc = Compression::new(svc);
let req = Request::new(Body::empty());
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let body = res.into_body();
assert_eq!(body.size_hint().exact().unwrap(), msg.len() as u64);
}
}

View File

@@ -0,0 +1,144 @@
// Full credit to @tesaguri who posted this gist under CC0 1.0 Universal licence
// https://gist.github.com/tesaguri/2a1c0790a48bbda3dd7f71c26d02a793
macro_rules! pin_project_cfg {
($(#[$($attr:tt)*])* $vis:vis enum $($rest:tt)+) => {
pin_project_cfg! {
@outer [$(#[$($attr)*])* $vis enum] $($rest)+
}
};
// Accumulate type parameters and `where` clause.
(@outer [$($accum:tt)*] $tt:tt $($rest:tt)+) => {
pin_project_cfg! {
@outer [$($accum)* $tt] $($rest)+
}
};
(@outer [$($accum:tt)*] { $($body:tt)* }) => {
pin_project_cfg! {
@body #[cfg(all())] [$($accum)*] {} $($body)*
}
};
// Process a variant with `cfg`.
(
@body
#[cfg(all($($pred_accum:tt)*))]
$outer:tt
{ $($accum:tt)* }
#[cfg($($pred:tt)*)]
$(#[$($attr:tt)*])* $variant:ident { $($body:tt)* },
$($rest:tt)*
) => {
// Create two versions of the enum with `cfg($pred)` and `cfg(not($pred))`.
pin_project_cfg! {
@variant_body
{ $($body)* }
{}
#[cfg(all($($pred_accum)* $($pred)*,))]
$outer
{ $($accum)* $(#[$($attr)*])* $variant }
$($rest)*
}
pin_project_cfg! {
@body
#[cfg(all($($pred_accum)* not($($pred)*),))]
$outer
{ $($accum)* }
$($rest)*
}
};
// Process a variant without `cfg`.
(
@body
#[cfg(all($($pred_accum:tt)*))]
$outer:tt
{ $($accum:tt)* }
$(#[$($attr:tt)*])* $variant:ident { $($body:tt)* },
$($rest:tt)*
) => {
pin_project_cfg! {
@variant_body
{ $($body)* }
{}
#[cfg(all($($pred_accum)*))]
$outer
{ $($accum)* $(#[$($attr)*])* $variant }
$($rest)*
}
};
// Process a variant field with `cfg`.
(
@variant_body
{
#[cfg($($pred:tt)*)]
$(#[$($attr:tt)*])* $field:ident: $ty:ty,
$($rest:tt)*
}
{ $($accum:tt)* }
#[cfg(all($($pred_accum:tt)*))]
$($outer:tt)*
) => {
pin_project_cfg! {
@variant_body
{$($rest)*}
{ $($accum)* $(#[$($attr)*])* $field: $ty, }
#[cfg(all($($pred_accum)* $($pred)*,))]
$($outer)*
}
pin_project_cfg! {
@variant_body
{ $($rest)* }
{ $($accum)* }
#[cfg(all($($pred_accum)* not($($pred)*),))]
$($outer)*
}
};
// Process a variant field without `cfg`.
(
@variant_body
{
$(#[$($attr:tt)*])* $field:ident: $ty:ty,
$($rest:tt)*
}
{ $($accum:tt)* }
$($outer:tt)*
) => {
pin_project_cfg! {
@variant_body
{$($rest)*}
{ $($accum)* $(#[$($attr)*])* $field: $ty, }
$($outer)*
}
};
(
@variant_body
{}
$body:tt
#[cfg(all($($pred_accum:tt)*))]
$outer:tt
{ $($accum:tt)* }
$($rest:tt)*
) => {
pin_project_cfg! {
@body
#[cfg(all($($pred_accum)*))]
$outer
{ $($accum)* $body, }
$($rest)*
}
};
(
@body
#[$cfg:meta]
[$($outer:tt)*]
$body:tt
) => {
#[$cfg]
pin_project_lite::pin_project! {
$($outer)* $body
}
};
}
pub(crate) use pin_project_cfg;

View File

@@ -0,0 +1,272 @@
//! Predicates for disabling compression of responses.
//!
//! Predicates are applied with [`Compression::compress_when`] or
//! [`CompressionLayer::compress_when`].
//!
//! [`Compression::compress_when`]: super::Compression::compress_when
//! [`CompressionLayer::compress_when`]: super::CompressionLayer::compress_when
use http::{header, Extensions, HeaderMap, StatusCode, Version};
use http_body::Body;
use std::{fmt, sync::Arc};
/// Predicate used to determine if a response should be compressed or not.
pub trait Predicate: Clone {
/// Should this response be compressed or not?
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body;
/// Combine two predicates into one.
///
/// The resulting predicate enables compression if both inner predicates do.
fn and<Other>(self, other: Other) -> And<Self, Other>
where
Self: Sized,
Other: Predicate,
{
And {
lhs: self,
rhs: other,
}
}
}
impl<F> Predicate for F
where
F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone,
{
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
let status = response.status();
let version = response.version();
let headers = response.headers();
let extensions = response.extensions();
self(status, version, headers, extensions)
}
}
impl<T> Predicate for Option<T>
where
T: Predicate,
{
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
self.as_ref()
.map(|inner| inner.should_compress(response))
.unwrap_or(true)
}
}
/// Two predicates combined into one.
///
/// Created with [`Predicate::and`]
#[derive(Debug, Clone, Default, Copy)]
pub struct And<Lhs, Rhs> {
lhs: Lhs,
rhs: Rhs,
}
impl<Lhs, Rhs> Predicate for And<Lhs, Rhs>
where
Lhs: Predicate,
Rhs: Predicate,
{
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
self.lhs.should_compress(response) && self.rhs.should_compress(response)
}
}
/// The default predicate used by [`Compression`] and [`CompressionLayer`].
///
/// This will compress responses unless:
///
/// - They're gRPC, which has its own protocol specific compression scheme.
/// - It's an image as determined by the `content-type` starting with `image/`.
/// - They're Server-Sent Events (SSE) as determined by the `content-type` being `text/event-stream`.
/// - The response is less than 32 bytes.
///
/// # Configuring the defaults
///
/// `DefaultPredicate` doesn't support any configuration. Instead you can build your own predicate
/// by combining types in this module:
///
/// ```rust
/// use tower_http::compression::predicate::{SizeAbove, NotForContentType, Predicate};
///
/// // slightly large min size than the default 32
/// let predicate = SizeAbove::new(256)
/// // still don't compress gRPC
/// .and(NotForContentType::GRPC)
/// // still don't compress images
/// .and(NotForContentType::IMAGES)
/// // also don't compress JSON
/// .and(NotForContentType::const_new("application/json"));
/// ```
///
/// [`Compression`]: super::Compression
/// [`CompressionLayer`]: super::CompressionLayer
#[derive(Clone)]
pub struct DefaultPredicate(
And<And<And<SizeAbove, NotForContentType>, NotForContentType>, NotForContentType>,
);
impl DefaultPredicate {
/// Create a new `DefaultPredicate`.
pub fn new() -> Self {
let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE)
.and(NotForContentType::GRPC)
.and(NotForContentType::IMAGES)
.and(NotForContentType::SSE);
Self(inner)
}
}
impl Default for DefaultPredicate {
fn default() -> Self {
Self::new()
}
}
impl Predicate for DefaultPredicate {
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
self.0.should_compress(response)
}
}
/// [`Predicate`] that will only allow compression of responses above a certain size.
#[derive(Clone, Copy, Debug)]
pub struct SizeAbove(u16);
impl SizeAbove {
pub(crate) const DEFAULT_MIN_SIZE: u16 = 32;
/// Create a new `SizeAbove` predicate that will only compress responses larger than
/// `min_size_bytes`.
///
/// The response will be compressed if the exact size cannot be determined through either the
/// `content-length` header or [`Body::size_hint`].
pub const fn new(min_size_bytes: u16) -> Self {
Self(min_size_bytes)
}
}
impl Default for SizeAbove {
fn default() -> Self {
Self(Self::DEFAULT_MIN_SIZE)
}
}
impl Predicate for SizeAbove {
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
let content_size = response.body().size_hint().exact().or_else(|| {
response
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|val| val.parse().ok())
});
match content_size {
Some(size) => size >= (self.0 as u64),
_ => true,
}
}
}
/// Predicate that wont allow responses with a specific `content-type` to be compressed.
#[derive(Clone, Debug)]
pub struct NotForContentType {
content_type: Str,
exception: Option<Str>,
}
impl NotForContentType {
/// Predicate that wont compress gRPC responses.
pub const GRPC: Self = Self::const_new("application/grpc");
/// Predicate that wont compress images.
pub const IMAGES: Self = Self {
content_type: Str::Static("image/"),
exception: Some(Str::Static("image/svg+xml")),
};
/// Predicate that wont compress Server-Sent Events (SSE) responses.
pub const SSE: Self = Self::const_new("text/event-stream");
/// Create a new `NotForContentType`.
pub fn new(content_type: &str) -> Self {
Self {
content_type: Str::Shared(content_type.into()),
exception: None,
}
}
/// Create a new `NotForContentType` from a static string.
pub const fn const_new(content_type: &'static str) -> Self {
Self {
content_type: Str::Static(content_type),
exception: None,
}
}
}
impl Predicate for NotForContentType {
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
if let Some(except) = &self.exception {
if content_type(response) == except.as_str() {
return true;
}
}
!content_type(response).starts_with(self.content_type.as_str())
}
}
#[derive(Clone)]
enum Str {
Static(&'static str),
Shared(Arc<str>),
}
impl Str {
fn as_str(&self) -> &str {
match self {
Str::Static(s) => s,
Str::Shared(s) => s,
}
}
}
impl fmt::Debug for Str {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Static(inner) => inner.fmt(f),
Self::Shared(inner) => inner.fmt(f),
}
}
}
fn content_type<B>(response: &http::Response<B>) -> &str {
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or_default()
}

View File

@@ -0,0 +1,185 @@
use super::{CompressionBody, CompressionLayer, ResponseFuture};
use crate::compression::predicate::{DefaultPredicate, Predicate};
use crate::compression::CompressionLevel;
use crate::{compression_utils::AcceptEncoding, content_encoding::Encoding};
use http::{Request, Response};
use http_body::Body;
use std::task::{Context, Poll};
use tower_service::Service;
/// Compress response bodies of the underlying service.
///
/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
/// `Content-Encoding` header to responses.
///
/// See the [module docs](crate::compression) for more details.
#[derive(Clone, Copy)]
pub struct Compression<S, P = DefaultPredicate> {
pub(crate) inner: S,
pub(crate) accept: AcceptEncoding,
pub(crate) predicate: P,
pub(crate) quality: CompressionLevel,
}
impl<S> Compression<S, DefaultPredicate> {
/// Creates a new `Compression` wrapping the `service`.
pub fn new(service: S) -> Compression<S, DefaultPredicate> {
Self {
inner: service,
accept: AcceptEncoding::default(),
predicate: DefaultPredicate::default(),
quality: CompressionLevel::default(),
}
}
}
impl<S, P> Compression<S, P> {
define_inner_service_accessors!();
/// Returns a new [`Layer`] that wraps services with a `Compression` middleware.
///
/// [`Layer`]: tower_layer::Layer
pub fn layer() -> CompressionLayer {
CompressionLayer::new()
}
/// Sets whether to enable the gzip encoding.
#[cfg(feature = "compression-gzip")]
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
/// Sets whether to enable the Deflate encoding.
#[cfg(feature = "compression-deflate")]
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
/// Sets whether to enable the Brotli encoding.
#[cfg(feature = "compression-br")]
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
/// Sets whether to enable the Zstd encoding.
#[cfg(feature = "compression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
/// Sets the compression quality.
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
/// Disables the gzip encoding.
///
/// This method is available even if the `gzip` crate feature is disabled.
pub fn no_gzip(mut self) -> Self {
self.accept.set_gzip(false);
self
}
/// Disables the Deflate encoding.
///
/// This method is available even if the `deflate` crate feature is disabled.
pub fn no_deflate(mut self) -> Self {
self.accept.set_deflate(false);
self
}
/// Disables the Brotli encoding.
///
/// This method is available even if the `br` crate feature is disabled.
pub fn no_br(mut self) -> Self {
self.accept.set_br(false);
self
}
/// Disables the Zstd encoding.
///
/// This method is available even if the `zstd` crate feature is disabled.
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}
/// Replace the current compression predicate.
///
/// Predicates are used to determine whether a response should be compressed or not.
///
/// The default predicate is [`DefaultPredicate`]. See its documentation for more
/// details on which responses it wont compress.
///
/// # Changing the compression predicate
///
/// ```
/// use tower_http::compression::{
/// Compression,
/// predicate::{Predicate, NotForContentType, DefaultPredicate},
/// };
/// use tower::util::service_fn;
///
/// // Placeholder service_fn
/// let service = service_fn(|_: ()| async {
/// Ok::<_, std::io::Error>(http::Response::new(()))
/// });
///
/// // build our custom compression predicate
/// // its recommended to still include `DefaultPredicate` as part of
/// // custom predicates
/// let predicate = DefaultPredicate::new()
/// // don't compress responses who's `content-type` starts with `application/json`
/// .and(NotForContentType::new("application/json"));
///
/// let service = Compression::new(service).compress_when(predicate);
/// ```
///
/// See [`predicate`](super::predicate) for more utilities for building compression predicates.
///
/// Responses that are already compressed (ie have a `content-encoding` header) will _never_ be
/// recompressed, regardless what they predicate says.
pub fn compress_when<C>(self, predicate: C) -> Compression<S, C>
where
C: Predicate,
{
Compression {
inner: self.inner,
accept: self.accept,
predicate,
quality: self.quality,
}
}
}
impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for Compression<S, P>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Body,
P: Predicate,
{
type Response = Response<CompressionBody<ResBody>>;
type Error = S::Error;
type Future = ResponseFuture<S::Future, P>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let encoding = Encoding::from_headers(req.headers(), self.accept);
ResponseFuture {
inner: self.inner.call(req),
encoding,
predicate: self.predicate.clone(),
quality: self.quality,
}
}
}