diff options
Diffstat (limited to 'src/https.rs')
-rw-r--r-- | src/https.rs | 133 |
1 files changed, 110 insertions, 23 deletions
diff --git a/src/https.rs b/src/https.rs index a389e72..1b467c0 100644 --- a/src/https.rs +++ b/src/https.rs @@ -1,44 +1,56 @@ +use std::convert::Infallible; use std::net::SocketAddr; use std::sync::{atomic::Ordering, Arc}; use anyhow::Result; use log::*; -use futures::FutureExt; +use accept_encoding_fork::Encoding; +use async_compression::tokio::bufread::*; +use futures::TryStreamExt; use http::header::{HeaderName, HeaderValue}; use hyper::server::conn::Http; use hyper::service::service_fn; -use hyper::{Body, Request, Response, StatusCode}; +use hyper::{header, Body, Request, Response, StatusCode}; use tokio::net::TcpListener; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; +use tokio_util::io::{ReaderStream, StreamReader}; use crate::cert_store::{CertStore, StoreResolver}; use crate::proxy_config::ProxyConfig; use crate::reverse_proxy; +pub struct HttpsConfig { + pub bind_addr: SocketAddr, + pub enable_compression: bool, + pub compress_mime_types: Vec<String>, +} + pub async fn serve_https( - bind_addr: SocketAddr, + config: HttpsConfig, cert_store: Arc<CertStore>, - proxy_config: watch::Receiver<Arc<ProxyConfig>>, + rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>, ) -> Result<()> { - let mut cfg = rustls::ServerConfig::builder() + let config = Arc::new(config); + + let mut tls_cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_cert_resolver(Arc::new(StoreResolver(cert_store))); - cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let tls_cfg = Arc::new(cfg); - let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg)); + tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg))); - info!("Starting to serve on https://{}.", bind_addr); + info!("Starting to serve on https://{}.", config.bind_addr); - let tcp = TcpListener::bind(bind_addr).await?; + let tcp = TcpListener::bind(config.bind_addr).await?; loop { let (socket, remote_addr) = tcp.accept().await?; - let proxy_config = proxy_config.clone(); + let rx_proxy_config = rx_proxy_config.clone(); let tls_acceptor = tls_acceptor.clone(); + let config = config.clone(); tokio::spawn(async move { match tls_acceptor.accept(socket).await { @@ -48,17 +60,10 @@ pub async fn serve_https( .serve_connection( stream, service_fn(move |req: Request<Body>| { - let proxy_config: Arc<ProxyConfig> = proxy_config.borrow().clone(); - handle(remote_addr, req, proxy_config).map(|res| match res { - Err(e) => { - warn!("Handler error: {}", e); - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from(format!("{}", e))) - .map_err(Into::into) - } - x => x, - }) + let https_config = config.clone(); + let proxy_config: Arc<ProxyConfig> = + rx_proxy_config.borrow().clone(); + handle_outer(remote_addr, req, https_config, proxy_config) }), ) .await; @@ -72,11 +77,30 @@ pub async fn serve_https( } } +async fn handle_outer( + remote_addr: SocketAddr, + req: Request<Body>, + https_config: Arc<HttpsConfig>, + proxy_config: Arc<ProxyConfig>, +) -> Result<Response<Body>, Infallible> { + match handle(remote_addr, req, https_config, proxy_config).await { + Err(e) => { + warn!("Handler error: {}", e); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from(format!("{}", e))) + .unwrap()) + } + Ok(r) => Ok(r), + } +} + // Custom echo service, handling two different routes and a // catch-all 404 responder. async fn handle( remote_addr: SocketAddr, req: Request<Body>, + https_config: Arc<HttpsConfig>, proxy_config: Arc<ProxyConfig>, ) -> Result<Response<Body>, anyhow::Error> { let method = req.method().clone(); @@ -91,6 +115,7 @@ async fn handle( .to_str()? }; let path = req.uri().path(); + let accept_encoding = accept_encoding_fork::parse(req.headers()).unwrap_or(None); let best_match = proxy_config .entries @@ -137,7 +162,11 @@ async fn handle( trace!("Response: {:?}", response); info!("{} {} {}", method, response.status().as_u16(), uri); - Ok(response) + if https_config.enable_compression { + try_compress(response, accept_encoding, &https_config) + } else { + Ok(response) + } } else { debug!("{}{} -> NOT FOUND", host, path); info!("{} 404 {}", method, uri); @@ -147,3 +176,61 @@ async fn handle( .body(Body::from("No matching proxy entry"))?) } } + +fn try_compress( + response: Response<Body>, + accept_encoding: Option<Encoding>, + https_config: &HttpsConfig, +) -> Result<Response<Body>> { + // Check if a compression encoding is accepted + let encoding = match accept_encoding { + None | Some(Encoding::Identity) => return Ok(response), + Some(enc) => enc, + }; + + // If already compressed, return as is + if response.headers().get(header::CONTENT_ENCODING).is_some() { + return Ok(response); + } + + // If content type not in mime types for which to compress, return as is + match response.headers().get(header::CONTENT_TYPE) { + Some(ct) => { + let ct_str = ct.to_str()?; + if !https_config.compress_mime_types.iter().any(|x| x == ct_str) { + return Ok(response); + } + } + None => return Ok(response), + }; + + debug!("Compressing response body as {:?}", encoding); + + let (mut head, body) = response.into_parts(); + let body_rd = + StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); + let compressed_body = match encoding { + Encoding::Gzip => { + head.headers + .insert(header::CONTENT_ENCODING, "gzip".parse()?); + Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))) + } + Encoding::Brotli => { + head.headers.insert(header::CONTENT_ENCODING, "br".parse()?); + Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd))) + } + Encoding::Deflate => { + head.headers + .insert(header::CONTENT_ENCODING, "deflate".parse()?); + Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))) + } + Encoding::Zstd => { + head.headers + .insert(header::CONTENT_ENCODING, "zstd".parse()?); + Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))) + } + _ => unreachable!(), + }; + + Ok(Response::from_parts(head, compressed_body)) +} |