diff options
author | Alex Auvolat <alex@adnab.me> | 2021-12-09 15:43:19 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2021-12-09 15:43:19 +0100 |
commit | 9b30f2b7d17cbee39c271d159524202e0ffa297c (patch) | |
tree | 4f523a832ab3e18e87241c1e3f2d28d5a332f180 /src | |
parent | e4942490ee6f51573223772ceee8a8ac46b55ae6 (diff) | |
download | tricot-9b30f2b7d17cbee39c271d159524202e0ffa297c.tar.gz tricot-9b30f2b7d17cbee39c271d159524202e0ffa297c.zip |
Compression
Diffstat (limited to 'src')
-rw-r--r-- | src/https.rs | 133 | ||||
-rw-r--r-- | src/main.rs | 30 | ||||
-rw-r--r-- | src/reverse_proxy.rs | 61 |
3 files changed, 159 insertions, 65 deletions
diff --git a/src/https.rs b/src/https.rs index a389e72..1b467c0 100644 --- a/src/https.rs +++ b/src/https.rs @@ -1,44 +1,56 @@ +use std::convert::Infallible; use std::net::SocketAddr; use std::sync::{atomic::Ordering, Arc}; use anyhow::Result; use log::*; -use futures::FutureExt; +use accept_encoding_fork::Encoding; +use async_compression::tokio::bufread::*; +use futures::TryStreamExt; use http::header::{HeaderName, HeaderValue}; use hyper::server::conn::Http; use hyper::service::service_fn; -use hyper::{Body, Request, Response, StatusCode}; +use hyper::{header, Body, Request, Response, StatusCode}; use tokio::net::TcpListener; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; +use tokio_util::io::{ReaderStream, StreamReader}; use crate::cert_store::{CertStore, StoreResolver}; use crate::proxy_config::ProxyConfig; use crate::reverse_proxy; +pub struct HttpsConfig { + pub bind_addr: SocketAddr, + pub enable_compression: bool, + pub compress_mime_types: Vec<String>, +} + pub async fn serve_https( - bind_addr: SocketAddr, + config: HttpsConfig, cert_store: Arc<CertStore>, - proxy_config: watch::Receiver<Arc<ProxyConfig>>, + rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>, ) -> Result<()> { - let mut cfg = rustls::ServerConfig::builder() + let config = Arc::new(config); + + let mut tls_cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_cert_resolver(Arc::new(StoreResolver(cert_store))); - cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let tls_cfg = Arc::new(cfg); - let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg)); + tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg))); - info!("Starting to serve on https://{}.", bind_addr); + info!("Starting to serve on https://{}.", config.bind_addr); - let tcp = TcpListener::bind(bind_addr).await?; + let tcp = TcpListener::bind(config.bind_addr).await?; loop { let (socket, remote_addr) = tcp.accept().await?; - let proxy_config = proxy_config.clone(); + let rx_proxy_config = rx_proxy_config.clone(); let tls_acceptor = tls_acceptor.clone(); + let config = config.clone(); tokio::spawn(async move { match tls_acceptor.accept(socket).await { @@ -48,17 +60,10 @@ pub async fn serve_https( .serve_connection( stream, service_fn(move |req: Request<Body>| { - let proxy_config: Arc<ProxyConfig> = proxy_config.borrow().clone(); - handle(remote_addr, req, proxy_config).map(|res| match res { - Err(e) => { - warn!("Handler error: {}", e); - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from(format!("{}", e))) - .map_err(Into::into) - } - x => x, - }) + let https_config = config.clone(); + let proxy_config: Arc<ProxyConfig> = + rx_proxy_config.borrow().clone(); + handle_outer(remote_addr, req, https_config, proxy_config) }), ) .await; @@ -72,11 +77,30 @@ pub async fn serve_https( } } +async fn handle_outer( + remote_addr: SocketAddr, + req: Request<Body>, + https_config: Arc<HttpsConfig>, + proxy_config: Arc<ProxyConfig>, +) -> Result<Response<Body>, Infallible> { + match handle(remote_addr, req, https_config, proxy_config).await { + Err(e) => { + warn!("Handler error: {}", e); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from(format!("{}", e))) + .unwrap()) + } + Ok(r) => Ok(r), + } +} + // Custom echo service, handling two different routes and a // catch-all 404 responder. async fn handle( remote_addr: SocketAddr, req: Request<Body>, + https_config: Arc<HttpsConfig>, proxy_config: Arc<ProxyConfig>, ) -> Result<Response<Body>, anyhow::Error> { let method = req.method().clone(); @@ -91,6 +115,7 @@ async fn handle( .to_str()? }; let path = req.uri().path(); + let accept_encoding = accept_encoding_fork::parse(req.headers()).unwrap_or(None); let best_match = proxy_config .entries @@ -137,7 +162,11 @@ async fn handle( trace!("Response: {:?}", response); info!("{} {} {}", method, response.status().as_u16(), uri); - Ok(response) + if https_config.enable_compression { + try_compress(response, accept_encoding, &https_config) + } else { + Ok(response) + } } else { debug!("{}{} -> NOT FOUND", host, path); info!("{} 404 {}", method, uri); @@ -147,3 +176,61 @@ async fn handle( .body(Body::from("No matching proxy entry"))?) } } + +fn try_compress( + response: Response<Body>, + accept_encoding: Option<Encoding>, + https_config: &HttpsConfig, +) -> Result<Response<Body>> { + // Check if a compression encoding is accepted + let encoding = match accept_encoding { + None | Some(Encoding::Identity) => return Ok(response), + Some(enc) => enc, + }; + + // If already compressed, return as is + if response.headers().get(header::CONTENT_ENCODING).is_some() { + return Ok(response); + } + + // If content type not in mime types for which to compress, return as is + match response.headers().get(header::CONTENT_TYPE) { + Some(ct) => { + let ct_str = ct.to_str()?; + if !https_config.compress_mime_types.iter().any(|x| x == ct_str) { + return Ok(response); + } + } + None => return Ok(response), + }; + + debug!("Compressing response body as {:?}", encoding); + + let (mut head, body) = response.into_parts(); + let body_rd = + StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); + let compressed_body = match encoding { + Encoding::Gzip => { + head.headers + .insert(header::CONTENT_ENCODING, "gzip".parse()?); + Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))) + } + Encoding::Brotli => { + head.headers.insert(header::CONTENT_ENCODING, "br".parse()?); + Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd))) + } + Encoding::Deflate => { + head.headers + .insert(header::CONTENT_ENCODING, "deflate".parse()?); + Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))) + } + Encoding::Zstd => { + head.headers + .insert(header::CONTENT_ENCODING, "zstd".parse()?); + Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))) + } + _ => unreachable!(), + }; + + Ok(Response::from_parts(head, compressed_body)) +} diff --git a/src/main.rs b/src/main.rs index 61fc747..febe540 100644 --- a/src/main.rs +++ b/src/main.rs @@ -58,6 +58,18 @@ struct Opt { /// E-mail address for Let's Encrypt certificate requests #[structopt(long = "letsencrypt-email", env = "TRICOT_LETSENCRYPT_EMAIL")] pub letsencrypt_email: String, + + /// Enable compression of responses + #[structopt(long = "enable-compression", env = "TRICOT_ENABLE_COMPRESSION")] + pub enable_compression: bool, + + /// Mime types for which to enable compression (comma-separated list) + #[structopt( + long = "compress-mime-types", + env = "TRICOT_COMPRESS_MIME_TYPES", + default_value = "text/html,text/plain,text/css,text/javascript,application/javascript,image/svg+xml" + )] + pub compress_mime_types: String, } #[tokio::main(flavor = "multi_thread", worker_threads = 10)] @@ -87,13 +99,19 @@ async fn main() { ); tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err)); + + let https_config = https::HttpsConfig { + bind_addr: opt.https_bind_addr, + enable_compression: opt.enable_compression, + compress_mime_types: opt + .compress_mime_types + .split(",") + .map(|x| x.to_string()) + .collect(), + }; tokio::spawn( - https::serve_https( - opt.https_bind_addr, - cert_store.clone(), - rx_proxy_config.clone(), - ) - .map_err(exit_on_err), + https::serve_https(https_config, cert_store.clone(), rx_proxy_config.clone()) + .map_err(exit_on_err), ); while rx_proxy_config.changed().await.is_ok() { diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs index 72644b7..445f6ef 100644 --- a/src/reverse_proxy.rs +++ b/src/reverse_proxy.rs @@ -12,33 +12,25 @@ use log::*; use http::header::HeaderName; use hyper::header::{HeaderMap, HeaderValue}; -use hyper::{Body, Client, Request, Response, Uri}; -use lazy_static::lazy_static; +use hyper::{header, Body, Client, Request, Response, Uri}; use rustls::client::{ServerCertVerified, ServerCertVerifier}; use rustls::{Certificate, ServerName}; use crate::tls_util::HttpsConnectorFixedDnsname; -fn is_hop_header(name: &str) -> bool { - use unicase::Ascii; - - // A list of the headers, using `unicase` to help us compare without - // worrying about the case, and `lazy_static!` to prevent reallocation - // of the vector. - lazy_static! { - static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![ - Ascii::new("Connection"), - Ascii::new("Keep-Alive"), - Ascii::new("Proxy-Authenticate"), - Ascii::new("Proxy-Authorization"), - Ascii::new("Te"), - Ascii::new("Trailers"), - Ascii::new("Transfer-Encoding"), - Ascii::new("Upgrade"), - ]; - } - - HOP_HEADERS.iter().any(|h| h == &name) +const HOP_HEADERS: &[HeaderName] = &[ + header::CONNECTION, + //header::KEEP_ALIVE, + header::PROXY_AUTHENTICATE, + header::PROXY_AUTHORIZATION, + header::TE, + header::TRAILER, + header::TRANSFER_ENCODING, + header::UPGRADE, +]; + +fn is_hop_header(name: &HeaderName) -> bool { + HOP_HEADERS.iter().any(|h| h == name) } /// Returns a clone of the headers without the [hop-by-hop headers]. @@ -47,7 +39,7 @@ fn is_hop_header(name: &str) -> bool { fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> { let mut result = HeaderMap::new(); for (k, v) in headers.iter() { - if !is_hop_header(k.as_str()) { + if !is_hop_header(&k) { result.append(k.clone(), v.clone()); } } @@ -80,10 +72,7 @@ fn create_proxied_request<B>( *builder.headers_mut().unwrap() = remove_hop_headers(request.headers()); // If request does not have host header, add it from original URI authority - let host_header_name = "host"; - if let hyper::header::Entry::Vacant(entry) = - builder.headers_mut().unwrap().entry(host_header_name) - { + if let header::Entry::Vacant(entry) = builder.headers_mut().unwrap().entry(header::HOST) { if let Some(authority) = request.uri().authority() { entry.insert(authority.as_str().parse()?); } @@ -96,11 +85,11 @@ fn create_proxied_request<B>( .unwrap() .entry(x_forwarded_for_header_name) { - hyper::header::Entry::Vacant(entry) => { + header::Entry::Vacant(entry) => { entry.insert(client_ip.to_string().parse()?); } - hyper::header::Entry::Occupied(mut entry) => { + header::Entry::Occupied(mut entry) => { let addr = format!("{}, {}", entry.get().to_str()?, client_ip); entry.insert(addr.parse()?); } @@ -112,17 +101,17 @@ fn create_proxied_request<B>( ); // Proxy upgrade requests properly - if let Some(conn) = request.headers().get("connection") { + if let Some(conn) = request.headers().get(header::CONNECTION) { if conn.to_str()?.to_lowercase() == "upgrade" { - if let Some(upgrade) = request.headers().get("upgrade") { - builder.headers_mut().unwrap().insert( - HeaderName::from_bytes(b"connection")?, - "Upgrade".try_into()?, - ); + if let Some(upgrade) = request.headers().get(header::UPGRADE) { + builder + .headers_mut() + .unwrap() + .insert(header::CONNECTION, "Upgrade".try_into()?); builder .headers_mut() .unwrap() - .insert(HeaderName::from_bytes(b"upgrade")?, upgrade.clone()); + .insert(header::UPGRADE, upgrade.clone()); } } } |