use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};
use std::time::Duration;
use anyhow::Result;
use log::*;
use accept_encoding_fork::Encoding;
use async_compression::tokio::bufread::*;
use futures::stream::FuturesUnordered;
use futures::{StreamExt, TryStreamExt};
use http::header::{HeaderName, HeaderValue};
use http::method::Method;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::select;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::io::{ReaderStream, StreamReader};
use opentelemetry::{metrics, KeyValue};
use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;
const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600);
pub struct HttpsConfig {
pub bind_addr: SocketAddr,
pub enable_compression: bool,
pub compress_mime_types: Vec<String>,
}
struct HttpsMetrics {
requests_received: metrics::Counter<u64>,
requests_served: metrics::Counter<u64>,
}
pub async fn serve_https(
config: HttpsConfig,
cert_store: Arc<CertStore>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
mut must_exit: watch::Receiver<bool>,
) -> Result<()> {
let config = Arc::new(config);
let meter = opentelemetry::global::meter("tricot");
let metrics = Arc::new(HttpsMetrics {
requests_received: meter
.u64_counter("https_requests_received")
.with_description("Total number of requests received over HTTPS")
.init(),
requests_served: meter
.u64_counter("https_requests_served")
.with_description("Total number of requests served over HTTPS")
.init(),
});
let mut tls_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(StoreResolver(cert_store)));
tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));
info!("Starting to serve on https://{}.", config.bind_addr);
let tcp = TcpListener::bind(config.bind_addr).await?;
let mut connections = FuturesUnordered::new();
while !*must_exit.borrow() {
let wait_conn_finished = async {
if connections.is_empty() {
futures::future::pending().await
} else {
connections.next().await
}
};
let (socket, remote_addr) = select! {
a = tcp.accept() => a?,
_ = wait_conn_finished => continue,
_ = must_exit.changed() => continue,
};
let rx_proxy_config = rx_proxy_config.clone();
let tls_acceptor = tls_acceptor.clone();
let config = config.clone();
let metrics = metrics.clone();
let mut must_exit_2 = must_exit.clone();
let conn = tokio::spawn(async move {
match tls_acceptor.accept(socket).await {
Ok(stream) => {
debug!("TLS handshake was successfull");
let http_conn = Http::new()
.serve_connection(
stream,
service_fn(move |req: Request<Body>| {
let https_config = config.clone();
let proxy_config: Arc<ProxyConfig> =
rx_proxy_config.borrow().clone();
let metrics = metrics.clone();
handle_outer(remote_addr, req, https_config, proxy_config, metrics)
}),
)
.with_upgrades();
let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME);
tokio::pin!(http_conn, timeout);
let http_result = loop {
select! (
r = &mut http_conn => break r.map_err(Into::into),
_ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")),
_ = must_exit_2.changed() => {
if *must_exit_2.borrow() {
http_conn.as_mut().graceful_shutdown();
}
}
)
};
if let Err(http_err) = http_result {
warn!("HTTP error: {}", http_err);
}
}
Err(e) => warn!("Error in TLS connection: {}", e),
}
});
connections.push(conn);
}
drop(tcp);
info!("HTTPS server shutting down, draining remaining connections...");
while connections.next().await.is_some() {}
Ok(())
}
async fn handle_outer(
remote_addr: SocketAddr,
req: Request<Body>,
https_config: Arc<HttpsConfig>,
proxy_config: Arc<ProxyConfig>,
metrics: Arc<HttpsMetrics>,
) -> Result<Response<Body>, Infallible> {
let mut tags = vec![
KeyValue::new("method", req.method().to_string()),
KeyValue::new(
"host",
req.uri()
.authority()
.map(|auth| auth.to_string())
.or_else(|| {
req.headers()
.get("host")
.map(|host| host.to_str().unwrap_or_default().to_string())
})
.unwrap_or_default(),
),
];
metrics.requests_received.add(1, &tags);
let resp = match handle(remote_addr, req, https_config, proxy_config, &mut tags).await {
Err(e) => {
warn!("Handler error: {}", e);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("{}", e)))
.unwrap()
}
Ok(r) => r,
};
tags.push(KeyValue::new(
"response_code",
resp.status().as_u16().to_string(),
));
metrics.requests_served.add(1, &tags);
Ok(resp)
}
// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn handle(
remote_addr: SocketAddr,
req: Request<Body>,
https_config: Arc<HttpsConfig>,
proxy_config: Arc<ProxyConfig>,
tags: &mut Vec<KeyValue>,
) -> Result<Response<Body>, anyhow::Error> {
let method = req.method().clone();
let uri = req.uri().to_string();
let host = if let Some(auth) = req.uri().authority() {
auth.as_str()
} else {
req.headers()
.get("host")
.ok_or_else(|| anyhow!("Missing host header"))?
.to_str()?
};
let path = req.uri().path();
let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]);
let best_match = proxy_config
.entries
.iter()
.filter(|ent| {
ent.host.matches(host)
&& ent
.path_prefix
.as_ref()
.map(|prefix| path.starts_with(prefix))
.unwrap_or(true)
})
.max_by_key(|ent| {
(
ent.priority,
ent.path_prefix
.as_ref()
.map(|x| x.len() as i32)
.unwrap_or(0),
ent.same_node,
ent.same_site,
-ent.calls.load(Ordering::SeqCst),
)
});
if let Some(proxy_to) = best_match {
tags.push(KeyValue::new("service_name", proxy_to.service_name.clone()));
tags.push(KeyValue::new(
"target_addr",
proxy_to.target_addr.to_string(),
));
tags.push(KeyValue::new(
"https_target",
proxy_to.https_target.to_string(),
));
tags.push(KeyValue::new("same_node", proxy_to.same_node.to_string()));
tags.push(KeyValue::new("same_site", proxy_to.same_site.to_string()));
proxy_to.calls.fetch_add(1, Ordering::SeqCst);
debug!("{}{} -> {}", host, path, proxy_to);
trace!("Request: {:?}", req);
let mut response = if proxy_to.https_target {
let to_addr = format!("https://{}", proxy_to.target_addr);
handle_error(reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await)
} else {
let to_addr = format!("http://{}", proxy_to.target_addr);
handle_error(reverse_proxy::call(remote_addr.ip(), &to_addr, req).await)
};
if response.status().is_success() {
// (TODO: maybe we want to add these headers even if it's not a success?)
for (header, value) in proxy_to.add_headers.iter() {
response.headers_mut().insert(
HeaderName::from_bytes(header.as_bytes())?,
HeaderValue::from_str(value)?,
);
}
}
if https_config.enable_compression {
response =
try_compress(response, method.clone(), accept_encoding, &https_config).await?
};
trace!("Final response: {:?}", response);
info!("{} {} {}", method, response.status().as_u16(), uri);
Ok(response)
} else {
debug!("{}{} -> NOT FOUND", host, path);
info!("{} 404 {}", method, uri);
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("No matching proxy entry"))?)
}
}
fn handle_error(resp: Result<Response<Body>>) -> Response<Body> {
match resp {
Ok(resp) => resp,
Err(e) => Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("Proxy error: {}", e)))
.unwrap(),
}
}
async fn try_compress(
response: Response<Body>,
method: Method,
accept_encoding: Vec<(Option<Encoding>, f32)>,
https_config: &HttpsConfig,
) -> Result<Response<Body>> {
// Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body)
// Don't compress partial content as it causes issues
// Don't bother compressing non-2xx results
// Don't compress Upgrade responses (e.g. websockets)
// Don't compress responses that are already compressed
if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT))
|| response.status() == StatusCode::PARTIAL_CONTENT
|| !response.status().is_success()
|| response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade"))
|| response.headers().get(header::CONTENT_ENCODING).is_some()
{
return Ok(response);
}
// Select preferred encoding among those proposed in accept_encoding
let max_q: f32 = accept_encoding
.iter()
.max_by_key(|(_, q)| (q * 10000f32) as i64)
.unwrap_or(&(None, 1.))
.1;
let preference = [
Encoding::Zstd,
//Encoding::Brotli,
Encoding::Deflate,
Encoding::Gzip,
];
#[allow(clippy::float_cmp)]
let encoding_opt = accept_encoding
.iter()
.filter(|(_, q)| *q == max_q)
.filter_map(|(enc, _)| *enc)
.filter(|enc| preference.contains(enc))
.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());
// If preferred encoding is none, return as is
let encoding = match encoding_opt {
None | Some(Encoding::Identity) => return Ok(response),
Some(enc) => enc,
};
// If content type not in mime types for which to compress, return as is
match response.headers().get(header::CONTENT_TYPE) {
Some(ct) => {
let ct_str = ct.to_str()?;
let mime_type = match ct_str.split_once(';') {
Some((mime_type, _params)) => mime_type,
None => ct_str,
};
if !https_config
.compress_mime_types
.iter()
.any(|x| x == mime_type)
{
return Ok(response);
}
}
None => return Ok(response), // don't compress if unknown mime type
};
let (mut head, mut body) = response.into_parts();
// ---- If body is smaller than 1400 bytes, don't compress ----
let mut chunks = vec![];
let mut sum_lengths = 0;
while sum_lengths < 1400 {
match body.next().await {
Some(chunk) => {
let chunk = chunk?;
sum_lengths += chunk.len();
chunks.push(chunk);
}
None => {
return Ok(Response::from_parts(head, Body::from(chunks.concat())));
}
}
}
// put beginning chunks back into body
let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body);
// make an async reader from that for compressor
let body_rd =
StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
trace!(
"Compressing response body as {:?} (at least {} bytes)",
encoding,
sum_lengths
);
// we don't know the compressed content-length so remove that header
head.headers.remove(header::CONTENT_LENGTH);
let (encoding, compressed_body) = match encoding {
Encoding::Gzip => (
"gzip",
Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))),
),
// Encoding::Brotli => {
// head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
// Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
// }
Encoding::Deflate => (
"deflate",
Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))),
),
Encoding::Zstd => (
"zstd",
Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))),
),
_ => unreachable!(),
};
head.headers
.insert(header::CONTENT_ENCODING, encoding.parse()?);
Ok(Response::from_parts(head, compressed_body))
}