aboutsummaryrefslogblamecommitdiff
path: root/src/https.rs
blob: 34e3f85dd750dba017ad815037f38d6d08835002 (plain) (tree)
1
2
3
4
5
6
7
8
9
                             
                         
                                       



                   

                                         
                       
                          
                                            
                         

                               
                                                         


                              
                                                 




                                                  





                                             
                         
                            
                                   
                                                           
                 


                                                         



                                                                         

                                                                            
 
                                                                    
 
                                                             


                                                                
                                                              
                                                        
                                            








                                                                                              



                                                                                                                          



                                                                            
                                                                                  

                                         
                                                                                  




                         

















                                                                          




                                                           
                                       

                                            


                                          








                                                                       
                                                                                                        
 
                                     


                               
                                              





                                                                               
                                   

                                             
                                               

                                                                
                                                      

                                              
                                                                  


                         

                                                              
 
                                                           
                                             
 






                                                                                         

                                                                    



                                                                           
                 
                                                   
                                                                           
 
                                                    
                                                                                            


                                    
                

                                                        





                                                                      
 
                      
                                 
                       
                                                      

                                   










                                                                                                           






                                                           
                                   


                                  
                                   








                                                                                     



                                                                       



                                                                                



                                                                        




                                                        


                                                    
                                                                                   

          
                                                         
 
















                                                                                                   
                                                                                 

                                                        

                                                                                                       
 
               
                                                                        

                           
          

                                                                            

                                                    




                                                                                        



                                                                                         







                                                                                           

                                    

                                                                     


                                                       
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};

use anyhow::Result;
use log::*;

use accept_encoding_fork::Encoding;
use async_compression::tokio::bufread::*;
use futures::StreamExt;
use futures::TryStreamExt;
use http::header::{HeaderName, HeaderValue};
use http::method::Method;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::io::{ReaderStream, StreamReader};

use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;

pub struct HttpsConfig {
	pub bind_addr: SocketAddr,
	pub enable_compression: bool,
	pub compress_mime_types: Vec<String>,
}

pub async fn serve_https(
	config: HttpsConfig,
	cert_store: Arc<CertStore>,
	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
) -> Result<()> {
	let config = Arc::new(config);

	let mut tls_cfg = rustls::ServerConfig::builder()
		.with_safe_defaults()
		.with_no_client_auth()
		.with_cert_resolver(Arc::new(StoreResolver(cert_store)));

	tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
	let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));

	info!("Starting to serve on https://{}.", config.bind_addr);

	let tcp = TcpListener::bind(config.bind_addr).await?;
	loop {
		let (socket, remote_addr) = tcp.accept().await?;

		let rx_proxy_config = rx_proxy_config.clone();
		let tls_acceptor = tls_acceptor.clone();
		let config = config.clone();

		tokio::spawn(async move {
			match tls_acceptor.accept(socket).await {
				Ok(stream) => {
					debug!("TLS handshake was successfull");
					let http_result = Http::new()
						.serve_connection(
							stream,
							service_fn(move |req: Request<Body>| {
								let https_config = config.clone();
								let proxy_config: Arc<ProxyConfig> =
									rx_proxy_config.borrow().clone();
								handle_outer(remote_addr, req, https_config, proxy_config)
							}),
						)
						.await;
					if let Err(http_err) = http_result {
						warn!("HTTP error: {}", http_err);
					}
				}
				Err(e) => warn!("Error in TLS connection: {}", e),
			}
		});
	}
}

