diff options
author | Alex <alex@adnab.me> | 2023-10-03 16:23:02 +0000 |
---|---|---|
committer | Alex <alex@adnab.me> | 2023-10-03 16:23:02 +0000 |
commit | 1243db87f2090a3302c7c8beb386e68ddf9b66b5 (patch) | |
tree | 7280b64fd8ef63be32f27f44f26f0ee7c8ca44be /src/web | |
parent | 16aa418e473a5e9ef229060d20f6eb280df272a2 (diff) | |
parent | 6f8a87814be502ecaee49cd37616ec7fe4c5b588 (diff) | |
download | garage-1243db87f2090a3302c7c8beb386e68ddf9b66b5.tar.gz garage-1243db87f2090a3302c7c8beb386e68ddf9b66b5.zip |
Merge pull request 'Add support for binding to unix domain sockets' (#640) from networkException/garage:unix-sockets into main
Reviewed-on: https://git.deuxfleurs.fr/Deuxfleurs/garage/pulls/640
Diffstat (limited to 'src/web')
-rw-r--r-- | src/web/Cargo.toml | 3 | ||||
-rw-r--r-- | src/web/web_server.rs | 58 |
2 files changed, 52 insertions, 9 deletions
diff --git a/src/web/Cargo.toml b/src/web/Cargo.toml index 6d0eba3a..eec47bcd 100644 --- a/src/web/Cargo.toml +++ b/src/web/Cargo.toml @@ -27,5 +27,8 @@ futures = "0.3" http = "0.2" hyper = { version = "0.14", features = ["server", "http1", "runtime", "tcp", "stream"] } +hyperlocal = { version = "0.8.0", default-features = false, features = ["server"] } + +tokio = { version = "1.0", default-features = false, features = ["net"] } opentelemetry = "0.17" diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 287aef1a..73780efb 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -1,4 +1,6 @@ -use std::{convert::Infallible, net::SocketAddr, sync::Arc}; +use std::fs::{self, Permissions}; +use std::os::unix::prelude::PermissionsExt; +use std::{convert::Infallible, sync::Arc}; use futures::future::Future; @@ -9,6 +11,10 @@ use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; +use hyperlocal::UnixServerExt; + +use tokio::net::UnixStream; + use opentelemetry::{ global, metrics::{Counter, ValueRecorder}, @@ -32,6 +38,7 @@ use garage_util::data::Uuid; use garage_util::error::Error as GarageError; use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; +use garage_util::socket_address::UnixOrTCPSocketAddress; struct WebMetrics { request_counter: Counter<u64>, @@ -69,7 +76,7 @@ impl WebServer { /// Run a web server pub async fn run( garage: Arc<Garage>, - addr: SocketAddr, + addr: UnixOrTCPSocketAddress, root_domain: String, shutdown_signal: impl Future<Output = ()>, ) -> Result<(), GarageError> { @@ -80,7 +87,7 @@ impl WebServer { root_domain, }); - let service = make_service_fn(|conn: &AddrStream| { + let tcp_service = make_service_fn(|conn: &AddrStream| { let web_server = web_server.clone(); let client_addr = conn.remote_addr(); @@ -88,23 +95,56 @@ impl WebServer { Ok::<_, Error>(service_fn(move |req: Request<Body>| { let web_server = web_server.clone(); - web_server.handle_request(req, client_addr) + web_server.handle_request(req, client_addr.to_string()) + })) + } + }); + + let unix_service = make_service_fn(|_: &UnixStream| { + let web_server = web_server.clone(); + + let path = addr.to_string(); + async move { + Ok::<_, Error>(service_fn(move |req: Request<Body>| { + let web_server = web_server.clone(); + + web_server.handle_request(req, path.clone()) })) } }); - let server = Server::bind(&addr).serve(service); - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("Web server listening on http://{}", addr); + info!("Web server listening on {}", addr); + + match addr { + UnixOrTCPSocketAddress::TCPSocket(addr) => { + Server::bind(&addr) + .serve(tcp_service) + .with_graceful_shutdown(shutdown_signal) + .await? + } + UnixOrTCPSocketAddress::UnixSocket(ref path) => { + if path.exists() { + fs::remove_file(path)? + } + + let bound = Server::bind_unix(path)?; + + fs::set_permissions(path, Permissions::from_mode(0o222))?; + + bound + .serve(unix_service) + .with_graceful_shutdown(shutdown_signal) + .await?; + } + }; - graceful.await?; Ok(()) } async fn handle_request( self: Arc<Self>, req: Request<Body>, - addr: SocketAddr, + addr: String, ) -> Result<Response<Body>, Infallible> { if let Ok(forwarded_for_ip_addr) = forwarded_headers::handle_forwarded_for_headers(req.headers()) |