aboutsummaryrefslogtreecommitdiff
path: root/src/https.rs
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2021-12-09 15:43:19 +0100
committerAlex Auvolat <alex@adnab.me>2021-12-09 15:43:19 +0100
commit9b30f2b7d17cbee39c271d159524202e0ffa297c (patch)
tree4f523a832ab3e18e87241c1e3f2d28d5a332f180 /src/https.rs
parente4942490ee6f51573223772ceee8a8ac46b55ae6 (diff)
downloadtricot-9b30f2b7d17cbee39c271d159524202e0ffa297c.tar.gz
tricot-9b30f2b7d17cbee39c271d159524202e0ffa297c.zip
Compression
Diffstat (limited to 'src/https.rs')
-rw-r--r--src/https.rs133
1 files changed, 110 insertions, 23 deletions
diff --git a/src/https.rs b/src/https.rs
index a389e72..1b467c0 100644
--- a/src/https.rs
+++ b/src/https.rs
@@ -1,44 +1,56 @@
+use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};
use anyhow::Result;
use log::*;
-use futures::FutureExt;
+use accept_encoding_fork::Encoding;
+use async_compression::tokio::bufread::*;
+use futures::TryStreamExt;
use http::header::{HeaderName, HeaderValue};
use hyper::server::conn::Http;
use hyper::service::service_fn;
-use hyper::{Body, Request, Response, StatusCode};
+use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
+use tokio_util::io::{ReaderStream, StreamReader};
use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;
+pub struct HttpsConfig {
+ pub bind_addr: SocketAddr,
+ pub enable_compression: bool,
+ pub compress_mime_types: Vec<String>,
+}
+
pub async fn serve_https(
- bind_addr: SocketAddr,
+ config: HttpsConfig,
cert_store: Arc<CertStore>,
- proxy_config: watch::Receiver<Arc<ProxyConfig>>,
+ rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
) -> Result<()> {
- let mut cfg = rustls::ServerConfig::builder()
+ let config = Arc::new(config);
+
+ let mut tls_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(StoreResolver(cert_store)));
- cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
- let tls_cfg = Arc::new(cfg);
- let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg));
+ tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
+ let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));
- info!("Starting to serve on https://{}.", bind_addr);
+ info!("Starting to serve on https://{}.", config.bind_addr);
- let tcp = TcpListener::bind(bind_addr).await?;
+ let tcp = TcpListener::bind(config.bind_addr).await?;
loop {
let (socket, remote_addr) = tcp.accept().await?;
- let proxy_config = proxy_config.clone();
+ let rx_proxy_config = rx_proxy_config.clone();
let tls_acceptor = tls_acceptor.clone();
+ let config = config.clone();
tokio::spawn(async move {
match tls_acceptor.accept(socket).await {
@@ -48,17 +60,10 @@ pub async fn serve_https(
.serve_connection(
stream,
service_fn(move |req: Request<Body>| {
- let proxy_config: Arc<ProxyConfig> = proxy_config.borrow().clone();
- handle(remote_addr, req, proxy_config).map(|res| match res {
- Err(e) => {
- warn!("Handler error: {}", e);
- Response::builder()
- .status(StatusCode::INTERNAL_SERVER_ERROR)
- .body(Body::from(format!("{}", e)))
- .map_err(Into::into)
- }
- x => x,
- })
+ let https_config = config.clone();
+ let proxy_config: Arc<ProxyConfig> =
+ rx_proxy_config.borrow().clone();
+ handle_outer(remote_addr, req, https_config, proxy_config)
}),
)
.await;
@@ -72,11 +77,30 @@ pub async fn serve_https(
}
}
+async fn handle_outer(
+ remote_addr: SocketAddr,
+ req: Request<Body>,
+ https_config: Arc<HttpsConfig>,
+ proxy_config: Arc<ProxyConfig>,
+) -> Result<Response<Body>, Infallible> {
+ match handle(remote_addr, req, https_config, proxy_config).await {
+ Err(e) => {
+ warn!("Handler error: {}", e);
+ Ok(Response::builder()
+ .status(StatusCode::INTERNAL_SERVER_ERROR)
+ .body(Body::from(format!("{}", e)))
+ .unwrap())
+ }
+ Ok(r) => Ok(r),
+ }
+}
+
// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn handle(
remote_addr: SocketAddr,
req: Request<Body>,
+ https_config: Arc<HttpsConfig>,
proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, anyhow::Error> {
let method = req.method().clone();
@@ -91,6 +115,7 @@ async fn handle(
.to_str()?
};
let path = req.uri().path();
+ let accept_encoding = accept_encoding_fork::parse(req.headers()).unwrap_or(None);
let best_match = proxy_config
.entries
@@ -137,7 +162,11 @@ async fn handle(
trace!("Response: {:?}", response);
info!("{} {} {}", method, response.status().as_u16(), uri);
- Ok(response)
+ if https_config.enable_compression {
+ try_compress(response, accept_encoding, &https_config)
+ } else {
+ Ok(response)
+ }
} else {
debug!("{}{} -> NOT FOUND", host, path);
info!("{} 404 {}", method, uri);
@@ -147,3 +176,61 @@ async fn handle(
.body(Body::from("No matching proxy entry"))?)
}
}
+
+fn try_compress(
+ response: Response<Body>,
+ accept_encoding: Option<Encoding>,
+ https_config: &HttpsConfig,
+) -> Result<Response<Body>> {
+ // Check if a compression encoding is accepted
+ let encoding = match accept_encoding {
+ None | Some(Encoding::Identity) => return Ok(response),
+ Some(enc) => enc,
+ };
+
+ // If already compressed, return as is
+ if response.headers().get(header::CONTENT_ENCODING).is_some() {
+ return Ok(response);
+ }
+
+ // If content type not in mime types for which to compress, return as is
+ match response.headers().get(header::CONTENT_TYPE) {
+ Some(ct) => {
+ let ct_str = ct.to_str()?;
+ if !https_config.compress_mime_types.iter().any(|x| x == ct_str) {
+ return Ok(response);
+ }
+ }
+ None => return Ok(response),
+ };
+
+ debug!("Compressing response body as {:?}", encoding);
+
+ let (mut head, body) = response.into_parts();
+ let body_rd =
+ StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
+ let compressed_body = match encoding {
+ Encoding::Gzip => {
+ head.headers
+ .insert(header::CONTENT_ENCODING, "gzip".parse()?);
+ Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd)))
+ }
+ Encoding::Brotli => {
+ head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
+ Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
+ }
+ Encoding::Deflate => {
+ head.headers
+ .insert(header::CONTENT_ENCODING, "deflate".parse()?);
+ Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd)))
+ }
+ Encoding::Zstd => {
+ head.headers
+ .insert(header::CONTENT_ENCODING, "zstd".parse()?);
+ Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd)))
+ }
+ _ => unreachable!(),
+ };
+
+ Ok(Response::from_parts(head, compressed_body))
+}