aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/reverse_proxy.rs72
1 files changed, 57 insertions, 15 deletions
diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs
index dc45869..74c43e3 100644
--- a/src/reverse_proxy.rs
+++ b/src/reverse_proxy.rs
@@ -10,7 +10,7 @@ use std::time::{Duration, SystemTime};
use anyhow::Result;
use log::*;
-use http::header::HeaderName;
+use http::{header::HeaderName, StatusCode};
use hyper::header::{HeaderMap, HeaderValue};
use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri};
use rustls::client::{ServerCertVerified, ServerCertVerifier};
@@ -51,20 +51,22 @@ fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue
fn copy_upgrade_headers(
old_headers: &HeaderMap<HeaderValue>,
new_headers: &mut HeaderMap<HeaderValue>,
-) -> Result<()> {
+) -> Result<bool> {
// The Connection header is stripped as it is a hop header that we are not supposed to proxy.
// However, it might also contain an Upgrade directive, e.g. for Websockets:
// when that happen, we do want to preserve that directive.
+ let mut is_upgrade = false;
if let Some(conn) = old_headers.get(header::CONNECTION) {
let conn_str = conn.to_str()?.to_lowercase();
if conn_str.split(',').map(str::trim).any(|x| x == "upgrade") {
if let Some(upgrade) = old_headers.get(header::UPGRADE) {
new_headers.insert(header::CONNECTION, "Upgrade".try_into()?);
new_headers.insert(header::UPGRADE, upgrade.clone());
+ is_upgrade = true;
}
}
}
- Ok(())
+ Ok(is_upgrade)
}
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
@@ -76,11 +78,11 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
Ok(Uri::from_str(forward_uri.as_str())?)
}
-fn create_proxied_request<B>(
+fn create_proxied_request<B: std::default::Default>(
client_ip: IpAddr,
forward_url: &str,
request: Request<B>,
-) -> Result<Request<B>> {
+) -> Result<(Request<B>, Option<Request<B>>)> {
let mut builder = Request::builder()
.method(request.method())
.uri(forward_uri(forward_url, &request)?)
@@ -131,19 +133,57 @@ fn create_proxied_request<B>(
);
// Proxy upgrade requests properly
- copy_upgrade_headers(old_headers, new_headers)?;
+ let is_upgrade = copy_upgrade_headers(old_headers, new_headers)?;
- Ok(builder.body(request.into_body())?)
+ if is_upgrade {
+ Ok((builder.body(B::default())?, Some(request)))
+ } else {
+ Ok((builder.body(request.into_body())?, None))
+ }
}
-fn create_proxied_response<B>(mut response: Response<B>) -> Result<Response<B>> {
+async fn create_proxied_response<B: std::default::Default + Send + Sync + 'static>(
+ mut response: Response<B>,
+ upgrade_request: Option<Request<B>>,
+) -> Result<Response<B>> {
let old_headers = response.headers();
- let mut new_headers = remove_hop_headers(old_headers);
+ let mut new_headers = remove_hop_headers(old_headers);
copy_upgrade_headers(old_headers, &mut new_headers)?;
- *response.headers_mut() = new_headers;
- Ok(response)
+ if response.status() == StatusCode::SWITCHING_PROTOCOLS {
+ if let Some(mut req) = upgrade_request {
+ let mut res_upgraded = hyper::upgrade::on(response).await?;
+
+ tokio::spawn(async move {
+ match hyper::upgrade::on(&mut req).await {
+ Ok(mut req_upgraded) => {
+ if let Err(e) =
+ tokio::io::copy_bidirectional(&mut req_upgraded, &mut res_upgraded)
+ .await
+ {
+ warn!("Error copying data in upgraded request: {}", e);
+ }
+ }
+ Err(e) => {
+ warn!(
+ "Could not upgrade client request when switching protocols: {}",
+ e
+ );
+ }
+ }
+ });
+
+ let mut new_res = Response::builder().status(StatusCode::SWITCHING_PROTOCOLS);
+ *new_res.headers_mut().unwrap() = new_headers;
+ Ok(new_res.body(B::default())?)
+ } else {
+ return Err(anyhow!("Switching protocols but not an upgrade request"));
+ }
+ } else {
+ *response.headers_mut() = new_headers;
+ Ok(response)
+ }
}
pub async fn call(
@@ -151,7 +191,8 @@ pub async fn call(
forward_uri: &str,
request: Request<Body>,
) -> Result<Response<Body>> {
- let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
+ let (proxied_request, upgrade_request) =
+ create_proxied_request(client_ip, forward_uri, request)?;
trace!("Proxied request: {:?}", proxied_request);
@@ -164,7 +205,7 @@ pub async fn call(
trace!("Inner response: {:?}", response);
- let proxied_response = create_proxied_response(response)?;
+ let proxied_response = create_proxied_response(response, upgrade_request).await?;
Ok(proxied_response)
}
@@ -173,7 +214,8 @@ pub async fn call_https(
forward_uri: &str,
request: Request<Body>,
) -> Result<Response<Body>> {
- let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
+ let (proxied_request, upgrade_request) =
+ create_proxied_request(client_ip, forward_uri, request)?;
trace!("Proxied request (HTTPS): {:?}", proxied_request);
@@ -191,7 +233,7 @@ pub async fn call_https(
trace!("Inner response (HTTPS): {:?}", response);
- let proxied_response = create_proxied_response(response)?;
+ let proxied_response = create_proxied_response(response, upgrade_request).await?;
Ok(proxied_response)
}