async fn handle_outer(
	remote_addr: SocketAddr,
	req: Request<Body>,
	https_config: Arc<HttpsConfig>,
	proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, Infallible> {
	match handle(remote_addr, req, https_config, proxy_config).await {
		Err(e) => {
			warn!("Handler error: {}", e);
			Ok(Response::builder()
				.status(StatusCode::INTERNAL_SERVER_ERROR)
				.body(Body::from(format!("{}", e)))
				.unwrap())
		}
		Ok(r) => Ok(r),
	}
}

// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn handle(
	remote_addr: SocketAddr,
	req: Request<Body>,
	https_config: Arc<HttpsConfig>,
	proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, anyhow::Error> {
	let method = req.method().clone();
	let uri = req.uri().to_string();

	let host = if let Some(auth) = req.uri().authority() {
		auth.as_str()
	} else {
		req.headers()
			.get("host")
			.ok_or_else(|| anyhow!("Missing host header"))?
			.to_str()?
	};
	let path = req.uri().path();
	let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]);

	let best_match = proxy_config
		.entries
		.iter()
		.filter(|ent| {
			ent.host.matches(host)
				&& ent
					.path_prefix
					.as_ref()
					.map(|prefix| path.starts_with(prefix))
					.unwrap_or(true)
		})
		.max_by_key(|ent| {
			(
				ent.priority,
				ent.path_prefix
					.as_ref()
					.map(|x| x.len() as i32)
					.unwrap_or(0),
				ent.same_node,
				ent.same_site,
				-ent.calls.load(Ordering::SeqCst),
			)
		});

	if let Some(proxy_to) = best_match {
		proxy_to.calls.fetch_add(1, Ordering::SeqCst);

		debug!("{}{} -> {}", host, path, proxy_to);
		trace!("Request: {:?}", req);

		let mut response = if proxy_to.https_target {
			let to_addr = format!("https://{}", proxy_to.target_addr);
			reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await?
		} else {
			let to_addr = format!("http://{}", proxy_to.target_addr);
			reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?
		};

		for (header, value) in proxy_to.add_headers.iter() {
			response.headers_mut().insert(
				HeaderName::from_bytes(header.as_bytes())?,
				HeaderValue::from_str(value)?,
			);
		}
		trace!("Response: {:?}", response);
		info!("{} {} {}", method, response.status().as_u16(), uri);

		if https_config.enable_compression {
			try_compress(response, method, accept_encoding, &https_config).await
		} else {
			Ok(response)
		}
	} else {
		debug!("{}{} -> NOT FOUND", host, path);
		info!("{} 404 {}", method, uri);

		Ok(Response::builder()
			.status(StatusCode::NOT_FOUND)
			.body(Body::from("No matching proxy entry"))?)
	}
}

async fn try_compress(
	response: Response<Body>,
	method: Method,
	accept_encoding: Vec<(Option<Encoding>, f32)>,
	https_config: &HttpsConfig,
) -> Result<Response<Body>> {
	// Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body)
	// Don't compress partial content, that would be wierd
	// If already compressed, return as is
	if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT))
		|| response.status() == StatusCode::PARTIAL_CONTENT
		|| response.headers().get(header::CONTENT_ENCODING).is_some()
	{
		return Ok(response);
	}

	// Select preferred encoding among those proposed in accept_encoding
	let max_q: f32 = accept_encoding
		.iter()
		.max_by_key(|(_, q)| (q * 10000f32) as i64)
		.unwrap_or(&(None, 1.))
		.1;
	let preference = [
		Encoding::Zstd,
		//Encoding::Brotli,
		Encoding::Deflate,
		Encoding::Gzip,
	];
	#[allow(clippy::float_cmp)]
	let encoding_opt = accept_encoding
		.iter()
		.filter(|(_, q)| *q == max_q)
		.filter_map(|(enc, _)| *enc)
		.filter(|enc| preference.contains(enc))
		.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());

	// If preferred encoding is none, return as is
	let encoding = match encoding_opt {
		None | Some(Encoding::Identity) => return Ok(response),
		Some(enc) => enc,
	};

	// If content type not in mime types for which to compress, return as is
	match response.headers().get(header::CONTENT_TYPE) {
		Some(ct) => {
			let ct_str = ct.to_str()?;
			let mime_type = match ct_str.split_once(';') {
				Some((mime_type, _params)) => mime_type,
				None => ct_str,
			};
			if !https_config
				.compress_mime_types
				.iter()
				.any(|x| x == mime_type)
			{
				return Ok(response);
			}
		}
		None => return Ok(response), // don't compress if unknown mime type
	};

	let (mut head, mut body) = response.into_parts();

	// ---- If body is smaller than 1400 bytes, don't compress ----
	let mut chunks = vec![];
	let mut sum_lengths = 0;
	while sum_lengths < 1400 {
		match body.next().await {
			Some(chunk) => {
				let chunk = chunk?;
				sum_lengths += chunk.len();
				chunks.push(chunk);
			}
			None => {
				return Ok(Response::from_parts(head, Body::from(chunks.concat())));
			}
		}
	}

	// put beginning chunks back into body
	let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body);

	// make an async reader from that for compressor
	let body_rd =
		StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));

	trace!(
		"Compressing response body as {:?} (at least {} bytes)",
		encoding,
		sum_lengths
	);

	// we don't know the compressed content-length so remove that header
	head.headers.remove(header::CONTENT_LENGTH);

	let (encoding, compressed_body) = match encoding {
		Encoding::Gzip => (
			"gzip",
			Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))),
		),
		// Encoding::Brotli => {
		// 	head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
		// 	Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
		// }
		Encoding::Deflate => (
			"deflate",
			Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))),
		),
		Encoding::Zstd => (
			"zstd",
			Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))),
		),
		_ => unreachable!(),
	};
	head.headers
		.insert(header::CONTENT_ENCODING, encoding.parse()?);

	Ok(Response::from_parts(head, compressed_body))
}