use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};
use anyhow::Result;
use log::*;
use accept_encoding_fork::Encoding;
use async_compression::tokio::bufread::*;
use futures::StreamExt;
use futures::TryStreamExt;
use http::header::{HeaderName, HeaderValue};
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::io::{ReaderStream, StreamReader};
use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;
pub struct HttpsConfig {
pub bind_addr: SocketAddr,
pub enable_compression: bool,
pub compress_mime_types: Vec<String>,
}
pub async fn serve_https(
config: HttpsConfig,
cert_store: Arc<CertStore>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
) -> Result<()> {
let config = Arc::new(config);
let mut tls_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(StoreResolver(cert_store)));
tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));
info!("Starting to serve on https://{}.", config.bind_addr);
let tcp = TcpListener::bind(config.bind_addr).await?;
loop {
let (socket, remote_addr) = tcp.accept().await?;
let rx_proxy_config = rx_proxy_config.clone();
let tls_acceptor = tls_acceptor.clone();
let config = config.clone();
tokio::spawn(async move {
match tls_acceptor.accept(socket).await {
Ok(stream) => {
debug!("TLS handshake was successfull");
let http_result = Http::new()
.serve_connection(
stream,
service_fn(move |req: Request<Body>| {
let https_config = config.clone();
let proxy_config: Arc<ProxyConfig> =
rx_proxy_config.borrow().clone();
handle_outer(remote_addr, req, https_config, proxy_config)
}),
)
.await;
if let Err(http_err) = http_result {
warn!("HTTP error: {}", http_err);
}
}
Err(e) => warn!("Error in TLS connection: {}", e),
}
});
}
}
async fn handle_outer(
remote_addr: SocketAddr,
req: Request<Body>,
https_config: Arc<HttpsConfig>,
proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, Infallible> {
match handle(remote_addr, req, https_config, proxy_config).await {
Err(e) => {
warn!("Handler error: {}", e);
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("{}", e)))
.unwrap())
}
Ok(r) => Ok(r),
}
}
// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn handle(
remote_addr: SocketAddr,
req: Request<Body>,
https_config: Arc<HttpsConfig>,
proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, anyhow::Error> {
let method = req.method().clone();
let uri = req.uri().to_string();
let host = if let Some(auth) = req.uri().authority() {
auth.as_str()
} else {
req.headers()
.get("host")
.ok_or_else(|| anyhow!("Missing host header"))?
.to_str()?
};
let path = req.uri().path();
let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or(vec![]);
let best_match = proxy_config
.entries
.iter()
.filter(|ent| {
ent.host.matches(host)
&& ent
.path_prefix
.as_ref()
.map(|prefix| path.starts_with(prefix))
.unwrap_or(true)
})
.max_by_key(|ent| {
(
ent.priority,
ent.path_prefix
.as_ref()
.map(|x| x.len() as i32)
.unwrap_or(0),
-ent.calls.load(Ordering::SeqCst),
)
});
if let Some(proxy_to) = best_match {
proxy_to.calls.fetch_add(1, Ordering::SeqCst);
debug!("{}{} -> {}", host, path, proxy_to);
trace!("Request: {:?}", req);
let mut response = if proxy_to.https_target {
let to_addr = format!("https://{}", proxy_to.target_addr);
reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await?
} else {
let to_addr = format!("http://{}", proxy_to.target_addr);
reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?
};
for (header, value) in proxy_to.add_headers.iter() {
response.headers_mut().insert(
HeaderName::from_bytes(header.as_bytes())?,
HeaderValue::from_str(value)?,
);
}
trace!("Response: {:?}", response);
info!("{} {} {}", method, response.status().as_u16(), uri);
if https_config.enable_compression {
try_compress(response, accept_encoding, &https_config).await
} else {
Ok(response)
}
} else {
debug!("{}{} -> NOT FOUND", host, path);
info!("{} 404 {}", method, uri);
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("No matching proxy entry"))?)
}
}
async fn try_compress(
response: Response<Body>,
accept_encoding: Vec<(Option<Encoding>, f32)>,
https_config: &HttpsConfig,
) -> Result<Response<Body>> {
let max_q: f32 = accept_encoding
.iter()
.max_by_key(|(_, q)| (q * 10000f32) as i64)
.unwrap_or(&(None, 1.))
.1;
let preference = [
Encoding::Zstd,
Encoding::Brotli,
Encoding::Deflate,
Encoding::Gzip,
];
let encoding_opt = accept_encoding
.iter()
.filter(|(_, q)| *q == max_q)
.filter_map(|(enc, _)| *enc)
.filter(|enc| preference.contains(enc))
.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());
// If preferred encoding is none, return as is
let encoding = match encoding_opt {
None | Some(Encoding::Identity) => return Ok(response),
Some(enc) => enc,
};
// If already compressed, return as is
if response.headers().get(header::CONTENT_ENCODING).is_some() {
return Ok(response);
}
// If content type not in mime types for which to compress, return as is
match response.headers().get(header::CONTENT_TYPE) {
Some(ct) => {
let ct_str = ct.to_str()?;
let mime_type = match ct_str.split_once(';') {
Some((mime_type, _params)) => mime_type,
None => ct_str,
};
if !https_config
.compress_mime_types
.iter()
.any(|x| x == mime_type)
{
return Ok(response);
}
}
None => return Ok(response),
};
let (mut head, mut body) = response.into_parts();
// ---- If body is smaller than 1400 bytes, don't compress ----
let mut chunks = vec![];
let mut sum_lengths = 0;
while sum_lengths < 1400 {
match body.next().await {
Some(chunk) => {
let chunk = chunk?;
sum_lengths += chunk.len();
chunks.push(chunk);
}
None => {
return Ok(Response::from_parts(head, Body::from(chunks.concat())));
}
}
}
// put beginning chunks back into body
let body = futures::stream::iter(chunks.into_iter().map(|c| Ok(c))).chain(body);
// make an async reader from that for compressor
let body_rd =
StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
debug!("Compressing response body as {:?} (at least {} bytes)", encoding, sum_lengths);
head.headers.remove(header::CONTENT_LENGTH);
let compressed_body = match encoding {
Encoding::Gzip => {
head.headers
.insert(header::CONTENT_ENCODING, "gzip".parse()?);
Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd)))
}
Encoding::Brotli => {
head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
}
Encoding::Deflate => {
head.headers
.insert(header::CONTENT_ENCODING, "deflate".parse()?);
Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd)))
}
Encoding::Zstd => {
head.headers
.insert(header::CONTENT_ENCODING, "zstd".parse()?);
Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd)))
}
_ => unreachable!(),
};
Ok(Response::from_parts(head, compressed_body))
}