aboutsummaryrefslogblamecommitdiff
path: root/src/https.rs
blob: ba4365d8022478b1e5f5ef2a0cbce42af513fa51 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                             
                         
                                       
                        



                   

                                         
                                      
                                       
                                            
                         

                               
                                                         
                            
                  

                              
                                                 




                                                  

                                                                         





                                             
                         
                            
                                   
                                                           
                                             
                 


                                                         



                                                                         

                                                                            
 
                                                                    
 
                                                             


                                                      






                                                                

                                                     
                                                           

                                                            
 
                                                              
                                                        
                                            
 

                                                        


                                                                                










                                                                                                                          

                                                                                                  

                                                                

                                                                                                                                              




                                                                                                               
                                                 
                                          
                                                                            
                                                                                  

                                         
                                                                                  

                         


                                       

                  
                                                                               
                                                   

              

 

















                                                                          




                                                           
                                       

                                            


                                          








                                                                       
                                                                                                        
 
                                     


                               
                                              





                                                                               
                                   

                                             
                                               

                                                                
                                                      

                                              
                                                                  


                         

                                                              
 
                                                           
                                             
 

                                                                                  
                                                                                                      

                                                                                 
                                                                                                
                  
 







                                                                                                 

                 
                                                    






                                                                                                             
                

                                                        





                                                                      
 







                                                                        

 
                      
                                 
                       
                                                      

                                   
                                                                                                           



                                                               

                                                                                                

                                                                                                           





                                                                             






                                                           
                                   


                                  
                                   








                                                                                     



                                                                       



                                                                                



                                                                        




                                                        


                                                    
                                                                                   

          
                                                         
 
















                                                                                                   
                                                                                 

                                                        

                                                                                                       
 
               
                                                                        

                           
          

                                                                            

                                                    




                                                                                        



                                                                                         







                                                                                           

                                    

                                                                     


                                                       
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};
use std::time::Duration;

use anyhow::Result;
use log::*;

use accept_encoding_fork::Encoding;
use async_compression::tokio::bufread::*;
use futures::stream::FuturesUnordered;
use futures::{StreamExt, TryStreamExt};
use http::header::{HeaderName, HeaderValue};
use http::method::Method;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::select;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::io::{ReaderStream, StreamReader};

use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;

const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600);

pub struct HttpsConfig {
	pub bind_addr: SocketAddr,
	pub enable_compression: bool,
	pub compress_mime_types: Vec<String>,
}

pub async fn serve_https(
	config: HttpsConfig,
	cert_store: Arc<CertStore>,
	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
	mut must_exit: watch::Receiver<bool>,
) -> Result<()> {
	let config = Arc::new(config);

	let mut tls_cfg = rustls::ServerConfig::builder()
		.with_safe_defaults()
		.with_no_client_auth()
		.with_cert_resolver(Arc::new(StoreResolver(cert_store)));

	tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
	let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));

	info!("Starting to serve on https://{}.", config.bind_addr);

	let tcp = TcpListener::bind(config.bind_addr).await?;
	let mut connections = FuturesUnordered::new();

	while !*must_exit.borrow() {
		let wait_conn_finished = async {
			if connections.is_empty() {
				futures::future::pending().await
			} else {
				connections.next().await
			}
		};
		let (socket, remote_addr) = select! {
			a = tcp.accept() => a?,
			_ = wait_conn_finished => continue,
			_ = must_exit.changed() => continue,
		};

		let rx_proxy_config = rx_proxy_config.clone();
		let tls_acceptor = tls_acceptor.clone();
		let config = config.clone();

		let mut must_exit_2 = must_exit.clone();
		let conn = tokio::spawn(async move {
			match tls_acceptor.accept(socket).await {
				Ok(stream) => {
					debug!("TLS handshake was successfull");
					let http_conn = Http::new()
						.serve_connection(
							stream,
							service_fn(move |req: Request<Body>| {
								let https_config = config.clone();
								let proxy_config: Arc<ProxyConfig> =
									rx_proxy_config.borrow().clone();
								handle_outer(remote_addr, req, https_config, proxy_config)
							}),
						)
						.with_upgrades();
					let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME);
					tokio::pin!(http_conn, timeout);
					let http_result = loop {
						select! (
							r = &mut http_conn => break r.map_err(Into::into),
							_ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")),
							_ = must_exit_2.changed() => {
								if *must_exit_2.borrow() {
									http_conn.as_mut().graceful_shutdown();
								}
							}
						)
					};
					if let Err(http_err) = http_result {
						warn!("HTTP error: {}", http_err);
					}
				}
				Err(e) => warn!("Error in TLS connection: {}", e),
			}
		});
		connections.push(conn);
	}

	drop(tcp);

	info!("HTTPS server shutting down, draining remaining connections...");
	while connections.next().await.is_some() {}

	Ok(())
}

