use std::convert::Infallible; use std::net::SocketAddr; use std::sync::{atomic::Ordering, Arc}; use anyhow::Result; use log::*; use accept_encoding_fork::Encoding; use async_compression::tokio::bufread::*; use futures::StreamExt; use futures::TryStreamExt; use http::header::{HeaderName, HeaderValue}; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{header, Body, Request, Response, StatusCode}; use tokio::net::TcpListener; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; use tokio_util::io::{ReaderStream, StreamReader}; use crate::cert_store::{CertStore, StoreResolver}; use crate::proxy_config::ProxyConfig; use crate::reverse_proxy; pub struct HttpsConfig { pub bind_addr: SocketAddr, pub enable_compression: bool, pub compress_mime_types: Vec, } pub async fn serve_https( config: HttpsConfig, cert_store: Arc, rx_proxy_config: watch::Receiver>, ) -> Result<()> { let config = Arc::new(config); let mut tls_cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_cert_resolver(Arc::new(StoreResolver(cert_store))); tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg))); info!("Starting to serve on https://{}.", config.bind_addr); let tcp = TcpListener::bind(config.bind_addr).await?; loop { let (socket, remote_addr) = tcp.accept().await?; let rx_proxy_config = rx_proxy_config.clone(); let tls_acceptor = tls_acceptor.clone(); let config = config.clone(); tokio::spawn(async move { match tls_acceptor.accept(socket).await { Ok(stream) => { debug!("TLS handshake was successfull"); let http_result = Http::new() .serve_connection( stream, service_fn(move |req: Request| { let https_config = config.clone(); let proxy_config: Arc = rx_proxy_config.borrow().clone(); handle_outer(remote_addr, req, https_config, proxy_config) }), ) .await; if let Err(http_err) = http_result { warn!("HTTP error: {}", http_err); } } Err(e) => warn!("Error in TLS connection: {}", e), } }); } } async fn handle_outer( remote_addr: SocketAddr, req: Request, https_config: Arc, proxy_config: Arc, ) -> Result, Infallible> { match handle(remote_addr, req, https_config, proxy_config).await { Err(e) => { warn!("Handler error: {}", e); Ok(Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(format!("{}", e))) .unwrap()) } Ok(r) => Ok(r), } } // Custom echo service, handling two different routes and a // catch-all 404 responder. async fn handle( remote_addr: SocketAddr, req: Request, https_config: Arc, proxy_config: Arc, ) -> Result, anyhow::Error> { let method = req.method().clone(); let uri = req.uri().to_string(); let host = if let Some(auth) = req.uri().authority() { auth.as_str() } else { req.headers() .get("host") .ok_or_else(|| anyhow!("Missing host header"))? .to_str()? }; let path = req.uri().path(); let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]); let best_match = proxy_config .entries .iter() .filter(|ent| { ent.host.matches(host) && ent .path_prefix .as_ref() .map(|prefix| path.starts_with(prefix)) .unwrap_or(true) }) .max_by_key(|ent| { ( ent.priority, ent.path_prefix .as_ref() .map(|x| x.len() as i32) .unwrap_or(0), -ent.calls.load(Ordering::SeqCst), ) }); if let Some(proxy_to) = best_match { proxy_to.calls.fetch_add(1, Ordering::SeqCst); debug!("{}{} -> {}", host, path, proxy_to); trace!("Request: {:?}", req); let mut response = if proxy_to.https_target { let to_addr = format!("https://{}", proxy_to.target_addr); reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await? } else { let to_addr = format!("http://{}", proxy_to.target_addr); reverse_proxy::call(remote_addr.ip(), &to_addr, req).await? }; for (header, value) in proxy_to.add_headers.iter() { response.headers_mut().insert( HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?, ); } trace!("Response: {:?}", response); info!("{} {} {}", method, response.status().as_u16(), uri); if https_config.enable_compression { try_compress(response, accept_encoding, &https_config).await } else { Ok(response) } } else { debug!("{}{} -> NOT FOUND", host, path); info!("{} 404 {}", method, uri); Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("No matching proxy entry"))?) } } async fn try_compress( response: Response, accept_encoding: Vec<(Option, f32)>, https_config: &HttpsConfig, ) -> Result> { let max_q: f32 = accept_encoding .iter() .max_by_key(|(_, q)| (q * 10000f32) as i64) .unwrap_or(&(None, 1.)) .1; let preference = [ Encoding::Zstd, //Encoding::Brotli, Encoding::Deflate, Encoding::Gzip, ]; #[allow(clippy::float_cmp)] let encoding_opt = accept_encoding .iter() .filter(|(_, q)| *q == max_q) .filter_map(|(enc, _)| *enc) .filter(|enc| preference.contains(enc)) .min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap()); // If preferred encoding is none, return as is let encoding = match encoding_opt { None | Some(Encoding::Identity) => return Ok(response), Some(enc) => enc, }; // If already compressed, return as is if response.headers().get(header::CONTENT_ENCODING).is_some() { return Ok(response); } // If content type not in mime types for which to compress, return as is match response.headers().get(header::CONTENT_TYPE) { Some(ct) => { let ct_str = ct.to_str()?; let mime_type = match ct_str.split_once(';') { Some((mime_type, _params)) => mime_type, None => ct_str, }; if !https_config .compress_mime_types .iter() .any(|x| x == mime_type) { return Ok(response); } } None => return Ok(response), }; let (mut head, mut body) = response.into_parts(); // ---- If body is smaller than 1400 bytes, don't compress ---- let mut chunks = vec![]; let mut sum_lengths = 0; while sum_lengths < 1400 { match body.next().await { Some(chunk) => { let chunk = chunk?; sum_lengths += chunk.len(); chunks.push(chunk); } None => { return Ok(Response::from_parts(head, Body::from(chunks.concat()))); } } } // put beginning chunks back into body let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body); // make an async reader from that for compressor let body_rd = StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); debug!( "Compressing response body as {:?} (at least {} bytes)", encoding, sum_lengths ); head.headers.remove(header::CONTENT_LENGTH); let compressed_body = match encoding { Encoding::Gzip => { head.headers .insert(header::CONTENT_ENCODING, "gzip".parse()?); Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))) } // Encoding::Brotli => { // head.headers.insert(header::CONTENT_ENCODING, "br".parse()?); // Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd))) // } Encoding::Deflate => { head.headers .insert(header::CONTENT_ENCODING, "deflate".parse()?); Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))) } Encoding::Zstd => { head.headers .insert(header::CONTENT_ENCODING, "zstd".parse()?); Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))) } _ => unreachable!(), }; Ok(Response::from_parts(head, compressed_body)) }