aboutsummaryrefslogtreecommitdiff
path: root/src/web
diff options
context:
space:
mode:
authorAlex <alex@adnab.me>2023-10-03 16:23:02 +0000
committerAlex <alex@adnab.me>2023-10-03 16:23:02 +0000
commit1243db87f2090a3302c7c8beb386e68ddf9b66b5 (patch)
tree7280b64fd8ef63be32f27f44f26f0ee7c8ca44be /src/web
parent16aa418e473a5e9ef229060d20f6eb280df272a2 (diff)
parent6f8a87814be502ecaee49cd37616ec7fe4c5b588 (diff)
downloadgarage-1243db87f2090a3302c7c8beb386e68ddf9b66b5.tar.gz
garage-1243db87f2090a3302c7c8beb386e68ddf9b66b5.zip
Merge pull request 'Add support for binding to unix domain sockets' (#640) from networkException/garage:unix-sockets into main
Reviewed-on: https://git.deuxfleurs.fr/Deuxfleurs/garage/pulls/640
Diffstat (limited to 'src/web')
-rw-r--r--src/web/Cargo.toml3
-rw-r--r--src/web/web_server.rs58
2 files changed, 52 insertions, 9 deletions
diff --git a/src/web/Cargo.toml b/src/web/Cargo.toml
index 6d0eba3a..eec47bcd 100644
--- a/src/web/Cargo.toml
+++ b/src/web/Cargo.toml
@@ -27,5 +27,8 @@ futures = "0.3"
http = "0.2"
hyper = { version = "0.14", features = ["server", "http1", "runtime", "tcp", "stream"] }
+hyperlocal = { version = "0.8.0", default-features = false, features = ["server"] }
+
+tokio = { version = "1.0", default-features = false, features = ["net"] }
opentelemetry = "0.17"
diff --git a/src/web/web_server.rs b/src/web/web_server.rs
index 287aef1a..73780efb 100644
--- a/src/web/web_server.rs
+++ b/src/web/web_server.rs
@@ -1,4 +1,6 @@
-use std::{convert::Infallible, net::SocketAddr, sync::Arc};
+use std::fs::{self, Permissions};
+use std::os::unix::prelude::PermissionsExt;
+use std::{convert::Infallible, sync::Arc};
use futures::future::Future;
@@ -9,6 +11,10 @@ use hyper::{
Body, Method, Request, Response, Server, StatusCode,
};
+use hyperlocal::UnixServerExt;
+
+use tokio::net::UnixStream;
+
use opentelemetry::{
global,
metrics::{Counter, ValueRecorder},
@@ -32,6 +38,7 @@ use garage_util::data::Uuid;
use garage_util::error::Error as GarageError;
use garage_util::forwarded_headers;
use garage_util::metrics::{gen_trace_id, RecordDuration};
+use garage_util::socket_address::UnixOrTCPSocketAddress;
struct WebMetrics {
request_counter: Counter<u64>,
@@ -69,7 +76,7 @@ impl WebServer {
/// Run a web server
pub async fn run(
garage: Arc<Garage>,
- addr: SocketAddr,
+ addr: UnixOrTCPSocketAddress,
root_domain: String,
shutdown_signal: impl Future<Output = ()>,
) -> Result<(), GarageError> {
@@ -80,7 +87,7 @@ impl WebServer {
root_domain,
});
- let service = make_service_fn(|conn: &AddrStream| {
+ let tcp_service = make_service_fn(|conn: &AddrStream| {
let web_server = web_server.clone();
let client_addr = conn.remote_addr();
@@ -88,23 +95,56 @@ impl WebServer {
Ok::<_, Error>(service_fn(move |req: Request<Body>| {
let web_server = web_server.clone();
- web_server.handle_request(req, client_addr)
+ web_server.handle_request(req, client_addr.to_string())
+ }))
+ }
+ });
+
+ let unix_service = make_service_fn(|_: &UnixStream| {
+ let web_server = web_server.clone();
+
+ let path = addr.to_string();
+ async move {
+ Ok::<_, Error>(service_fn(move |req: Request<Body>| {
+ let web_server = web_server.clone();
+
+ web_server.handle_request(req, path.clone())
}))
}
});
- let server = Server::bind(&addr).serve(service);
- let graceful = server.with_graceful_shutdown(shutdown_signal);
- info!("Web server listening on http://{}", addr);
+ info!("Web server listening on {}", addr);
+
+ match addr {
+ UnixOrTCPSocketAddress::TCPSocket(addr) => {
+ Server::bind(&addr)
+ .serve(tcp_service)
+ .with_graceful_shutdown(shutdown_signal)
+ .await?
+ }
+ UnixOrTCPSocketAddress::UnixSocket(ref path) => {
+ if path.exists() {
+ fs::remove_file(path)?
+ }
+
+ let bound = Server::bind_unix(path)?;
+
+ fs::set_permissions(path, Permissions::from_mode(0o222))?;
+
+ bound
+ .serve(unix_service)
+ .with_graceful_shutdown(shutdown_signal)
+ .await?;
+ }
+ };
- graceful.await?;
Ok(())
}
async fn handle_request(
self: Arc<Self>,
req: Request<Body>,
- addr: SocketAddr,
+ addr: String,
) -> Result<Response<Body>, Infallible> {
if let Ok(forwarded_for_ip_addr) =
forwarded_headers::handle_forwarded_for_headers(req.headers())