aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/https.rs133
-rw-r--r--src/main.rs30
-rw-r--r--src/reverse_proxy.rs61
3 files changed, 159 insertions, 65 deletions
diff --git a/src/https.rs b/src/https.rs
index a389e72..1b467c0 100644
--- a/src/https.rs
+++ b/src/https.rs
@@ -1,44 +1,56 @@
+use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};
use anyhow::Result;
use log::*;
-use futures::FutureExt;
+use accept_encoding_fork::Encoding;
+use async_compression::tokio::bufread::*;
+use futures::TryStreamExt;
use http::header::{HeaderName, HeaderValue};
use hyper::server::conn::Http;
use hyper::service::service_fn;
-use hyper::{Body, Request, Response, StatusCode};
+use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
+use tokio_util::io::{ReaderStream, StreamReader};
use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;
+pub struct HttpsConfig {
+ pub bind_addr: SocketAddr,
+ pub enable_compression: bool,
+ pub compress_mime_types: Vec<String>,
+}
+
pub async fn serve_https(
- bind_addr: SocketAddr,
+ config: HttpsConfig,
cert_store: Arc<CertStore>,
- proxy_config: watch::Receiver<Arc<ProxyConfig>>,
+ rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
) -> Result<()> {
- let mut cfg = rustls::ServerConfig::builder()
+ let config = Arc::new(config);
+
+ let mut tls_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(StoreResolver(cert_store)));
- cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
- let tls_cfg = Arc::new(cfg);
- let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg));
+ tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
+ let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));
- info!("Starting to serve on https://{}.", bind_addr);
+ info!("Starting to serve on https://{}.", config.bind_addr);
- let tcp = TcpListener::bind(bind_addr).await?;
+ let tcp = TcpListener::bind(config.bind_addr).await?;
loop {
let (socket, remote_addr) = tcp.accept().await?;
- let proxy_config = proxy_config.clone();
+ let rx_proxy_config = rx_proxy_config.clone();
let tls_acceptor = tls_acceptor.clone();
+ let config = config.clone();
tokio::spawn(async move {
match tls_acceptor.accept(socket).await {
@@ -48,17 +60,10 @@ pub async fn serve_https(
.serve_connection(
stream,
service_fn(move |req: Request<Body>| {
- let proxy_config: Arc<ProxyConfig> = proxy_config.borrow().clone();
- handle(remote_addr, req, proxy_config).map(|res| match res {
- Err(e) => {
- warn!("Handler error: {}", e);
- Response::builder()
- .status(StatusCode::INTERNAL_SERVER_ERROR)
- .body(Body::from(format!("{}", e)))
- .map_err(Into::into)
- }
- x => x,
- })
+ let https_config = config.clone();
+ let proxy_config: Arc<ProxyConfig> =
+ rx_proxy_config.borrow().clone();
+ handle_outer(remote_addr, req, https_config, proxy_config)
}),
)
.await;
@@ -72,11 +77,30 @@ pub async fn serve_https(
}
}
+async fn handle_outer(
+ remote_addr: SocketAddr,
+ req: Request<Body>,
+ https_config: Arc<HttpsConfig>,
+ proxy_config: Arc<ProxyConfig>,
+) -> Result<Response<Body>, Infallible> {
+ match handle(remote_addr, req, https_config, proxy_config).await {
+ Err(e) => {
+ warn!("Handler error: {}", e);
+ Ok(Response::builder()
+ .status(StatusCode::INTERNAL_SERVER_ERROR)
+ .body(Body::from(format!("{}", e)))
+ .unwrap())
+ }
+ Ok(r) => Ok(r),
+ }
+}
+
// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn handle(
remote_addr: SocketAddr,
req: Request<Body>,
+ https_config: Arc<HttpsConfig>,
proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, anyhow::Error> {
let method = req.method().clone();
@@ -91,6 +115,7 @@ async fn handle(
.to_str()?
};
let path = req.uri().path();
+ let accept_encoding = accept_encoding_fork::parse(req.headers()).unwrap_or(None);
let best_match = proxy_config
.entries
@@ -137,7 +162,11 @@ async fn handle(
trace!("Response: {:?}", response);
info!("{} {} {}", method, response.status().as_u16(), uri);
- Ok(response)
+ if https_config.enable_compression {
+ try_compress(response, accept_encoding, &https_config)
+ } else {
+ Ok(response)
+ }
} else {
debug!("{}{} -> NOT FOUND", host, path);
info!("{} 404 {}", method, uri);
@@ -147,3 +176,61 @@ async fn handle(
.body(Body::from("No matching proxy entry"))?)
}
}
+
+fn try_compress(
+ response: Response<Body>,
+ accept_encoding: Option<Encoding>,
+ https_config: &HttpsConfig,
+) -> Result<Response<Body>> {
+ // Check if a compression encoding is accepted
+ let encoding = match accept_encoding {
+ None | Some(Encoding::Identity) => return Ok(response),
+ Some(enc) => enc,
+ };
+
+ // If already compressed, return as is
+ if response.headers().get(header::CONTENT_ENCODING).is_some() {
+ return Ok(response);
+ }
+
+ // If content type not in mime types for which to compress, return as is
+ match response.headers().get(header::CONTENT_TYPE) {
+ Some(ct) => {
+ let ct_str = ct.to_str()?;
+ if !https_config.compress_mime_types.iter().any(|x| x == ct_str) {
+ return Ok(response);
+ }
+ }
+ None => return Ok(response),
+ };
+
+ debug!("Compressing response body as {:?}", encoding);
+
+ let (mut head, body) = response.into_parts();
+ let body_rd =
+ StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
+ let compressed_body = match encoding {
+ Encoding::Gzip => {
+ head.headers
+ .insert(header::CONTENT_ENCODING, "gzip".parse()?);
+ Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd)))
+ }
+ Encoding::Brotli => {
+ head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
+ Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
+ }
+ Encoding::Deflate => {
+ head.headers
+ .insert(header::CONTENT_ENCODING, "deflate".parse()?);
+ Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd)))
+ }
+ Encoding::Zstd => {
+ head.headers
+ .insert(header::CONTENT_ENCODING, "zstd".parse()?);
+ Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd)))
+ }
+ _ => unreachable!(),
+ };
+
+ Ok(Response::from_parts(head, compressed_body))
+}
diff --git a/src/main.rs b/src/main.rs
index 61fc747..febe540 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -58,6 +58,18 @@ struct Opt {
/// E-mail address for Let's Encrypt certificate requests
#[structopt(long = "letsencrypt-email", env = "TRICOT_LETSENCRYPT_EMAIL")]
pub letsencrypt_email: String,
+
+ /// Enable compression of responses
+ #[structopt(long = "enable-compression", env = "TRICOT_ENABLE_COMPRESSION")]
+ pub enable_compression: bool,
+
+ /// Mime types for which to enable compression (comma-separated list)
+ #[structopt(
+ long = "compress-mime-types",
+ env = "TRICOT_COMPRESS_MIME_TYPES",
+ default_value = "text/html,text/plain,text/css,text/javascript,application/javascript,image/svg+xml"
+ )]
+ pub compress_mime_types: String,
}
#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
@@ -87,13 +99,19 @@ async fn main() {
);
tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err));
+
+ let https_config = https::HttpsConfig {
+ bind_addr: opt.https_bind_addr,
+ enable_compression: opt.enable_compression,
+ compress_mime_types: opt
+ .compress_mime_types
+ .split(",")
+ .map(|x| x.to_string())
+ .collect(),
+ };
tokio::spawn(
- https::serve_https(
- opt.https_bind_addr,
- cert_store.clone(),
- rx_proxy_config.clone(),
- )
- .map_err(exit_on_err),
+ https::serve_https(https_config, cert_store.clone(), rx_proxy_config.clone())
+ .map_err(exit_on_err),
);
while rx_proxy_config.changed().await.is_ok() {
diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs
index 72644b7..445f6ef 100644
--- a/src/reverse_proxy.rs
+++ b/src/reverse_proxy.rs
@@ -12,33 +12,25 @@ use log::*;
use http::header::HeaderName;
use hyper::header::{HeaderMap, HeaderValue};
-use hyper::{Body, Client, Request, Response, Uri};
-use lazy_static::lazy_static;
+use hyper::{header, Body, Client, Request, Response, Uri};
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, ServerName};
use crate::tls_util::HttpsConnectorFixedDnsname;
-fn is_hop_header(name: &str) -> bool {
- use unicase::Ascii;
-
- // A list of the headers, using `unicase` to help us compare without
- // worrying about the case, and `lazy_static!` to prevent reallocation
- // of the vector.
- lazy_static! {
- static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![
- Ascii::new("Connection"),
- Ascii::new("Keep-Alive"),
- Ascii::new("Proxy-Authenticate"),
- Ascii::new("Proxy-Authorization"),
- Ascii::new("Te"),
- Ascii::new("Trailers"),
- Ascii::new("Transfer-Encoding"),
- Ascii::new("Upgrade"),
- ];
- }
-
- HOP_HEADERS.iter().any(|h| h == &name)
+const HOP_HEADERS: &[HeaderName] = &[
+ header::CONNECTION,
+ //header::KEEP_ALIVE,
+ header::PROXY_AUTHENTICATE,
+ header::PROXY_AUTHORIZATION,
+ header::TE,
+ header::TRAILER,
+ header::TRANSFER_ENCODING,
+ header::UPGRADE,
+];
+
+fn is_hop_header(name: &HeaderName) -> bool {
+ HOP_HEADERS.iter().any(|h| h == name)
}
/// Returns a clone of the headers without the [hop-by-hop headers].
@@ -47,7 +39,7 @@ fn is_hop_header(name: &str) -> bool {
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
let mut result = HeaderMap::new();
for (k, v) in headers.iter() {
- if !is_hop_header(k.as_str()) {
+ if !is_hop_header(&k) {
result.append(k.clone(), v.clone());
}
}
@@ -80,10 +72,7 @@ fn create_proxied_request<B>(
*builder.headers_mut().unwrap() = remove_hop_headers(request.headers());
// If request does not have host header, add it from original URI authority
- let host_header_name = "host";
- if let hyper::header::Entry::Vacant(entry) =
- builder.headers_mut().unwrap().entry(host_header_name)
- {
+ if let header::Entry::Vacant(entry) = builder.headers_mut().unwrap().entry(header::HOST) {
if let Some(authority) = request.uri().authority() {
entry.insert(authority.as_str().parse()?);
}
@@ -96,11 +85,11 @@ fn create_proxied_request<B>(
.unwrap()
.entry(x_forwarded_for_header_name)
{
- hyper::header::Entry::Vacant(entry) => {
+ header::Entry::Vacant(entry) => {
entry.insert(client_ip.to_string().parse()?);
}
- hyper::header::Entry::Occupied(mut entry) => {
+ header::Entry::Occupied(mut entry) => {
let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
entry.insert(addr.parse()?);
}
@@ -112,17 +101,17 @@ fn create_proxied_request<B>(
);
// Proxy upgrade requests properly
- if let Some(conn) = request.headers().get("connection") {
+ if let Some(conn) = request.headers().get(header::CONNECTION) {
if conn.to_str()?.to_lowercase() == "upgrade" {
- if let Some(upgrade) = request.headers().get("upgrade") {
- builder.headers_mut().unwrap().insert(
- HeaderName::from_bytes(b"connection")?,
- "Upgrade".try_into()?,
- );
+ if let Some(upgrade) = request.headers().get(header::UPGRADE) {
+ builder
+ .headers_mut()
+ .unwrap()
+ .insert(header::CONNECTION, "Upgrade".try_into()?);
builder
.headers_mut()
.unwrap()
- .insert(HeaderName::from_bytes(b"upgrade")?, upgrade.clone());
+ .insert(header::UPGRADE, upgrade.clone());
}
}
}