aboutsummaryrefslogblamecommitdiff
path: root/src/reverse_proxy.rs
blob: c6e0bacc84c7676b2d27eda6b99a38813290903b (plain) (tree)
1
2
3
4
5
6
7
8
9


                                                                  
                          

                      
                   
                                      
 
                   
           
 
                             
                                            
                                                                                 

                                                             

                                                
 

                                                            

                                     
                                                           









                                             







                                                                                   
                                      
                                                            























                                                                                        

                                            

                                                         


                                                                                
                                                                                   
                                                                                                  
                                                                    




                                                                  
                                                            




                                                   
                                                 


                                                                     
                                                       




                                                                                       



                                                              
 
                                          
                                                                       
                                                               




                                                                                           


                                                      
                                                                                  



                         







                                              
                                                                                       
 

                                                         


                                                           
                                                                                                
 
                                                              


                                                 


                                                                 





                               
                                                                                       


                                                                 



                                                                                 




                                                                                             
                                                                                                






                                                                 




                                                  







                                                        


                                                   
//! Copied from https://github.com/felipenoris/hyper-reverse-proxy
//! See there for original Copyright notice

use std::convert::TryInto;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};

use anyhow::Result;
use log::*;

use http::header::HeaderName;
use hyper::header::{HeaderMap, HeaderValue};
use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri};
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, ServerName};

use crate::tls_util::HttpsConnectorFixedDnsname;

pub const PROXY_TIMEOUT: Duration = Duration::from_secs(60);

const HOP_HEADERS: &[HeaderName] = &[
	header::CONNECTION,
	// header::KEEP_ALIVE, // not found in http::header
	header::PROXY_AUTHENTICATE,
	header::PROXY_AUTHORIZATION,
	header::TE,
	header::TRAILER,
	header::TRANSFER_ENCODING,
	header::UPGRADE,
];

fn is_hop_header(name: &HeaderName) -> bool {
	HOP_HEADERS.iter().any(|h| h == name)
}

/// Returns a clone of the headers without the [hop-by-hop headers].
///
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
	let mut result = HeaderMap::new();
	for (k, v) in headers.iter() {
		if !is_hop_header(k) {
			result.append(k.clone(), v.clone());
		}
	}
	result
}

fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
	*response.headers_mut() = remove_hop_headers(response.headers());
	response
}

fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
	let forward_uri = match req.uri().query() {
		Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
		None => format!("{}{}", forward_url, req.uri().path()),
	};

	Ok(Uri::from_str(forward_uri.as_str())?)
}

fn create_proxied_request<B>(
	client_ip: IpAddr,
	forward_url: &str,
	request: Request<B>,
) -> Result<Request<B>> {
	let mut builder = Request::builder()
		.method(request.method())
		.uri(forward_uri(forward_url, &request)?)
		.version(hyper::Version::HTTP_11);

	*builder.headers_mut().unwrap() = remove_hop_headers(request.headers());

	// If request does not have host header, add it from original URI authority
	if let header::Entry::Vacant(entry) = builder.headers_mut().unwrap().entry(header::HOST) {
		if let Some(authority) = request.uri().authority() {
			entry.insert(authority.as_str().parse()?);
		}
	}

	// Add forwarding information in the headers
	let x_forwarded_for_header_name = "x-forwarded-for";
	match builder
		.headers_mut()
		.unwrap()
		.entry(x_forwarded_for_header_name)
	{
		header::Entry::Vacant(entry) => {
			entry.insert(client_ip.to_string().parse()?);
		}

		header::Entry::Occupied(mut entry) => {
			let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
			entry.insert(addr.parse()?);
		}
	}

	builder.headers_mut().unwrap().insert(
		HeaderName::from_bytes(b"x-forwarded-proto")?,
		"https".try_into()?,
	);

	// Proxy upgrade requests properly
	if let Some(conn) = request.headers().get(header::CONNECTION) {
		if conn.to_str()?.to_lowercase() == "upgrade" {
			if let Some(upgrade) = request.headers().get(header::UPGRADE) {
				builder
					.headers_mut()
					.unwrap()
					.insert(header::CONNECTION, "Upgrade".try_into()?);
				builder
					.headers_mut()
					.unwrap()
					.insert(header::UPGRADE, upgrade.clone());
			}
		}
	}

	Ok(builder.body(request.into_body())?)
}

pub async fn call(
	client_ip: IpAddr,
	forward_uri: &str,
	request: Request<Body>,
) -> Result<Response<Body>> {
	let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;

	trace!("Proxied request: {:?}", proxied_request);

	let mut connector = HttpConnector::new();
	connector.set_connect_timeout(Some(PROXY_TIMEOUT));

	let client: Client<_, hyper::Body> = Client::builder().set_host(false).build(connector);

	let response = client.request(proxied_request).await?;

	trace!("Inner response: {:?}", response);

	let proxied_response = create_proxied_response(response);
	Ok(proxied_response)
}

pub async fn call_https(
	client_ip: IpAddr,
	forward_uri: &str,
	request: Request<Body>,
) -> Result<Response<Body>> {
	let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;

	trace!("Proxied request (HTTPS): {:?}", proxied_request);

	let tls_config = rustls::client::ClientConfig::builder()
		.with_safe_defaults()
		.with_custom_certificate_verifier(Arc::new(DontVerifyServerCert))
		.with_no_client_auth();

	let mut http_connector = HttpConnector::new();
	http_connector.set_connect_timeout(Some(PROXY_TIMEOUT));
	let connector = HttpsConnectorFixedDnsname::new(tls_config, "dummy", http_connector);

	let client: Client<_, hyper::Body> = Client::builder().set_host(false).build(connector);
	let response = client.request(proxied_request).await?;

	trace!("Inner response (HTTPS): {:?}", response);

	let proxied_response = create_proxied_response(response);
	Ok(proxied_response)
}

struct DontVerifyServerCert;

impl ServerCertVerifier for DontVerifyServerCert {
	fn verify_server_cert(
		&self,
		_end_entity: &Certificate,
		_intermediates: &[Certificate],
		_server_name: &ServerName,
		_scts: &mut dyn Iterator<Item = &[u8]>,
		_ocsp_response: &[u8],
		_now: SystemTime,
	) -> Result<ServerCertVerified, rustls::Error> {
		Ok(ServerCertVerified::assertion())
	}
}