async fn handle_outer(
	remote_addr: SocketAddr,
	req: Request<Body>,
	https_config: Arc<HttpsConfig>,
	proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, Infallible> {
	match handle(remote_addr, req, https_config, proxy_config).await {
		Err(e) => {
			warn!("Handler error: {}", e);
			Ok(Response::builder()
				.status(StatusCode::INTERNAL_SERVER_ERROR)
				.body(Body::from(format!("{}", e)))
				.unwrap())
		}
		Ok(r) => Ok(r),
	}
}

// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn handle(
	remote_addr: SocketAddr,
	req: Request<Body>,
	https_config: Arc<HttpsConfig>,
	proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, anyhow::Error> {
	let method = req.method().clone();
	let uri = req.uri().to_string();

	let host = if let Some(auth) = req.uri().authority() {
		auth.as_str()
	} else {
		req.headers()
			.get("host")
			.ok_or_else(|| anyhow!("Missing host header"))?
			.to_str()?
	};
	let path = req.uri().path();
	let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]);

	let best_match = proxy_config
		.entries
		.iter()
		.filter(|ent| {
			ent.host.matches(host)
				&& ent
					.path_prefix
					.as_ref()
					.map(|prefix| path.starts_with(prefix))
					.unwrap_or(true)
		})
		.max_by_key(|ent| {
			(
				ent.priority,
				ent.path_prefix
					.as_ref()
					.map(|x| x.len() as i32)
					.unwrap_or(0),
				ent.same_node,
				ent.same_site,
				-ent.calls.load(Ordering::SeqCst),
			)
		});

	if let Some(proxy_to) = best_match {
		proxy_to.calls.fetch_add(1, Ordering::SeqCst);

		debug!("{}{} -> {}", host, path, proxy_to);
		trace!("Request: {:?}", req);

		let mut response = if proxy_to.https_target {
			let to_addr = format!("https://{}", proxy_to.target_addr);
			handle_error(reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await)
		} else {
			let to_addr = format!("http://{}", proxy_to.target_addr);
			handle_error(reverse_proxy::call(remote_addr.ip(), &to_addr, req).await)
		};

		if response.status().is_success() {
			// (TODO: maybe we want to add these headers even if it's not a success?)
			for (header, value) in proxy_to.add_headers.iter() {
				response.headers_mut().insert(
					HeaderName::from_bytes(header.as_bytes())?,
					HeaderValue::from_str(value)?,
				);
			}
		}

		if https_config.enable_compression {
			response =
				try_compress(response, method.clone(), accept_encoding, &https_config).await?
		};

		trace!("Final response: {:?}", response);
		info!("{} {} {}", method, response.status().as_u16(), uri);
		Ok(response)
	} else {
		debug!("{}{} -> NOT FOUND", host, path);
		info!("{} 404 {}", method, uri);

		Ok(Response::builder()
			.status(StatusCode::NOT_FOUND)
			.body(Body::from("No matching proxy entry"))?)
	}
}

fn handle_error(resp: Result<Response<Body>>) -> Response<Body> {
	match resp {
		Ok(resp) => resp,
		Err(e) => Response::builder()
			.status(StatusCode::BAD_GATEWAY)
			.body(Body::from(format!("Proxy error: {}", e)))
			.unwrap(),
	}
}

async fn try_compress(
	response: Response<Body>,
	method: Method,
	accept_encoding: Vec<(Option<Encoding>, f32)>,
	https_config: &HttpsConfig,
) -> Result<Response<Body>> {
	// Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body)
	// Don't compress partial content as it causes issues
	// Don't bother compressing non-2xx results
	// Don't compress Upgrade responses (e.g. websockets)
	// Don't compress responses that are already compressed
	if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT))
		|| response.status() == StatusCode::PARTIAL_CONTENT
		|| !response.status().is_success()
		|| response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade"))
		|| response.headers().get(header::CONTENT_ENCODING).is_some()
	{
		return Ok(response);
	}

	// Select preferred encoding among those proposed in accept_encoding
	let max_q: f32 = accept_encoding
		.iter()
		.max_by_key(|(_, q)| (q * 10000f32) as i64)
		.unwrap_or(&(None, 1.))
		.1;
	let preference = [
		Encoding::Zstd,
		//Encoding::Brotli,
		Encoding::Deflate,
		Encoding::Gzip,
	];
	#[allow(clippy::float_cmp)]
	let encoding_opt = accept_encoding
		.iter()
		.filter(|(_, q)| *q == max_q)
		.filter_map(|(enc, _)| *enc)
		.filter(|enc| preference.contains(enc))
		.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());

	// If preferred encoding is none, return as is
	let encoding = match encoding_opt {
		None | Some(Encoding::Identity) => return Ok(response),
		Some(enc) => enc,
	};

	// If content type not in mime types for which to compress, return as is
	match response.headers().get(header::CONTENT_TYPE) {
		Some(ct) => {
			let ct_str = ct.to_str()?;
			let mime_type = match ct_str.split_once(';') {
				Some((mime_type, _params)) => mime_type,
				None => ct_str,
			};
			if !https_config
				.compress_mime_types
				.iter()
				.any(|x| x == mime_type)
			{
				return Ok(response);
			}
		}
		None => return Ok(response), // don't compress if unknown mime type
	};

	let (mut head, mut body) = response.into_parts();

	// ---- If body is smaller than 1400 bytes, don't compress ----
	let mut chunks = vec![];
	let mut sum_lengths = 0;
	while sum_lengths < 1400 {
		match body.next().await {
			Some(chunk) => {
				let chunk = chunk?;
				sum_lengths += chunk.len();
				chunks.push(chunk);
			}
			None => {
				return Ok(Response::from_parts(head, Body::from(chunks.concat())));
			}
		}
	}

	// put beginning chunks back into body
	let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body);

	// make an async reader from that for compressor
	let body_rd =
		StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));

	trace!(
		"Compressing response body as {:?} (at least {} bytes)",
		encoding,
		sum_lengths
	);

	// we don't know the compressed content-length so remove that header
	head.headers.remove(header::CONTENT_LENGTH);

	let (encoding, compressed_body) = match encoding {
		Encoding::Gzip => (
			"gzip",
			Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))),
		),
		// Encoding::Brotli => {
		// 	head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
		// 	Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
		// }
		Encoding::Deflate => (
			"deflate",
			Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))),
		),
		Encoding::Zstd => (
			"zstd",
			Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))),
		),
		_ => unreachable!(),
	};
	head.headers
		.insert(header::CONTENT_ENCODING, encoding.parse()?);

	Ok(Response::from_parts(head, compressed_body))
}