aboutsummaryrefslogblamecommitdiff
path: root/src/https.rs
blob: f31caef050f3d94d12fc27e48f174ab45d44f76b (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                             
                         
                                       
                                   

                   
               
 

                                         
                                      
                                       
                                            
                         

                               
                                                         
                            
                  

                              
                                                 
 

                                       
                                                  
                                                   

                         

                                                                         



                                             

                                                     
                                 

 


                                                 
                                                            

 
                         
                            
                                   
                                                           
                                             
                 

                                      









                                                                                         



                                                                                                                                                    

           
                                                         



                                                                         

                                                                            
 
                                                                    
 
                                                             


                                                      






                                                                

                                                     
                                                           

                                                            
 
                                                              
                                                        
                                            
                                              
 

                                                        


                                                                                






                                                                                                         
                                                                                              






                                                                                     


                                                                 

                                                                                                  

                                                                

                                                                                                                                              




                                                                                                               
                                                 
                                          
                                                                            
                                                                                  

                                         
                                                                                  

                         


                                       

                  
                                                                               
                                                   

              

 
                        



                                       
                                   
                                         





















                                                                                                  



                                           

                            
                          
         
               
 
                                                                               


                                              

 

                                                           



                                   

                                
                                 
                     

                                           


                                          


                                                              
                                                                               







                                                                                


                                    
                                     


                               

                                                         





                                                                               
                                   

                                             
                                               

                                                                
                                                      

                                                                                                  

                                                                              


                         
                                            
                                                                                   



                                                         

                                                                                
 




                                                                                      
 
                                                           
                                             
 
                                                                                               




                                                                                
                  
 
                                                                          

                                               
                                                                             
 

                                                                           
                        
                

                                                        
 
                                   
                                                      

                                                                    

         
 
























                                                                                                        
         

                                            
                                                                                               


                    

 
                      
                                 
                       
                                                      

                                   
                                                                                                           



                                                               

                                                                                                

                                                                                                           





                                                                             






                                                           
                                   


                                  
                                   








                                                                                     



                                                                       



                                                                                



                                                                        




                                                        


                                                    
                                                                                   

          
                                                         
 
















                                                                                                   
                                                                                 

                                                        

                                                                                                       
 
               
                                                                        

                           
          

                                                                            

                                                    




                                                                                        



                                                                                         







                                                                                           

                                    

                                                                     


                                                       
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};
use std::time::{Duration, Instant};

use anyhow::Result;
use tracing::*;

use accept_encoding_fork::Encoding;
use async_compression::tokio::bufread::*;
use futures::stream::FuturesUnordered;
use futures::{StreamExt, TryStreamExt};
use http::header::{HeaderName, HeaderValue};
use http::method::Method;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::select;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::io::{ReaderStream, StreamReader};

use opentelemetry::{metrics, KeyValue};

use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::{ProxyConfig, ProxyEntry};
use crate::reverse_proxy;

const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600);

pub struct HttpsConfig {
	pub bind_addr: SocketAddr,
	pub enable_compression: bool,
	pub compress_mime_types: Vec<String>,

	// used internally to convert Instants to u64
	pub time_origin: Instant,
}

struct HttpsMetrics {
	requests_received: metrics::Counter<u64>,
	requests_served: metrics::Counter<u64>,
	request_proxy_duration: metrics::ValueRecorder<f64>,
}

pub async fn serve_https(
	config: HttpsConfig,
	cert_store: Arc<CertStore>,
	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
	mut must_exit: watch::Receiver<bool>,
) -> Result<()> {
	let config = Arc::new(config);

	let meter = opentelemetry::global::meter("tricot");
	let metrics = Arc::new(HttpsMetrics {
		requests_received: meter
			.u64_counter("https_requests_received")
			.with_description("Total number of requests received over HTTPS")
			.init(),
		requests_served: meter
			.u64_counter("https_requests_served")
			.with_description("Total number of requests served over HTTPS")
			.init(),
		request_proxy_duration: meter
			.f64_value_recorder("https_request_proxy_duration")
			.with_description("Duration between time when request was received, and time when backend returned status code and headers")
			.init(),
	});

	let mut tls_cfg = rustls::ServerConfig::builder()
		.with_safe_defaults()
		.with_no_client_auth()
		.with_cert_resolver(Arc::new(StoreResolver(cert_store)));

	tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
	let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));

	info!("Starting to serve on https://{}.", config.bind_addr);

	let tcp = TcpListener::bind(config.bind_addr).await?;
	let mut connections = FuturesUnordered::new();

	while !*must_exit.borrow() {
		let wait_conn_finished = async {
			if connections.is_empty() {
				futures::future::pending().await
			} else {
				connections.next().await
			}
		};
		let (socket, remote_addr) = select! {
			a = tcp.accept() => a?,
			_ = wait_conn_finished => continue,
			_ = must_exit.changed() => continue,
		};

		let rx_proxy_config = rx_proxy_config.clone();
		let tls_acceptor = tls_acceptor.clone();
		let config = config.clone();
		let metrics = metrics.clone();

		let mut must_exit_2 = must_exit.clone();
		let conn = tokio::spawn(async move {
			match tls_acceptor.accept(socket).await {
				Ok(stream) => {
					debug!("TLS handshake was successfull");
					let http_conn = Http::new()
						.serve_connection(
							stream,
							service_fn(move |req: Request<Body>| {
								let https_config = config.clone();
								let proxy_config: Arc<ProxyConfig> =
									rx_proxy_config.borrow().clone();
								let metrics = metrics.clone();
								handle_request(
									remote_addr,
									req,
									https_config,
									proxy_config,
									metrics,
								)
							}),
						)
						.with_upgrades();
					let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME);
					tokio::pin!(http_conn, timeout);
					let http_result = loop {
						select! (
							r = &mut http_conn => break r.map_err(Into::into),
							_ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")),
							_ = must_exit_2.changed() => {
								if *must_exit_2.borrow() {
									http_conn.as_mut().graceful_shutdown();
								}
							}
						)
					};
					if let Err(http_err) = http_result {
						warn!("HTTP error: {}", http_err);
					}
				}
				Err(e) => warn!("Error in TLS connection: {}", e),
			}
		});
		connections.push(conn);
	}

	drop(tcp);

	info!("HTTPS server shutting down, draining remaining connections...");
	while connections.next().await.is_some() {}

	Ok(())
}

