aboutsummaryrefslogblamecommitdiff
path: root/src/reverse_proxy.rs
blob: 23131c9cf95669d796d9f35bc084fd1b34436ef4 (plain) (tree)
1
2
3
4
5
6
7
8
9


                                                                  
                          

                      
                   
                                      
 
                   
               
 
                                           
                                            
                                                                                 

                                                             

                                                
 

                                                            

                                     
                                                           









                                             







                                                                                   
                                      
                                                            




                 


                                                 
                   


                                                                                                     
                                   
                                                                 

                                                                               


                                                                                              
                                                  


                         
                      










                                                                                        
                                                    


                            
                                               

                                            

                                                         
 

                                                         
 
                                                       
 
                                                                                   
                                                                               
                                                                    



                                                                  
                                                        
                                                                                                
                                       
                                                           





                                                                   
                                                                                                     
                                                                              

         
                                                    
                                                            
                                                              
                                                 


                                                                     
                                                       




                                                                                       
                           


                                                              
 
                                          
                                                                         
 




                                                                

 



                                                                                   
                                             
 
                                                              

                                                             


























                                                                                                                           
                                                                                      




                                                      

 




                               

                                                                         
 

                                                         


                                                           
                                                                                                
 
                                                              


                                                 
                                                                                         

                            





                               

                                                                         


                                                                 



                                                                                 




                                                                                             
                                                                                                



                                                              
                                                                                         

                            




                                                  







                                                        


                                                   
//! Copied from https://github.com/felipenoris/hyper-reverse-proxy
//! See there for original Copyright notice

use std::convert::TryInto;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};

use anyhow::Result;
use tracing::*;

use http::{header::HeaderName, StatusCode};
use hyper::header::{HeaderMap, HeaderValue};
use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri};
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, ServerName};

use crate::tls_util::HttpsConnectorFixedDnsname;

pub const PROXY_TIMEOUT: Duration = Duration::from_secs(60);

const HOP_HEADERS: &[HeaderName] = &[
	header::CONNECTION,
	// header::KEEP_ALIVE, // not found in http::header
	header::PROXY_AUTHENTICATE,
	header::PROXY_AUTHORIZATION,
	header::TE,
	header::TRAILER,
	header::TRANSFER_ENCODING,
	header::UPGRADE,
];

fn is_hop_header(name: &HeaderName) -> bool {
	HOP_HEADERS.iter().any(|h| h == name)
}

/// Returns a clone of the headers without the [hop-by-hop headers].
///
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
	let mut result = HeaderMap::new();
	for (k, v) in headers.iter() {
		if !is_hop_header(k) {
			result.append(k.clone(), v.clone());
		}
	}
	result
}

fn copy_upgrade_headers(
	old_headers: &HeaderMap<HeaderValue>,
	new_headers: &mut HeaderMap<HeaderValue>,
) -> Result<bool> {
	// The Connection header is stripped as it is a hop header that we are not supposed to proxy.
	// However, it might also contain an Upgrade directive, e.g. for Websockets:
	// when that happen, we do want to preserve that directive.
	let mut is_upgrade = false;
	if let Some(conn) = old_headers.get(header::CONNECTION) {
		let conn_str = conn.to_str()?.to_lowercase();
		if conn_str.split(',').map(str::trim).any(|x| x == "upgrade") {
			if let Some(upgrade) = old_headers.get(header::UPGRADE) {
				new_headers.insert(header::CONNECTION, "Upgrade".try_into()?);
				new_headers.insert(header::UPGRADE, upgrade.clone());
				is_upgrade = true;
			}
		}
	}
	Ok(is_upgrade)
}

fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
	let forward_uri = match req.uri().query() {
		Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
		None => format!("{}{}", forward_url, req.uri().path()),
	};

	Ok(Uri::from_str(forward_uri.as_str())?)
}

