diff options
-rw-r--r-- | src/reverse_proxy.rs | 72 |
1 files changed, 57 insertions, 15 deletions
diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs index dc45869..74c43e3 100644 --- a/src/reverse_proxy.rs +++ b/src/reverse_proxy.rs @@ -10,7 +10,7 @@ use std::time::{Duration, SystemTime}; use anyhow::Result; use log::*; -use http::header::HeaderName; +use http::{header::HeaderName, StatusCode}; use hyper::header::{HeaderMap, HeaderValue}; use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri}; use rustls::client::{ServerCertVerified, ServerCertVerifier}; @@ -51,20 +51,22 @@ fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue fn copy_upgrade_headers( old_headers: &HeaderMap<HeaderValue>, new_headers: &mut HeaderMap<HeaderValue>, -) -> Result<()> { +) -> Result<bool> { // The Connection header is stripped as it is a hop header that we are not supposed to proxy. // However, it might also contain an Upgrade directive, e.g. for Websockets: // when that happen, we do want to preserve that directive. + let mut is_upgrade = false; if let Some(conn) = old_headers.get(header::CONNECTION) { let conn_str = conn.to_str()?.to_lowercase(); if conn_str.split(',').map(str::trim).any(|x| x == "upgrade") { if let Some(upgrade) = old_headers.get(header::UPGRADE) { new_headers.insert(header::CONNECTION, "Upgrade".try_into()?); new_headers.insert(header::UPGRADE, upgrade.clone()); + is_upgrade = true; } } } - Ok(()) + Ok(is_upgrade) } fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> { @@ -76,11 +78,11 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> { Ok(Uri::from_str(forward_uri.as_str())?) } -fn create_proxied_request<B>( +fn create_proxied_request<B: std::default::Default>( client_ip: IpAddr, forward_url: &str, request: Request<B>, -) -> Result<Request<B>> { +) -> Result<(Request<B>, Option<Request<B>>)> { let mut builder = Request::builder() .method(request.method()) .uri(forward_uri(forward_url, &request)?) @@ -131,19 +133,57 @@ fn create_proxied_request<B>( ); // Proxy upgrade requests properly - copy_upgrade_headers(old_headers, new_headers)?; + let is_upgrade = copy_upgrade_headers(old_headers, new_headers)?; - Ok(builder.body(request.into_body())?) + if is_upgrade { + Ok((builder.body(B::default())?, Some(request))) + } else { + Ok((builder.body(request.into_body())?, None)) + } } -fn create_proxied_response<B>(mut response: Response<B>) -> Result<Response<B>> { +async fn create_proxied_response<B: std::default::Default + Send + Sync + 'static>( + mut response: Response<B>, + upgrade_request: Option<Request<B>>, +) -> Result<Response<B>> { let old_headers = response.headers(); - let mut new_headers = remove_hop_headers(old_headers); + let mut new_headers = remove_hop_headers(old_headers); copy_upgrade_headers(old_headers, &mut new_headers)?; - *response.headers_mut() = new_headers; - Ok(response) + if response.status() == StatusCode::SWITCHING_PROTOCOLS { + if let Some(mut req) = upgrade_request { + let mut res_upgraded = hyper::upgrade::on(response).await?; + + tokio::spawn(async move { + match hyper::upgrade::on(&mut req).await { + Ok(mut req_upgraded) => { + if let Err(e) = + tokio::io::copy_bidirectional(&mut req_upgraded, &mut res_upgraded) + .await + { + warn!("Error copying data in upgraded request: {}", e); + } + } + Err(e) => { + warn!( + "Could not upgrade client request when switching protocols: {}", + e + ); + } + } + }); + + let mut new_res = Response::builder().status(StatusCode::SWITCHING_PROTOCOLS); + *new_res.headers_mut().unwrap() = new_headers; + Ok(new_res.body(B::default())?) + } else { + return Err(anyhow!("Switching protocols but not an upgrade request")); + } + } else { + *response.headers_mut() = new_headers; + Ok(response) + } } pub async fn call( @@ -151,7 +191,8 @@ pub async fn call( forward_uri: &str, request: Request<Body>, ) -> Result<Response<Body>> { - let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; + let (proxied_request, upgrade_request) = + create_proxied_request(client_ip, forward_uri, request)?; trace!("Proxied request: {:?}", proxied_request); @@ -164,7 +205,7 @@ pub async fn call( trace!("Inner response: {:?}", response); - let proxied_response = create_proxied_response(response)?; + let proxied_response = create_proxied_response(response, upgrade_request).await?; Ok(proxied_response) } @@ -173,7 +214,8 @@ pub async fn call_https( forward_uri: &str, request: Request<Body>, ) -> Result<Response<Body>> { - let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; + let (proxied_request, upgrade_request) = + create_proxied_request(client_ip, forward_uri, request)?; trace!("Proxied request (HTTPS): {:?}", proxied_request); @@ -191,7 +233,7 @@ pub async fn call_https( trace!("Inner response (HTTPS): {:?}", response); - let proxied_response = create_proxied_response(response)?; + let proxied_response = create_proxied_response(response, upgrade_request).await?; Ok(proxied_response) } |