async fn handle_request(
	remote_addr: SocketAddr,
	req: Request<Body>,
	https_config: Arc<HttpsConfig>,
	proxy_config: Arc<ProxyConfig>,
	metrics: Arc<HttpsMetrics>,
) -> Result<Response<Body>, Infallible> {
	let method_tag = KeyValue::new("method", req.method().to_string());

	// The host tag is only included in the requests_received metric,
	// as for other metrics it can easily lead to cardinality explosions.
	let host_tag = KeyValue::new(
		"host",
		req.uri()
			.authority()
			.map(|auth| auth.to_string())
			.or_else(|| {
				req.headers()
					.get("host")
					.map(|host| host.to_str().unwrap_or_default().to_string())
			})
			.unwrap_or_default(),
	);

	metrics
		.requests_received
		.add(1, &[host_tag, method_tag.clone()]);

	let mut tags = vec![method_tag];
	let resp = select_target_and_proxy(
		&https_config,
		&proxy_config,
		&metrics,
		remote_addr,
		req,
		&mut tags,
	)
	.await;

	tags.push(KeyValue::new("status_code", resp.status().as_u16() as i64));
	metrics.requests_served.add(1, &tags);

	Ok(resp)
}

// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn select_target_and_proxy(
	https_config: &HttpsConfig,
	proxy_config: &ProxyConfig,
	metrics: &HttpsMetrics,
	remote_addr: SocketAddr,
	req: Request<Body>,
	tags: &mut Vec<KeyValue>,
) -> Response<Body> {
	let received_time = Instant::now();

	let method = req.method().clone();
	let uri = req.uri().to_string();

	let host = if let Some(auth) = req.uri().authority() {
		auth.as_str()
	} else {
		match req.headers().get("host").and_then(|x| x.to_str().ok()) {
			Some(host) => host,
			None => {
				return Response::builder()
					.status(StatusCode::BAD_REQUEST)
					.body(Body::from("Missing Host header"))
					.unwrap();
			}
		}
	};
	let path = req.uri().path();

	let best_match = proxy_config
		.entries
		.iter()
		.filter(|ent| {
			ent.flags.healthy
				&& ent.host.matches(host)
				&& ent
					.path_prefix
					.as_ref()
					.map(|prefix| path.starts_with(prefix))
					.unwrap_or(true)
		})
		.max_by_key(|ent| {
			(
				ent.priority,
				ent.path_prefix
					.as_ref()
					.map(|x| x.len() as i32)
					.unwrap_or(0),
				(ent.flags.same_node || ent.flags.site_lb || ent.flags.global_lb),
				(ent.flags.same_site || ent.flags.global_lb),
				-ent.calls_in_progress.load(Ordering::SeqCst),
				-ent.last_call.load(Ordering::SeqCst),
			)
		});

	if let Some(proxy_to) = best_match {
		tags.push(KeyValue::new("service", proxy_to.service_name.clone()));
		tags.push(KeyValue::new(
			"target_addr",
			proxy_to.target_addr.to_string(),
		));
		tags.push(KeyValue::new("same_node", proxy_to.flags.same_node));
		tags.push(KeyValue::new("same_site", proxy_to.flags.same_site));

		proxy_to.last_call.fetch_max(
			(received_time - https_config.time_origin).as_millis() as i64,
			Ordering::Relaxed,
		);
		proxy_to.calls_in_progress.fetch_add(1, Ordering::SeqCst);

		debug!("{}{} -> {}", host, path, proxy_to);
		trace!("Request: {:?}", req);

		let response = match do_proxy(https_config, remote_addr, req, proxy_to).await {
			Ok(resp) => resp,
			Err(e) => Response::builder()
				.status(StatusCode::BAD_GATEWAY)
				.body(Body::from(format!("Proxy error: {}", e)))
				.unwrap(),
		};

		proxy_to.calls_in_progress.fetch_sub(1, Ordering::SeqCst);
		metrics
			.request_proxy_duration
			.record(received_time.elapsed().as_secs_f64(), tags);

		trace!("Final response: {:?}", response);
		info!("{} {} {}", method, response.status().as_u16(), uri);
		response
	} else {
		debug!("{}{} -> NOT FOUND", host, path);
		info!("{} 404 {}", method, uri);

		Response::builder()
			.status(StatusCode::NOT_FOUND)
			.body(Body::from("No matching proxy entry"))
			.unwrap()
	}
}

