use std::convert::Infallible; use std::net::SocketAddr; use std::sync::{atomic::Ordering, Arc}; use std::time::{Duration, Instant}; use anyhow::Result; use tracing::*; use accept_encoding_fork::Encoding; use async_compression::tokio::bufread::*; use futures::stream::FuturesUnordered; use futures::{StreamExt, TryStreamExt}; use http::header::{HeaderName, HeaderValue}; use http::method::Method; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{header, Body, Request, Response, StatusCode}; use tokio::net::TcpListener; use tokio::select; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; use tokio_util::io::{ReaderStream, StreamReader}; use opentelemetry::{metrics, KeyValue}; use crate::cert_store::{CertStore, StoreResolver}; use crate::proxy_config::{ProxyConfig, ProxyEntry}; use crate::reverse_proxy; const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600); pub struct HttpsConfig { pub bind_addr: SocketAddr, pub enable_compression: bool, pub compress_mime_types: Vec, // used internally to convert Instants to u64 pub time_origin: Instant, } struct HttpsMetrics { requests_received: metrics::Counter, requests_served: metrics::Counter, request_proxy_duration: metrics::ValueRecorder, } pub async fn serve_https( config: HttpsConfig, cert_store: Arc, rx_proxy_config: watch::Receiver>, mut must_exit: watch::Receiver, ) -> Result<()> { let config = Arc::new(config); let meter = opentelemetry::global::meter("tricot"); let metrics = Arc::new(HttpsMetrics { requests_received: meter .u64_counter("https_requests_received") .with_description("Total number of requests received over HTTPS") .init(), requests_served: meter .u64_counter("https_requests_served") .with_description("Total number of requests served over HTTPS") .init(), request_proxy_duration: meter .f64_value_recorder("https_request_proxy_duration") .with_description("Duration between time when request was received, and time when backend returned status code and headers") .init(), }); let mut tls_cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_cert_resolver(Arc::new(StoreResolver(cert_store))); tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg))); info!("Starting to serve on https://{}.", config.bind_addr); let tcp = TcpListener::bind(config.bind_addr).await?; let mut connections = FuturesUnordered::new(); while !*must_exit.borrow() { let wait_conn_finished = async { if connections.is_empty() { futures::future::pending().await } else { connections.next().await } }; let (socket, remote_addr) = select! { a = tcp.accept() => a?, _ = wait_conn_finished => continue, _ = must_exit.changed() => continue, }; let rx_proxy_config = rx_proxy_config.clone(); let tls_acceptor = tls_acceptor.clone(); let config = config.clone(); let metrics = metrics.clone(); let mut must_exit_2 = must_exit.clone(); let conn = tokio::spawn(async move { match tls_acceptor.accept(socket).await { Ok(stream) => { debug!("TLS handshake was successfull"); let http_conn = Http::new() .serve_connection( stream, service_fn(move |req: Request| { let https_config = config.clone(); let proxy_config: Arc = rx_proxy_config.borrow().clone(); let metrics = metrics.clone(); handle_request( remote_addr, req, https_config, proxy_config, metrics, ) }), ) .with_upgrades(); let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME); tokio::pin!(http_conn, timeout); let http_result = loop { select! ( r = &mut http_conn => break r.map_err(Into::into), _ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")), _ = must_exit_2.changed() => { if *must_exit_2.borrow() { http_conn.as_mut().graceful_shutdown(); } } ) }; if let Err(http_err) = http_result { warn!("HTTP error: {}", http_err); } } Err(e) => warn!("Error in TLS connection: {}", e), } }); connections.push(conn); } drop(tcp); info!("HTTPS server shutting down, draining remaining connections..."); while connections.next().await.is_some() {} Ok(()) } async fn handle_request( remote_addr: SocketAddr, req: Request, https_config: Arc, proxy_config: Arc, metrics: Arc, ) -> Result, Infallible> { let method_tag = KeyValue::new("method", req.method().to_string()); // The host tag is only included in the requests_received metric, // as for other metrics it can easily lead to cardinality explosions. let host_tag = KeyValue::new( "host", req.uri() .authority() .map(|auth| auth.to_string()) .or_else(|| { req.headers() .get("host") .map(|host| host.to_str().unwrap_or_default().to_string()) }) .unwrap_or_default(), ); metrics .requests_received .add(1, &[host_tag, method_tag.clone()]); let mut tags = vec![method_tag]; let resp = select_target_and_proxy( &https_config, &proxy_config, &metrics, remote_addr, req, &mut tags, ) .await; tags.push(KeyValue::new("status_code", resp.status().as_u16() as i64)); metrics.requests_served.add(1, &tags); Ok(resp) } // Custom echo service, handling two different routes and a // catch-all 404 responder. async fn select_target_and_proxy( https_config: &HttpsConfig, proxy_config: &ProxyConfig, metrics: &HttpsMetrics, remote_addr: SocketAddr, req: Request, tags: &mut Vec, ) -> Response { let received_time = Instant::now(); let method = req.method().clone(); let uri = req.uri().to_string(); let host = if let Some(auth) = req.uri().authority() { auth.as_str() } else { match req.headers().get("host").and_then(|x| x.to_str().ok()) { Some(host) => host, None => { return Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::from("Missing Host header")) .unwrap(); } } }; let path = req.uri().path(); let best_match = proxy_config .entries .iter() .filter(|ent| { ent.flags.healthy && ent.host.matches(host) && ent .path_prefix .as_ref() .map(|prefix| path.starts_with(prefix)) .unwrap_or(true) }) .max_by_key(|ent| { ( ent.priority, ent.path_prefix .as_ref() .map(|x| x.len() as i32) .unwrap_or(0), (ent.flags.same_node || ent.flags.site_lb || ent.flags.global_lb), (ent.flags.same_site || ent.flags.global_lb), -ent.calls_in_progress.load(Ordering::SeqCst), -ent.last_call.load(Ordering::SeqCst), ) }); if let Some(proxy_to) = best_match { tags.push(KeyValue::new("service", proxy_to.service_name.clone())); tags.push(KeyValue::new( "target_addr", proxy_to.target_addr.to_string(), )); tags.push(KeyValue::new("same_node", proxy_to.flags.same_node)); tags.push(KeyValue::new("same_site", proxy_to.flags.same_site)); proxy_to.last_call.fetch_max( (received_time - https_config.time_origin).as_millis() as i64, Ordering::Relaxed, ); proxy_to.calls_in_progress.fetch_add(1, Ordering::SeqCst); debug!("{}{} -> {}", host, path, proxy_to); trace!("Request: {:?}", req); let response = match do_proxy(https_config, remote_addr, req, proxy_to).await { Ok(resp) => resp, Err(e) => Response::builder() .status(StatusCode::BAD_GATEWAY) .body(Body::from(format!("Proxy error: {}", e))) .unwrap(), }; proxy_to.calls_in_progress.fetch_sub(1, Ordering::SeqCst); metrics .request_proxy_duration .record(received_time.elapsed().as_secs_f64(), tags); trace!("Final response: {:?}", response); info!("{} {} {}", method, response.status().as_u16(), uri); response } else { debug!("{}{} -> NOT FOUND", host, path); info!("{} 404 {}", method, uri); Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("No matching proxy entry")) .unwrap() } } async fn do_proxy( https_config: &HttpsConfig, remote_addr: SocketAddr, req: Request, proxy_to: &ProxyEntry, ) -> Result> { let method = req.method().clone(); let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]); let mut response = if proxy_to.https_target { let to_addr = format!("https://{}", proxy_to.target_addr); reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await? } else { let to_addr = format!("http://{}", proxy_to.target_addr); reverse_proxy::call(remote_addr.ip(), &to_addr, req).await? }; if response.status().is_success() { // (TODO: maybe we want to add these headers even if it's not a success?) for (header, value) in proxy_to.add_headers.iter() { response.headers_mut().insert( HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?, ); } } if https_config.enable_compression { response = try_compress(response, method, accept_encoding, https_config).await? }; Ok(response) } async fn try_compress( response: Response, method: Method, accept_encoding: Vec<(Option, f32)>, https_config: &HttpsConfig, ) -> Result> { // Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body) // Don't compress partial content as it causes issues // Don't bother compressing non-2xx results // Don't compress Upgrade responses (e.g. websockets) // Don't compress responses that are already compressed if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT)) || response.status() == StatusCode::PARTIAL_CONTENT || !response.status().is_success() || response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade")) || response.headers().get(header::CONTENT_ENCODING).is_some() { return Ok(response); } // Select preferred encoding among those proposed in accept_encoding let max_q: f32 = accept_encoding .iter() .max_by_key(|(_, q)| (q * 10000f32) as i64) .unwrap_or(&(None, 1.)) .1; let preference = [ Encoding::Zstd, //Encoding::Brotli, Encoding::Deflate, Encoding::Gzip, ]; #[allow(clippy::float_cmp)] let encoding_opt = accept_encoding .iter() .filter(|(_, q)| *q == max_q) .filter_map(|(enc, _)| *enc) .filter(|enc| preference.contains(enc)) .min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap()); // If preferred encoding is none, return as is let encoding = match encoding_opt { None | Some(Encoding::Identity) => return Ok(response), Some(enc) => enc, }; // If content type not in mime types for which to compress, return as is match response.headers().get(header::CONTENT_TYPE) { Some(ct) => { let ct_str = ct.to_str()?; let mime_type = match ct_str.split_once(';') { Some((mime_type, _params)) => mime_type, None => ct_str, }; if !https_config .compress_mime_types .iter() .any(|x| x == mime_type) { return Ok(response); } } None => return Ok(response), // don't compress if unknown mime type }; let (mut head, mut body) = response.into_parts(); // ---- If body is smaller than 1400 bytes, don't compress ---- let mut chunks = vec![]; let mut sum_lengths = 0; while sum_lengths < 1400 { match body.next().await { Some(chunk) => { let chunk = chunk?; sum_lengths += chunk.len(); chunks.push(chunk); } None => { return Ok(Response::from_parts(head, Body::from(chunks.concat()))); } } } // put beginning chunks back into body let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body); // make an async reader from that for compressor let body_rd = StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); trace!( "Compressing response body as {:?} (at least {} bytes)", encoding, sum_lengths ); // we don't know the compressed content-length so remove that header head.headers.remove(header::CONTENT_LENGTH); let (encoding, compressed_body) = match encoding { Encoding::Gzip => ( "gzip", Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))), ), // Encoding::Brotli => { // head.headers.insert(header::CONTENT_ENCODING, "br".parse()?); // Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd))) // } Encoding::Deflate => ( "deflate", Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))), ), Encoding::Zstd => ( "zstd", Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))), ), _ => unreachable!(), }; head.headers .insert(header::CONTENT_ENCODING, encoding.parse()?); Ok(Response::from_parts(head, compressed_body)) }