diff options
Diffstat (limited to 'src/api/generic_server.rs')
-rw-r--r-- | src/api/generic_server.rs | 113 |
1 files changed, 63 insertions, 50 deletions
diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs index fa346f48..832f2da3 100644 --- a/src/api/generic_server.rs +++ b/src/api/generic_server.rs @@ -6,15 +6,16 @@ use async_trait::async_trait; use futures::future::Future; +use http_body_util::BodyExt; use hyper::header::HeaderValue; -use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, Server}; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{body::Incoming as IncomingBody, Request, Response}; use hyper::{HeaderMap, StatusCode}; +use hyper_util::rt::TokioIo; -use hyperlocal::UnixServerExt; - -use tokio::net::UnixStream; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, UnixListener}; use opentelemetry::{ global, @@ -28,6 +29,8 @@ use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; use garage_util::socket_address::UnixOrTCPSocketAddress; +use crate::helpers::{BoxBody, BytesBody}; + pub(crate) trait ApiEndpoint: Send + Sync + 'static { fn name(&self) -> &'static str; fn add_span_attributes(&self, span: SpanRef<'_>); @@ -36,7 +39,7 @@ pub(crate) trait ApiEndpoint: Send + Sync + 'static { pub trait ApiError: std::error::Error + Send + Sync + 'static { fn http_status_code(&self) -> StatusCode; fn add_http_headers(&self, header_map: &mut HeaderMap<HeaderValue>); - fn http_body(&self, garage_region: &str, path: &str) -> Body; + fn http_body(&self, garage_region: &str, path: &str) -> BytesBody; } #[async_trait] @@ -47,12 +50,12 @@ pub(crate) trait ApiHandler: Send + Sync + 'static { type Endpoint: ApiEndpoint; type Error: ApiError; - fn parse_endpoint(&self, r: &Request<Body>) -> Result<Self::Endpoint, Self::Error>; + fn parse_endpoint(&self, r: &Request<IncomingBody>) -> Result<Self::Endpoint, Self::Error>; async fn handle( &self, - req: Request<Body>, + req: Request<IncomingBody>, endpoint: Self::Endpoint, - ) -> Result<Response<Body>, Self::Error>; + ) -> Result<Response<BoxBody<Self::Error>>, Self::Error>; } pub(crate) struct ApiServer<A: ApiHandler> { @@ -101,72 +104,79 @@ impl<A: ApiHandler> ApiServer<A> { unix_bind_addr_mode: Option<u32>, shutdown_signal: impl Future<Output = ()>, ) -> Result<(), GarageError> { - let tcp_service = make_service_fn(|conn: &AddrStream| { - let this = self.clone(); - - let client_addr = conn.remote_addr(); - async move { - Ok::<_, GarageError>(service_fn(move |req: Request<Body>| { - let this = this.clone(); - - this.handler(req, client_addr.to_string()) - })) - } - }); - - let unix_service = make_service_fn(|_: &UnixStream| { - let this = self.clone(); - - let path = bind_addr.to_string(); - async move { - Ok::<_, GarageError>(service_fn(move |req: Request<Body>| { - let this = this.clone(); - - this.handler(req, path.clone()) - })) - } - }); - info!( "{} API server listening on {}", A::API_NAME_DISPLAY, bind_addr ); + tokio::pin!(shutdown_signal); + match bind_addr { UnixOrTCPSocketAddress::TCPSocket(addr) => { - Server::bind(&addr) - .serve(tcp_service) - .with_graceful_shutdown(shutdown_signal) - .await? + let listener = TcpListener::bind(addr).await?; + + loop { + let (stream, client_addr) = tokio::select! { + acc = listener.accept() => acc?, + _ = &mut shutdown_signal => break, + }; + + self.launch_handler(stream, client_addr.to_string()); + } } UnixOrTCPSocketAddress::UnixSocket(ref path) => { if path.exists() { fs::remove_file(path)? } - let bound = Server::bind_unix(path)?; + let listener = UnixListener::bind(path)?; fs::set_permissions( path, Permissions::from_mode(unix_bind_addr_mode.unwrap_or(0o222)), )?; - bound - .serve(unix_service) - .with_graceful_shutdown(shutdown_signal) - .await?; + loop { + let (stream, _) = tokio::select! { + acc = listener.accept() => acc?, + _ = &mut shutdown_signal => break, + }; + + self.launch_handler(stream, path.display().to_string()); + } } }; Ok(()) } + fn launch_handler<S>(self: &Arc<Self>, stream: S, client_addr: String) + where + S: AsyncRead + AsyncWrite + Send + Sync + 'static, + { + let this = self.clone(); + let io = TokioIo::new(stream); + + let serve = + move |req: Request<IncomingBody>| this.clone().handler(req, client_addr.to_string()); + + tokio::task::spawn(async move { + let io = Box::pin(io); + if let Err(e) = http1::Builder::new() + .serve_connection(io, service_fn(serve)) + .await + { + debug!("Error handling HTTP connection: {}", e); + } + }); + } + async fn handler( self: Arc<Self>, - req: Request<Body>, + req: Request<IncomingBody>, addr: String, - ) -> Result<Response<Body>, GarageError> { + ) -> Result<Response<BoxBody<A::Error>>, GarageError> { let uri = req.uri().clone(); if let Ok(forwarded_for_ip_addr) = @@ -205,7 +215,7 @@ impl<A: ApiHandler> ApiServer<A> { Ok(x) } Err(e) => { - let body: Body = e.http_body(&self.region, uri.path()); + let body = e.http_body(&self.region, uri.path()); let mut http_error_builder = Response::builder().status(e.http_status_code()); if let Some(header_map) = http_error_builder.headers_mut() { @@ -219,12 +229,15 @@ impl<A: ApiHandler> ApiServer<A> { } else { info!("Response: error {}, {}", e.http_status_code(), e); } - Ok(http_error) + Ok(http_error.map(|body| BoxBody::new(body.map_err(|_| unreachable!())))) } } } - async fn handler_stage2(&self, req: Request<Body>) -> Result<Response<Body>, A::Error> { + async fn handler_stage2( + &self, + req: Request<IncomingBody>, + ) -> Result<Response<BoxBody<A::Error>>, A::Error> { let endpoint = self.api_handler.parse_endpoint(&req)?; debug!("Endpoint: {}", endpoint.name()); |