async fn do_proxy(
	https_config: &HttpsConfig,
	remote_addr: SocketAddr,
	req: Request<Body>,
	proxy_to: &ProxyEntry,
) -> Result<Response<Body>> {
	let method = req.method().clone();
	let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]);

	let mut response = if proxy_to.https_target {
		let to_addr = format!("https://{}", proxy_to.target_addr);
		reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await?
	} else {
		let to_addr = format!("http://{}", proxy_to.target_addr);
		reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?
	};

	if response.status().is_success() {
		// (TODO: maybe we want to add these headers even if it's not a success?)
		for (header, value) in proxy_to.add_headers.iter() {
			response.headers_mut().insert(
				HeaderName::from_bytes(header.as_bytes())?,
				HeaderValue::from_str(value)?,
			);
		}
	}

	if https_config.enable_compression {
		response = try_compress(response, method, accept_encoding, https_config).await?
	};

	Ok(response)
}

async fn try_compress(
	response: Response<Body>,
	method: Method,
	accept_encoding: Vec<(Option<Encoding>, f32)>,
	https_config: &HttpsConfig,
) -> Result<Response<Body>> {
	// Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body)
	// Don't compress partial content as it causes issues
	// Don't bother compressing non-2xx results
	// Don't compress Upgrade responses (e.g. websockets)
	// Don't compress responses that are already compressed
	if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT))
		|| response.status() == StatusCode::PARTIAL_CONTENT
		|| !response.status().is_success()
		|| response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade"))
		|| response.headers().get(header::CONTENT_ENCODING).is_some()
	{
		return Ok(response);
	}

	// Select preferred encoding among those proposed in accept_encoding
	let max_q: f32 = accept_encoding
		.iter()
		.max_by_key(|(_, q)| (q * 10000f32) as i64)
		.unwrap_or(&(None, 1.))
		.1;
	let preference = [
		Encoding::Zstd,
		//Encoding::Brotli,
		Encoding::Deflate,
		Encoding::Gzip,
	];
	#[allow(clippy::float_cmp)]
	let encoding_opt = accept_encoding
		.iter()
		.filter(|(_, q)| *q == max_q)
		.filter_map(|(enc, _)| *enc)
		.filter(|enc| preference.contains(enc))
		.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());

	// If preferred encoding is none, return as is
	let encoding = match encoding_opt {
		None | Some(Encoding::Identity) => return Ok(response),
		Some(enc) => enc,
	};

	// If content type not in mime types for which to compress, return as is
	match response.headers().get(header::CONTENT_TYPE) {
		Some(ct) => {
			let ct_str = ct.to_str()?;
			let mime_type = match ct_str.split_once(';') {
				Some((mime_type, _params)) => mime_type,
				None => ct_str,
			};
			if !https_config
				.compress_mime_types
				.iter()
				.any(|x| x == mime_type)
			{
				return Ok(response);
			}
		}
		None => return Ok(response), // don't compress if unknown mime type
	};

	let (mut head, mut body) = response.into_parts();

	// ---- If body is smaller than 1400 bytes, don't compress ----
	let mut chunks = vec![];
	let mut sum_lengths = 0;
	while sum_lengths < 1400 {
		match body.next().await {
			Some(chunk) => {
				let chunk = chunk?;
				sum_lengths += chunk.len();
				chunks.push(chunk);
			}
			None => {
				return Ok(Response::from_parts(head, Body::from(chunks.concat())));
			}
		}
	}

	// put beginning chunks back into body
	let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body);

	// make an async reader from that for compressor
	let body_rd =
		StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));

	trace!(
		"Compressing response body as {:?} (at least {} bytes)",
		encoding,
		sum_lengths
	);

	// we don't know the compressed content-length so remove that header
	head.headers.remove(header::CONTENT_LENGTH);

	let (encoding, compressed_body) = match encoding {
		Encoding::Gzip => (
			"gzip",
			Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))),
		),
		// Encoding::Brotli => {
		// 	head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
		// 	Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
		// }
		Encoding::Deflate => (
			"deflate",
			Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))),
		),
		Encoding::Zstd => (
			"zstd",
			Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))),
		),
		_ => unreachable!(),
	};
	head.headers
		.insert(header::CONTENT_ENCODING, encoding.parse()?);

	Ok(Response::from_parts(head, compressed_body))
}