use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};
use std::time::{Duration, Instant};
use anyhow::Result;
use log::*;
use accept_encoding_fork::Encoding;
use async_compression::tokio::bufread::*;
use futures::stream::FuturesUnordered;
use futures::{StreamExt, TryStreamExt};
use http::header::{HeaderName, HeaderValue};
use http::method::Method;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::select;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::io::{ReaderStream, StreamReader};
use opentelemetry::{metrics, KeyValue};
use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::{ProxyConfig, ProxyEntry};
use crate::reverse_proxy;
const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600);
pub struct HttpsConfig {
pub bind_addr: SocketAddr,
pub enable_compression: bool,
pub compress_mime_types: Vec<String>,
// used internally to convert Instants to u64
pub time_origin: Instant,
}
struct HttpsMetrics {
requests_received: metrics::Counter<u64>,
requests_served: metrics::Counter<u64>,
request_proxy_duration: metrics::ValueRecorder<f64>,
}
pub async fn serve_https(
config: HttpsConfig,
cert_store: Arc<CertStore>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
mut must_exit: watch::Receiver<bool>,
) -> Result<()> {
let config = Arc::new(config);
let meter = opentelemetry::global::meter("tricot");
let metrics = Arc::new(HttpsMetrics {
requests_received: meter
.u64_counter("https_requests_received")
.with_description("Total number of requests received over HTTPS")
.init(),
requests_served: meter
.u64_counter("https_requests_served")
.with_description("Total number of requests served over HTTPS")
.init(),
request_proxy_duration: meter
.f64_value_recorder("https_request_proxy_duration")
.with_description("Duration between time when request was received, and time when backend returned status code and headers")
.init(),
});
let mut tls_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(StoreResolver(cert_store)));
tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));
info!("Starting to serve on https://{}.", config.bind_addr);
let tcp = TcpListener::bind(config.bind_addr).await?;
let mut connections = FuturesUnordered::new();
while !*must_exit.borrow() {
let wait_conn_finished = async {
if connections.is_empty() {
futures::future::pending().await
} else {
connections.next().await
}
};
let (socket, remote_addr) = select! {
a = tcp.accept() => a?,
_ = wait_conn_finished => continue,
_ = must_exit.changed() => continue,
};
let rx_proxy_config = rx_proxy_config.clone();
let tls_acceptor = tls_acceptor.clone();
let config = config.clone();
let metrics = metrics.clone();
let mut must_exit_2 = must_exit.clone();
let conn = tokio::spawn(async move {
match tls_acceptor.accept(socket).await {
Ok(stream) => {
debug!("TLS handshake was successfull");
let http_conn = Http::new()
.serve_connection(
stream,
service_fn(move |req: Request<Body>| {
let https_config = config.clone();
let proxy_config: Arc<ProxyConfig> =
rx_proxy_config.borrow().clone();
let metrics = metrics.clone();
handle_request(
remote_addr,
req,
https_config,
proxy_config,
metrics,
)
}),
)
.with_upgrades();
let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME);
tokio::pin!(http_conn, timeout);
let http_result = loop {
select! (
r = &mut http_conn => break r.map_err(Into::into),
_ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")),
_ = must_exit_2.changed() => {
if *must_exit_2.borrow() {
http_conn.as_mut().graceful_shutdown();
}
}
)
};
if let Err(http_err) = http_result {
warn!("HTTP error: {}", http_err);
}
}
Err(e) => warn!("Error in TLS connection: {}", e),
}
});
connections.push(conn);
}
drop(tcp);
info!("HTTPS server shutting down, draining remaining connections...");
while connections.next().await.is_some() {}
Ok(())
}
async fn handle_request(
remote_addr: SocketAddr,
req: Request<Body>,
https_config: Arc<HttpsConfig>,
proxy_config: Arc<ProxyConfig>,
metrics: Arc<HttpsMetrics>,
) -> Result<Response<Body>, Infallible> {
let method_tag = KeyValue::new("method", req.method().to_string());
// The host tag is only included in the requests_received metric,
// as for other metrics it can easily lead to cardinality explosions.
let host_tag = KeyValue::new(
"host",
req.uri()
.authority()
.map(|auth| auth.to_string())
.or_else(|| {
req.headers()
.get("host")
.map(|host| host.to_str().unwrap_or_default().to_string())
})
.unwrap_or_default(),
);
metrics
.requests_received
.add(1, &[host_tag, method_tag.clone()]);
let mut tags = vec![method_tag];
let resp = select_target_and_proxy(
&https_config,
&proxy_config,
&metrics,
remote_addr,
req,
&mut tags,
)
.await;
tags.push(KeyValue::new("status_code", resp.status().as_u16() as i64));
metrics.requests_served.add(1, &tags);
Ok(resp)
}
// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn select_target_and_proxy(
https_config: &HttpsConfig,
proxy_config: &ProxyConfig,
metrics: &HttpsMetrics,
remote_addr: SocketAddr,
req: Request<Body>,
tags: &mut Vec<KeyValue>,
) -> Response<Body> {
let received_time = Instant::now();
let method = req.method().clone();
let uri = req.uri().to_string();
let host = if let Some(auth) = req.uri().authority() {
auth.as_str()
} else {
match req.headers().get("host").and_then(|x| x.to_str().ok()) {
Some(host) => host,
None => {
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Missing Host header"))
.unwrap();
}
}
};
let path = req.uri().path();
let best_match = proxy_config
.entries
.iter()
.filter(|ent| {
ent.host.matches(host)
&& ent
.path_prefix
.as_ref()
.map(|prefix| path.starts_with(prefix))
.unwrap_or(true)
})
.max_by_key(|ent| {
(
ent.priority,
ent.path_prefix
.as_ref()
.map(|x| x.len() as i32)
.unwrap_or(0),
ent.same_node,
ent.same_site,
-ent.calls_in_progress.load(Ordering::SeqCst),
-ent.last_call.load(Ordering::SeqCst),
)
});
if let Some(proxy_to) = best_match {
tags.push(KeyValue::new("service", proxy_to.service_name.clone()));
tags.push(KeyValue::new(
"target_addr",
proxy_to.target_addr.to_string(),
));
tags.push(KeyValue::new("same_node", proxy_to.same_node));
tags.push(KeyValue::new("same_site", proxy_to.same_site));
proxy_to.last_call.fetch_max(
(received_time - https_config.time_origin).as_millis() as i64,
Ordering::Relaxed,
);
proxy_to.calls_in_progress.fetch_add(1, Ordering::SeqCst);
debug!("{}{} -> {}", host, path, proxy_to);
trace!("Request: {:?}", req);
let response = match do_proxy(https_config, remote_addr, req, proxy_to).await {
Ok(resp) => resp,
Err(e) => Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("Proxy error: {}", e)))
.unwrap(),
};
proxy_to.calls_in_progress.fetch_sub(1, Ordering::SeqCst);
metrics
.request_proxy_duration
.record(received_time.elapsed().as_secs_f64(), tags);
trace!("Final response: {:?}", response);
info!("{} {} {}", method, response.status().as_u16(), uri);
response
} else {
debug!("{}{} -> NOT FOUND", host, path);
info!("{} 404 {}", method, uri);
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("No matching proxy entry"))
.unwrap()
}
}
async fn do_proxy(
https_config: &HttpsConfig,
remote_addr: SocketAddr,
req: Request<Body>,
proxy_to: &ProxyEntry,
) -> Result<Response<Body>> {
let method = req.method().clone();
let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]);
let mut response = if proxy_to.https_target {
let to_addr = format!("https://{}", proxy_to.target_addr);
reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await?
} else {
let to_addr = format!("http://{}", proxy_to.target_addr);
reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?
};
if response.status().is_success() {
// (TODO: maybe we want to add these headers even if it's not a success?)
for (header, value) in proxy_to.add_headers.iter() {
response.headers_mut().insert(
HeaderName::from_bytes(header.as_bytes())?,
HeaderValue::from_str(value)?,
);
}
}
if https_config.enable_compression {
response = try_compress(response, method, accept_encoding, https_config).await?
};
Ok(response)
}
async fn try_compress(
response: Response<Body>,
method: Method,
accept_encoding: Vec<(Option<Encoding>, f32)>,
https_config: &HttpsConfig,
) -> Result<Response<Body>> {
// Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body)
// Don't compress partial content as it causes issues
// Don't bother compressing non-2xx results
// Don't compress Upgrade responses (e.g. websockets)
// Don't compress responses that are already compressed
if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT))
|| response.status() == StatusCode::PARTIAL_CONTENT
|| !response.status().is_success()
|| response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade"))
|| response.headers().get(header::CONTENT_ENCODING).is_some()
{
return Ok(response);
}
// Select preferred encoding among those proposed in accept_encoding
let max_q: f32 = accept_encoding
.iter()
.max_by_key(|(_, q)| (q * 10000f32) as i64)
.unwrap_or(&(None, 1.))
.1;
let preference = [
Encoding::Zstd,
//Encoding::Brotli,
Encoding::Deflate,
Encoding::Gzip,
];
#[allow(clippy::float_cmp)]
let encoding_opt = accept_encoding
.iter()
.filter(|(_, q)| *q == max_q)
.filter_map(|(enc, _)| *enc)
.filter(|enc| preference.contains(enc))
.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());
// If preferred encoding is none, return as is
let encoding = match encoding_opt {
None | Some(Encoding::Identity) => return Ok(response),
Some(enc) => enc,
};
// If content type not in mime types for which to compress, return as is
match response.headers().get(header::CONTENT_TYPE) {
Some(ct) => {
let ct_str = ct.to_str()?;
let mime_type = match ct_str.split_once(';') {
Some((mime_type, _params)) => mime_type,
None => ct_str,
};
if !https_config
.compress_mime_types
.iter()
.any(|x| x == mime_type)
{
return Ok(response);
}
}
None => return Ok(response), // don't compress if unknown mime type
};
let (mut head, mut body) = response.into_parts();
// ---- If body is smaller than 1400 bytes, don't compress ----
let mut chunks = vec![];
let mut sum_lengths = 0;
while sum_lengths < 1400 {
match body.next().await {
Some(chunk) => {
let chunk = chunk?;
sum_lengths += chunk.len();
chunks.push(chunk);
}
None => {
return Ok(Response::from_parts(head, Body::from(chunks.concat())));
}
}
}
// put beginning chunks back into body
let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body);
// make an async reader from that for compressor
let body_rd =
StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
trace!(
"Compressing response body as {:?} (at least {} bytes)",
encoding,
sum_lengths
);
// we don't know the compressed content-length so remove that header
head.headers.remove(header::CONTENT_LENGTH);
let (encoding, compressed_body) = match encoding {
Encoding::Gzip => (
"gzip",
Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))),
),
// Encoding::Brotli => {
// head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
// Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
// }
Encoding::Deflate => (
"deflate",
Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))),
),
Encoding::Zstd => (
"zstd",
Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))),
),
_ => unreachable!(),
};
head.headers
.insert(header::CONTENT_ENCODING, encoding.parse()?);
Ok(Response::from_parts(head, compressed_body))
}