diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/api/generic_server.rs | 13 | ||||
-rw-r--r-- | src/util/forwarded_headers.rs | 63 | ||||
-rw-r--r-- | src/util/lib.rs | 1 | ||||
-rw-r--r-- | src/web/web_server.rs | 15 |
4 files changed, 82 insertions, 10 deletions
diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs index aa90868a..d0354d28 100644 --- a/src/api/generic_server.rs +++ b/src/api/generic_server.rs @@ -19,6 +19,7 @@ use opentelemetry::{ }; use garage_util::error::Error as GarageError; +use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; pub(crate) trait ApiEndpoint: Send + Sync + 'static { @@ -126,15 +127,9 @@ impl<A: ApiHandler> ApiServer<A> { ) -> Result<Response<Body>, GarageError> { let uri = req.uri().clone(); - let has_forwarded_for_header = req.headers().contains_key("x-forwarded-for"); - if has_forwarded_for_header { - let forwarded_for_ip_addr = &req - .headers() - .get("x-forwarded-for") - .expect("Could not parse X-Forwarded-For header") - .to_str() - .unwrap_or_default(); - + if let Ok(forwarded_for_ip_addr) = + forwarded_headers::handle_forwarded_for_headers(&req.headers()) + { info!( "{} (via {}) {} {}", forwarded_for_ip_addr, diff --git a/src/util/forwarded_headers.rs b/src/util/forwarded_headers.rs new file mode 100644 index 00000000..6ae275aa --- /dev/null +++ b/src/util/forwarded_headers.rs @@ -0,0 +1,63 @@ +use http::{HeaderMap, HeaderValue}; +use std::net::IpAddr; +use std::str::FromStr; + +use crate::error::{Error, OkOrMessage}; + +pub fn handle_forwarded_for_headers(headers: &HeaderMap<HeaderValue>) -> Result<String, Error> { + let forwarded_for_header = headers + .get("x-forwarded-for") + .ok_or_message("X-Forwarded-For header not provided")?; + + let forwarded_for_ip_str = forwarded_for_header + .to_str() + .ok_or_message("Error parsing X-Forwarded-For header")?; + + let client_ip = IpAddr::from_str(&forwarded_for_ip_str) + .ok_or_message("Valid IP address not found in X-Forwarded-For header")?; + + Ok(client_ip.to_string()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_handle_forwarded_for_headers_ipv4_client() { + let mut test_headers = HeaderMap::new(); + test_headers.insert("X-Forwarded-For", "192.0.2.100".parse().unwrap()); + + if let Ok(forwarded_ip) = handle_forwarded_for_headers(&test_headers) { + assert_eq!(forwarded_ip, "192.0.2.100"); + } + } + + #[test] + fn test_handle_forwarded_for_headers_ipv6_client() { + let mut test_headers = HeaderMap::new(); + test_headers.insert("X-Forwarded-For", "2001:db8::f00d:cafe".parse().unwrap()); + + if let Ok(forwarded_ip) = handle_forwarded_for_headers(&test_headers) { + assert_eq!(forwarded_ip, "2001:db8::f00d:cafe"); + } + } + + #[test] + fn test_handle_forwarded_for_headers_invalid_ip() { + let mut test_headers = HeaderMap::new(); + test_headers.insert("X-Forwarded-For", "www.example.com".parse().unwrap()); + + let result = handle_forwarded_for_headers(&test_headers); + assert!(result.is_err()); + } + + #[test] + fn test_handle_forwarded_for_headers_missing() { + let mut test_headers = HeaderMap::new(); + test_headers.insert("Host", "www.deuxfleurs.fr".parse().unwrap()); + + let result = handle_forwarded_for_headers(&test_headers); + assert!(result.is_err()); + } +} diff --git a/src/util/lib.rs b/src/util/lib.rs index d35ca72f..c9110fb2 100644 --- a/src/util/lib.rs +++ b/src/util/lib.rs @@ -11,6 +11,7 @@ pub mod data; pub mod encode; pub mod error; pub mod formater; +pub mod forwarded_headers; pub mod metrics; pub mod migrate; pub mod persister; diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 5719da54..0c7edf23 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -29,6 +29,7 @@ use garage_model::garage::Garage; use garage_table::*; use garage_util::error::Error as GarageError; +use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; struct WebMetrics { @@ -104,7 +105,19 @@ impl WebServer { req: Request<Body>, addr: SocketAddr, ) -> Result<Response<Body>, Infallible> { - info!("{} {} {}", addr, req.method(), req.uri()); + if let Ok(forwarded_for_ip_addr) = + forwarded_headers::handle_forwarded_for_headers(&req.headers()) + { + info!( + "{} (via {}) {} {}", + forwarded_for_ip_addr, + addr, + req.method(), + req.uri() + ); + } else { + info!("{} {} {}", addr, req.method(), req.uri()); + } // Lots of instrumentation let tracer = opentelemetry::global::tracer("garage"); |