aboutsummaryrefslogtreecommitdiff
path: root/src/api/generic_server.rs
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2024-02-13 11:24:56 +0100
committerAlex Auvolat <alex@adnab.me>2024-02-13 11:36:28 +0100
commitcf2af186fcc0c8f581a966454b6cd4720d3821f0 (patch)
tree37a978ba9ffb780fc828cff7b8ec93662d50884f /src/api/generic_server.rs
parentdb48dd3d6c1f9e86a62e9b8edfce2c1620bcd5f3 (diff)
parent823078b4cdaf93e09de0847c5eaa75beb7b26b7f (diff)
downloadgarage-cf2af186fcc0c8f581a966454b6cd4720d3821f0.tar.gz
garage-cf2af186fcc0c8f581a966454b6cd4720d3821f0.zip
Merge branch 'main' into next-0.10
Diffstat (limited to 'src/api/generic_server.rs')
-rw-r--r--src/api/generic_server.rs216
1 files changed, 157 insertions, 59 deletions
diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs
index fa346f48..9c49fdf3 100644
--- a/src/api/generic_server.rs
+++ b/src/api/generic_server.rs
@@ -1,3 +1,4 @@
+use std::convert::Infallible;
use std::fs::{self, Permissions};
use std::os::unix::fs::PermissionsExt;
use std::sync::Arc;
@@ -5,16 +6,19 @@ use std::sync::Arc;
use async_trait::async_trait;
use futures::future::Future;
+use futures::stream::{futures_unordered::FuturesUnordered, StreamExt};
+use http_body_util::BodyExt;
use hyper::header::HeaderValue;
-use hyper::server::conn::AddrStream;
-use hyper::service::{make_service_fn, service_fn};
-use hyper::{Body, Request, Response, Server};
+use hyper::server::conn::http1;
+use hyper::service::service_fn;
+use hyper::{body::Incoming as IncomingBody, Request, Response};
use hyper::{HeaderMap, StatusCode};
+use hyper_util::rt::TokioIo;
-use hyperlocal::UnixServerExt;
-
-use tokio::net::UnixStream;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
+use tokio::sync::watch;
use opentelemetry::{
global,
@@ -28,6 +32,8 @@ use garage_util::forwarded_headers;
use garage_util::metrics::{gen_trace_id, RecordDuration};
use garage_util::socket_address::UnixOrTCPSocketAddress;
+use crate::helpers::{BoxBody, ErrorBody};
+
pub(crate) trait ApiEndpoint: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn add_span_attributes(&self, span: SpanRef<'_>);
@@ -36,7 +42,7 @@ pub(crate) trait ApiEndpoint: Send + Sync + 'static {
pub trait ApiError: std::error::Error + Send + Sync + 'static {
fn http_status_code(&self) -> StatusCode;
fn add_http_headers(&self, header_map: &mut HeaderMap<HeaderValue>);
- fn http_body(&self, garage_region: &str, path: &str) -> Body;
+ fn http_body(&self, garage_region: &str, path: &str) -> ErrorBody;
}
#[async_trait]
@@ -47,12 +53,12 @@ pub(crate) trait ApiHandler: Send + Sync + 'static {
type Endpoint: ApiEndpoint;
type Error: ApiError;
- fn parse_endpoint(&self, r: &Request<Body>) -> Result<Self::Endpoint, Self::Error>;
+ fn parse_endpoint(&self, r: &Request<IncomingBody>) -> Result<Self::Endpoint, Self::Error>;
async fn handle(
&self,
- req: Request<Body>,
+ req: Request<IncomingBody>,
endpoint: Self::Endpoint,
- ) -> Result<Response<Body>, Self::Error>;
+ ) -> Result<Response<BoxBody<Self::Error>>, Self::Error>;
}
pub(crate) struct ApiServer<A: ApiHandler> {
@@ -99,74 +105,42 @@ impl<A: ApiHandler> ApiServer<A> {
self: Arc<Self>,
bind_addr: UnixOrTCPSocketAddress,
unix_bind_addr_mode: Option<u32>,
- shutdown_signal: impl Future<Output = ()>,
+ must_exit: watch::Receiver<bool>,
) -> Result<(), GarageError> {
- let tcp_service = make_service_fn(|conn: &AddrStream| {
- let this = self.clone();
-
- let client_addr = conn.remote_addr();
- async move {
- Ok::<_, GarageError>(service_fn(move |req: Request<Body>| {
- let this = this.clone();
-
- this.handler(req, client_addr.to_string())
- }))
- }
- });
-
- let unix_service = make_service_fn(|_: &UnixStream| {
- let this = self.clone();
-
- let path = bind_addr.to_string();
- async move {
- Ok::<_, GarageError>(service_fn(move |req: Request<Body>| {
- let this = this.clone();
-
- this.handler(req, path.clone())
- }))
- }
- });
-
- info!(
- "{} API server listening on {}",
- A::API_NAME_DISPLAY,
- bind_addr
- );
+ let server_name = format!("{} API", A::API_NAME_DISPLAY);
+ info!("{} server listening on {}", server_name, bind_addr);
match bind_addr {
UnixOrTCPSocketAddress::TCPSocket(addr) => {
- Server::bind(&addr)
- .serve(tcp_service)
- .with_graceful_shutdown(shutdown_signal)
- .await?
+ let listener = TcpListener::bind(addr).await?;
+
+ let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
+ server_loop(server_name, listener, handler, must_exit).await
}
UnixOrTCPSocketAddress::UnixSocket(ref path) => {
if path.exists() {
fs::remove_file(path)?
}
- let bound = Server::bind_unix(path)?;
+ let listener = UnixListener::bind(path)?;
+ let listener = UnixListenerOn(listener, path.display().to_string());
fs::set_permissions(
path,
Permissions::from_mode(unix_bind_addr_mode.unwrap_or(0o222)),
)?;
- bound
- .serve(unix_service)
- .with_graceful_shutdown(shutdown_signal)
- .await?;
+ let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
+ server_loop(server_name, listener, handler, must_exit).await
}
- };
-
- Ok(())
+ }
}
async fn handler(
self: Arc<Self>,
- req: Request<Body>,
+ req: Request<IncomingBody>,
addr: String,
- ) -> Result<Response<Body>, GarageError> {
+ ) -> Result<Response<BoxBody<A::Error>>, http::Error> {
let uri = req.uri().clone();
if let Ok(forwarded_for_ip_addr) =
@@ -205,7 +179,7 @@ impl<A: ApiHandler> ApiServer<A> {
Ok(x)
}
Err(e) => {
- let body: Body = e.http_body(&self.region, uri.path());
+ let body = e.http_body(&self.region, uri.path());
let mut http_error_builder = Response::builder().status(e.http_status_code());
if let Some(header_map) = http_error_builder.headers_mut() {
@@ -219,12 +193,16 @@ impl<A: ApiHandler> ApiServer<A> {
} else {
info!("Response: error {}, {}", e.http_status_code(), e);
}
- Ok(http_error)
+ Ok(http_error
+ .map(|body| BoxBody::new(body.map_err(|_: Infallible| unreachable!()))))
}
}
}
- async fn handler_stage2(&self, req: Request<Body>) -> Result<Response<Body>, A::Error> {
+ async fn handler_stage2(
+ &self,
+ req: Request<IncomingBody>,
+ ) -> Result<Response<BoxBody<A::Error>>, A::Error> {
let endpoint = self.api_handler.parse_endpoint(&req)?;
debug!("Endpoint: {}", endpoint.name());
@@ -265,3 +243,123 @@ impl<A: ApiHandler> ApiServer<A> {
res
}
}
+
+// ==== helper functions ====
+
+#[async_trait]
+pub trait Accept: Send + Sync + 'static {
+ type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static;
+ async fn accept(&self) -> std::io::Result<(Self::Stream, String)>;
+}
+
+#[async_trait]
+impl Accept for TcpListener {
+ type Stream = TcpStream;
+ async fn accept(&self) -> std::io::Result<(Self::Stream, String)> {
+ self.accept()
+ .await
+ .map(|(stream, addr)| (stream, addr.to_string()))
+ }
+}
+
+pub struct UnixListenerOn(pub UnixListener, pub String);
+
+#[async_trait]
+impl Accept for UnixListenerOn {
+ type Stream = UnixStream;
+ async fn accept(&self) -> std::io::Result<(Self::Stream, String)> {
+ self.0
+ .accept()
+ .await
+ .map(|(stream, _addr)| (stream, self.1.clone()))
+ }
+}
+
+pub async fn server_loop<A, H, F, E>(
+ server_name: String,
+ listener: A,
+ handler: H,
+ mut must_exit: watch::Receiver<bool>,
+) -> Result<(), GarageError>
+where
+ A: Accept,
+ H: Fn(Request<IncomingBody>, String) -> F + Send + Sync + Clone + 'static,
+ F: Future<Output = Result<Response<BoxBody<E>>, http::Error>> + Send + 'static,
+ E: Send + Sync + std::error::Error + 'static,
+{
+ let (conn_in, mut conn_out) = tokio::sync::mpsc::unbounded_channel();
+ let connection_collector = tokio::spawn({
+ let server_name = server_name.clone();
+ async move {
+ let mut connections = FuturesUnordered::new();
+ loop {
+ let collect_next = async {
+ if connections.is_empty() {
+ futures::future::pending().await
+ } else {
+ connections.next().await
+ }
+ };
+ tokio::select! {
+ result = collect_next => {
+ trace!("{} server: HTTP connection finished: {:?}", server_name, result);
+ }
+ new_fut = conn_out.recv() => {
+ match new_fut {
+ Some(f) => connections.push(f),
+ None => break,
+ }
+ }
+ }
+ }
+ if !connections.is_empty() {
+ info!(
+ "{} server: {} connections still open",
+ server_name,
+ connections.len()
+ );
+ while let Some(conn_res) = connections.next().await {
+ trace!(
+ "{} server: HTTP connection finished: {:?}",
+ server_name,
+ conn_res
+ );
+ info!(
+ "{} server: {} connections still open",
+ server_name,
+ connections.len()
+ );
+ }
+ }
+ }
+ });
+
+ while !*must_exit.borrow() {
+ let (stream, client_addr) = tokio::select! {
+ acc = listener.accept() => acc?,
+ _ = must_exit.changed() => continue,
+ };
+
+ let io = TokioIo::new(stream);
+
+ let handler = handler.clone();
+ let serve = move |req: Request<IncomingBody>| handler(req, client_addr.clone());
+
+ let fut = tokio::task::spawn(async move {
+ let io = Box::pin(io);
+ if let Err(e) = http1::Builder::new()
+ .serve_connection(io, service_fn(serve))
+ .await
+ {
+ debug!("Error handling HTTP connection: {}", e);
+ }
+ });
+ conn_in.send(fut)?;
+ }
+
+ info!("{} server exiting", server_name);
+ drop(conn_in);
+ connection_collector.await?;
+
+ Ok(())
+}