fn create_proxied_request<B: std::default::Default>(
	client_ip: IpAddr,
	forward_url: &str,
	request: Request<B>,
) -> Result<(Request<B>, Option<Request<B>>)> {
	let mut builder = Request::builder()
		.method(request.method())
		.uri(forward_uri(forward_url, &request)?)
		.version(hyper::Version::HTTP_11);

	let old_headers = request.headers();
	let new_headers = builder.headers_mut().unwrap();

	*new_headers = remove_hop_headers(old_headers);

	// If request does not have host header, add it from original URI authority
	if let header::Entry::Vacant(entry) = new_headers.entry(header::HOST) {
		if let Some(authority) = request.uri().authority() {
			entry.insert(authority.as_str().parse()?);
		}
	}

	// Concatenate cookie headers into single header
	// (HTTP/2 allows several cookie headers, but we are proxying to HTTP/1.1 that does not)
	let mut cookie_concat = vec![];
	for cookie in new_headers.get_all(header::COOKIE) {
		if !cookie_concat.is_empty() {
			cookie_concat.extend(b"; ");
		}
		cookie_concat.extend_from_slice(cookie.as_bytes());
	}
	if !cookie_concat.is_empty() {
		// insert clears the old value of COOKIE and inserts the concatenated version instead
		new_headers.insert(header::COOKIE, cookie_concat.try_into()?);
	}

	// Add forwarding information in the headers
	let x_forwarded_for_header_name = "x-forwarded-for";
	match new_headers.entry(x_forwarded_for_header_name) {
		header::Entry::Vacant(entry) => {
			entry.insert(client_ip.to_string().parse()?);
		}

		header::Entry::Occupied(mut entry) => {
			let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
			entry.insert(addr.parse()?);
		}
	}

	new_headers.insert(
		HeaderName::from_bytes(b"x-forwarded-proto")?,
		"https".try_into()?,
	);

	// Proxy upgrade requests properly
	let is_upgrade = copy_upgrade_headers(old_headers, new_headers)?;

	if is_upgrade {
		Ok((builder.body(B::default())?, Some(request)))
	} else {
		Ok((builder.body(request.into_body())?, None))
	}
}

async fn create_proxied_response<B: std::default::Default + Send + Sync + 'static>(
	mut response: Response<B>,
	upgrade_request: Option<Request<B>>,
) -> Result<Response<B>> {
	let old_headers = response.headers();

	let mut new_headers = remove_hop_headers(old_headers);
	copy_upgrade_headers(old_headers, &mut new_headers)?;

	if response.status() == StatusCode::SWITCHING_PROTOCOLS {
		if let Some(mut req) = upgrade_request {
			let mut res_upgraded = hyper::upgrade::on(response).await?;

			tokio::spawn(async move {
				match hyper::upgrade::on(&mut req).await {
					Ok(mut req_upgraded) => {
						if let Err(e) =
							tokio::io::copy_bidirectional(&mut req_upgraded, &mut res_upgraded)
								.await
						{
							warn!("Error copying data in upgraded request: {}", e);
						}
					}
					Err(e) => {
						warn!(
							"Could not upgrade client request when switching protocols: {}",
							e
						);
					}
				}
			});

			let mut new_res = Response::builder().status(StatusCode::SWITCHING_PROTOCOLS);
			*new_res.headers_mut().unwrap() = new_headers;
			Ok(new_res.body(B::default())?)
		} else {
			Err(anyhow!("Switching protocols but not an upgrade request"))
		}
	} else {
		*response.headers_mut() = new_headers;
		Ok(response)
	}
}

pub async fn call(
	client_ip: IpAddr,
	forward_uri: &str,
	request: Request<Body>,
) -> Result<Response<Body>> {
	let (proxied_request, upgrade_request) =
		create_proxied_request(client_ip, forward_uri, request)?;

	trace!("Proxied request: {:?}", proxied_request);

	let mut connector = HttpConnector::new();
	connector.set_connect_timeout(Some(PROXY_TIMEOUT));

	let client: Client<_, hyper::Body> = Client::builder().set_host(false).build(connector);

	let response = client.request(proxied_request).await?;

	trace!("Inner response: {:?}", response);

	let proxied_response = create_proxied_response(response, upgrade_request).await?;
	Ok(proxied_response)
}

pub async fn call_https(
	client_ip: IpAddr,
	forward_uri: &str,
	request: Request<Body>,
) -> Result<Response<Body>> {
	let (proxied_request, upgrade_request) =
		create_proxied_request(client_ip, forward_uri, request)?;

	trace!("Proxied request (HTTPS): {:?}", proxied_request);

	let tls_config = rustls::client::ClientConfig::builder()
		.with_safe_defaults()
		.with_custom_certificate_verifier(Arc::new(DontVerifyServerCert))
		.with_no_client_auth();

	let mut http_connector = HttpConnector::new();
	http_connector.set_connect_timeout(Some(PROXY_TIMEOUT));
	let connector = HttpsConnectorFixedDnsname::new(tls_config, "dummy", http_connector);

	let client: Client<_, hyper::Body> = Client::builder().set_host(false).build(connector);
	let response = client.request(proxied_request).await?;

	trace!("Inner response (HTTPS): {:?}", response);

	let proxied_response = create_proxied_response(response, upgrade_request).await?;
	Ok(proxied_response)
}

struct DontVerifyServerCert;

impl ServerCertVerifier for DontVerifyServerCert {
	fn verify_server_cert(
		&self,
		_end_entity: &Certificate,
		_intermediates: &[Certificate],
		_server_name: &ServerName,
		_scts: &mut dyn Iterator<Item = &[u8]>,
		_ocsp_response: &[u8],
		_now: SystemTime,
	) -> Result<ServerCertVerified, rustls::Error> {
		Ok(ServerCertVerified::assertion())
	}
}