aboutsummaryrefslogtreecommitdiff
path: root/src/reverse_proxy.rs
blob: 046808f67e46413a4df426690df1e84c3f062128 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
//! Copied from https://github.com/felipenoris/hyper-reverse-proxy
//! See there for original Copyright notice

use anyhow::Result;

use hyper::header::{HeaderMap, HeaderValue};
use hyper::{Body, Client, Request, Response, Uri};
use lazy_static::lazy_static;
use std::net::IpAddr;
use std::str::FromStr;

fn is_hop_header(name: &str) -> bool {
	use unicase::Ascii;

	// A list of the headers, using `unicase` to help us compare without
	// worrying about the case, and `lazy_static!` to prevent reallocation
	// of the vector.
	lazy_static! {
		static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![
			Ascii::new("Connection"),
			Ascii::new("Keep-Alive"),
			Ascii::new("Proxy-Authenticate"),
			Ascii::new("Proxy-Authorization"),
			Ascii::new("Te"),
			Ascii::new("Trailers"),
			Ascii::new("Transfer-Encoding"),
			Ascii::new("Upgrade"),
		];
	}

	HOP_HEADERS.iter().any(|h| h == &name)
}

/// Returns a clone of the headers without the [hop-by-hop headers].
///
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
	let mut result = HeaderMap::new();
	for (k, v) in headers.iter() {
		if !is_hop_header(k.as_str()) {
			result.insert(k.clone(), v.clone());
		}
	}
	result
}

fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
	*response.headers_mut() = remove_hop_headers(response.headers());
	response
}

fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
	let forward_uri = match req.uri().query() {
		Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
		None => format!("{}{}", forward_url, req.uri().path()),
	};

	Ok(Uri::from_str(forward_uri.as_str())?)
}

fn create_proxied_request<B>(
	client_ip: IpAddr,
	forward_url: &str,
	request: Request<B>,
) -> Result<Request<B>> {
	let mut builder = Request::builder().uri(forward_uri(forward_url, &request)?);

	*builder.headers_mut().unwrap() = remove_hop_headers(request.headers());

	let host_header_name = "host";
	let x_forwarded_for_header_name = "x-forwarded-for";

	// If request does not have host header, add it from original URI authority
	if let Some(authority) = request.uri().authority() {
		if let hyper::header::Entry::Vacant(entry) =
			builder.headers_mut().unwrap().entry(host_header_name)
		{
			entry.insert(authority.as_str().parse()?);
		}
	}

	// Add forwarding information in the headers
	match builder
		.headers_mut()
		.unwrap()
		.entry(x_forwarded_for_header_name)
	{
		hyper::header::Entry::Vacant(entry) => {
			entry.insert(client_ip.to_string().parse()?);
		}

		hyper::header::Entry::Occupied(mut entry) => {
			let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
			entry.insert(addr.parse()?);
		}
	}

	Ok(builder.body(request.into_body())?)
}

pub async fn call(
	client_ip: IpAddr,
	forward_uri: &str,
	request: Request<Body>,
) -> Result<Response<Body>> {
	let proxied_request = create_proxied_request(client_ip, &forward_uri, request)?;

	let client = Client::new();
	let response = client.request(proxied_request).await?;
	let proxied_response = create_proxied_response(response);
	Ok(proxied_response)
}