aboutsummaryrefslogblamecommitdiff
path: root/src/https.rs
blob: 1b467c0b36ea1c5c3e21c201a40d8e8cbe182df8 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                             
                         
                                       



                   


                                         
                                            

                               
                                                         


                              
                                                 




                                                  





                                             
                         
                            
                                   
                                                           
                 


                                                         



                                                                         

                                                                            
 
                                                                    
 
                                                             


                                                                
                                                              
                                                        
                                            








                                                                                              



                                                                                                                          



                                                                            
                                                                                  

                                         
                                                                                  




                         

















                                                                          




                                                           
                                       

                                            


                                          








                                                                       
                                                                                         
 
                                     


                               
                                              





                                                                               
                                   

                                             
                                               

                                                                

                                                                  


                         

                                                              
 
                                                           
                                             
 






                                                                                         

                                                                    



                                                                           
                 
                                                   
                                                                           
 




                                                                              
                

                                                        





                                                                      

























































                                                                                                       
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{atomic::Ordering, Arc};

use anyhow::Result;
use log::*;

use accept_encoding_fork::Encoding;
use async_compression::tokio::bufread::*;
use futures::TryStreamExt;
use http::header::{HeaderName, HeaderValue};
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{header, Body, Request, Response, StatusCode};
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::io::{ReaderStream, StreamReader};

use crate::cert_store::{CertStore, StoreResolver};
use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;

pub struct HttpsConfig {
	pub bind_addr: SocketAddr,
	pub enable_compression: bool,
	pub compress_mime_types: Vec<String>,
}

pub async fn serve_https(
	config: HttpsConfig,
	cert_store: Arc<CertStore>,
	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
) -> Result<()> {
	let config = Arc::new(config);

	let mut tls_cfg = rustls::ServerConfig::builder()
		.with_safe_defaults()
		.with_no_client_auth()
		.with_cert_resolver(Arc::new(StoreResolver(cert_store)));

	tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
	let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));

	info!("Starting to serve on https://{}.", config.bind_addr);

	let tcp = TcpListener::bind(config.bind_addr).await?;
	loop {
		let (socket, remote_addr) = tcp.accept().await?;

		let rx_proxy_config = rx_proxy_config.clone();
		let tls_acceptor = tls_acceptor.clone();
		let config = config.clone();

		tokio::spawn(async move {
			match tls_acceptor.accept(socket).await {
				Ok(stream) => {
					debug!("TLS handshake was successfull");
					let http_result = Http::new()
						.serve_connection(
							stream,
							service_fn(move |req: Request<Body>| {
								let https_config = config.clone();
								let proxy_config: Arc<ProxyConfig> =
									rx_proxy_config.borrow().clone();
								handle_outer(remote_addr, req, https_config, proxy_config)
							}),
						)
						.await;
					if let Err(http_err) = http_result {
						warn!("HTTP error: {}", http_err);
					}
				}
				Err(e) => warn!("Error in TLS connection: {}", e),
			}
		});
	}
}

async fn handle_outer(
	remote_addr: SocketAddr,
	req: Request<Body>,
	https_config: Arc<HttpsConfig>,
	proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, Infallible> {
	match handle(remote_addr, req, https_config, proxy_config).await {
		Err(e) => {
			warn!("Handler error: {}", e);
			Ok(Response::builder()
				.status(StatusCode::INTERNAL_SERVER_ERROR)
				.body(Body::from(format!("{}", e)))
				.unwrap())
		}
		Ok(r) => Ok(r),
	}
}

// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn handle(
	remote_addr: SocketAddr,
	req: Request<Body>,
	https_config: Arc<HttpsConfig>,
	proxy_config: Arc<ProxyConfig>,
) -> Result<Response<Body>, anyhow::Error> {
	let method = req.method().clone();
	let uri = req.uri().to_string();

	let host = if let Some(auth) = req.uri().authority() {
		auth.as_str()
	} else {
		req.headers()
			.get("host")
			.ok_or_else(|| anyhow!("Missing host header"))?
			.to_str()?
	};
	let path = req.uri().path();
	let accept_encoding = accept_encoding_fork::parse(req.headers()).unwrap_or(None);

	let best_match = proxy_config
		.entries
		.iter()
		.filter(|ent| {
			ent.host.matches(host)
				&& ent
					.path_prefix
					.as_ref()
					.map(|prefix| path.starts_with(prefix))
					.unwrap_or(true)
		})
		.max_by_key(|ent| {
			(
				ent.priority,
				ent.path_prefix
					.as_ref()
					.map(|x| x.len() as i32)
					.unwrap_or(0),
				-ent.calls.load(Ordering::SeqCst),
			)
		});

	if let Some(proxy_to) = best_match {
		proxy_to.calls.fetch_add(1, Ordering::SeqCst);

		debug!("{}{} -> {}", host, path, proxy_to);
		trace!("Request: {:?}", req);

		let mut response = if proxy_to.https_target {
			let to_addr = format!("https://{}", proxy_to.target_addr);
			reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await?
		} else {
			let to_addr = format!("http://{}", proxy_to.target_addr);
			reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?
		};

		for (header, value) in proxy_to.add_headers.iter() {
			response.headers_mut().insert(
				HeaderName::from_bytes(header.as_bytes())?,
				HeaderValue::from_str(value)?,
			);
		}
		trace!("Response: {:?}", response);
		info!("{} {} {}", method, response.status().as_u16(), uri);

		if https_config.enable_compression {
			try_compress(response, accept_encoding, &https_config)
		} else {
			Ok(response)
		}
	} else {
		debug!("{}{} -> NOT FOUND", host, path);
		info!("{} 404 {}", method, uri);

		Ok(Response::builder()
			.status(StatusCode::NOT_FOUND)
			.body(Body::from("No matching proxy entry"))?)
	}
}

fn try_compress(
	response: Response<Body>,
	accept_encoding: Option<Encoding>,
	https_config: &HttpsConfig,
) -> Result<Response<Body>> {
	// Check if a compression encoding is accepted
	let encoding = match accept_encoding {
		None | Some(Encoding::Identity) => return Ok(response),
		Some(enc) => enc,
	};

	// If already compressed, return as is
	if response.headers().get(header::CONTENT_ENCODING).is_some() {
		return Ok(response);
	}

	// If content type not in mime types for which to compress, return as is
	match response.headers().get(header::CONTENT_TYPE) {
		Some(ct) => {
			let ct_str = ct.to_str()?;
			if !https_config.compress_mime_types.iter().any(|x| x == ct_str) {
				return Ok(response);
			}
		}
		None => return Ok(response),
	};

	debug!("Compressing response body as {:?}", encoding);

	let (mut head, body) = response.into_parts();
	let body_rd =
		StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
	let compressed_body = match encoding {
		Encoding::Gzip => {
			head.headers
				.insert(header::CONTENT_ENCODING, "gzip".parse()?);
			Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd)))
		}
		Encoding::Brotli => {
			head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
			Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
		}
		Encoding::Deflate => {
			head.headers
				.insert(header::CONTENT_ENCODING, "deflate".parse()?);
			Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd)))
		}
		Encoding::Zstd => {
			head.headers
				.insert(header::CONTENT_ENCODING, "zstd".parse()?);
			Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd)))
		}
		_ => unreachable!(),
	};

	Ok(Response::from_parts(head, compressed_body))
}