aboutsummaryrefslogtreecommitdiff
path: root/aero-proto
diff options
context:
space:
mode:
Diffstat (limited to 'aero-proto')
-rw-r--r--aero-proto/src/dav/codec.rs80
-rw-r--r--aero-proto/src/dav/middleware.rs70
-rw-r--r--aero-proto/src/dav/mod.rs (renamed from aero-proto/src/dav.rs)135
3 files changed, 161 insertions, 124 deletions
diff --git a/aero-proto/src/dav/codec.rs b/aero-proto/src/dav/codec.rs
new file mode 100644
index 0000000..08af2fe
--- /dev/null
+++ b/aero-proto/src/dav/codec.rs
@@ -0,0 +1,80 @@
+use anyhow::Result;
+use hyper::{Request, Response, body::Bytes};
+use hyper::body::Incoming;
+use http_body_util::Full;
+use futures::stream::StreamExt;
+use futures::stream::TryStreamExt;
+use http_body_util::BodyStream;
+use http_body_util::StreamBody;
+use http_body_util::combinators::BoxBody;
+use hyper::body::Frame;
+use tokio_util::sync::PollSender;
+use std::io::{Error, ErrorKind};
+use futures::sink::SinkExt;
+use tokio_util::io::{SinkWriter, CopyToBytes};
+use http_body_util::BodyExt;
+
+use aero_dav::types as dav;
+use aero_dav::xml as dxml;
+
+pub(crate) fn depth(req: &Request<impl hyper::body::Body>) -> dav::Depth {
+ match req.headers().get("Depth").map(hyper::header::HeaderValue::to_str) {
+ Some(Ok("0")) => dav::Depth::Zero,
+ Some(Ok("1")) => dav::Depth::One,
+ Some(Ok("Infinity")) => dav::Depth::Infinity,
+ _ => dav::Depth::Zero,
+ }
+}
+
+pub(crate) fn text_body(txt: &'static str) -> BoxBody<Bytes, std::io::Error> {
+ BoxBody::new(Full::new(Bytes::from(txt)).map_err(|e| match e {}))
+}
+
+pub(crate) fn serialize<T: dxml::QWrite + Send + 'static>(status_ok: hyper::StatusCode, elem: T) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let (tx, rx) = tokio::sync::mpsc::channel::<Bytes>(1);
+
+ // Build the writer
+ tokio::task::spawn(async move {
+ let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe));
+ let mut writer = SinkWriter::new(CopyToBytes::new(sink));
+ let q = quick_xml::writer::Writer::new_with_indent(&mut writer, b' ', 4);
+ let ns_to_apply = vec![ ("xmlns:D".into(), "DAV:".into()), ("xmlns:C".into(), "urn:ietf:params:xml:ns:caldav".into()) ];
+ let mut qwriter = dxml::Writer { q, ns_to_apply };
+ let decl = quick_xml::events::BytesDecl::from_start(quick_xml::events::BytesStart::from_content("xml version=\"1.0\" encoding=\"utf-8\"", 0));
+ match qwriter.q.write_event_async(quick_xml::events::Event::Decl(decl)).await {
+ Ok(_) => (),
+ Err(e) => tracing::error!(err=?e, "unable to write XML declaration <?xml ... >"),
+ }
+ match elem.qwrite(&mut qwriter).await {
+ Ok(_) => tracing::debug!("fully serialized object"),
+ Err(e) => tracing::error!(err=?e, "failed to serialize object"),
+ }
+ });
+
+
+ // Build the reader
+ let recv = tokio_stream::wrappers::ReceiverStream::new(rx);
+ let stream = StreamBody::new(recv.map(|v| Ok(Frame::data(v))));
+ let boxed_body = BoxBody::new(stream);
+
+ let response = Response::builder()
+ .status(status_ok)
+ .header("content-type", "application/xml; charset=\"utf-8\"")
+ .body(boxed_body)?;
+
+ Ok(response)
+}
+
+
+/// Deserialize a request body to an XML request
+pub(crate) async fn deserialize<T: dxml::Node<T>>(req: Request<Incoming>) -> Result<T> {
+ let stream_of_frames = BodyStream::new(req.into_body());
+ let stream_of_bytes = stream_of_frames
+ .try_filter_map(|frame| async move { Ok(frame.into_data().ok()) })
+ .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err));
+ let async_read = tokio_util::io::StreamReader::new(stream_of_bytes);
+ let async_read = std::pin::pin!(async_read);
+ let mut rdr = dxml::Reader::new(quick_xml::reader::NsReader::from_reader(async_read)).await?;
+ let parsed = rdr.find::<T>().await?;
+ Ok(parsed)
+}
diff --git a/aero-proto/src/dav/middleware.rs b/aero-proto/src/dav/middleware.rs
new file mode 100644
index 0000000..c4edbd8
--- /dev/null
+++ b/aero-proto/src/dav/middleware.rs
@@ -0,0 +1,70 @@
+use anyhow::{anyhow, Result};
+use base64::Engine;
+use hyper::{Request, Response, body::Bytes};
+use hyper::body::Incoming;
+use http_body_util::combinators::BoxBody;
+
+use aero_user::login::ArcLoginProvider;
+use aero_collections::user::User;
+
+use super::codec::text_body;
+
+type ArcUser = std::sync::Arc<User>;
+
+pub(super) async fn auth<'a>(
+ login: ArcLoginProvider,
+ req: Request<Incoming>,
+ next: impl Fn(ArcUser, Request<Incoming>) -> futures::future::BoxFuture<'a, Result<Response<BoxBody<Bytes, std::io::Error>>>>,
+) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let auth_val = match req.headers().get(hyper::header::AUTHORIZATION) {
+ Some(hv) => hv.to_str()?,
+ None => {
+ tracing::info!("Missing authorization field");
+ return Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
+ .body(text_body("Missing Authorization field"))?)
+ },
+ };
+
+ let b64_creds_maybe_padded = match auth_val.split_once(" ") {
+ Some(("Basic", b64)) => b64,
+ _ => {
+ tracing::info!("Unsupported authorization field");
+ return Ok(Response::builder()
+ .status(400)
+ .body(text_body("Unsupported Authorization field"))?)
+ },
+ };
+
+ // base64urlencoded may have trailing equals, base64urlsafe has not
+ // theoretically authorization is padded but "be liberal in what you accept"
+ let b64_creds_clean = b64_creds_maybe_padded.trim_end_matches('=');
+
+ // Decode base64
+ let creds = base64::engine::general_purpose::STANDARD_NO_PAD.decode(b64_creds_clean)?;
+ let str_creds = std::str::from_utf8(&creds)?;
+
+ // Split username and password
+ let (username, password) = str_creds
+ .split_once(':')
+ .ok_or(anyhow!("Missing colon in Authorization, can't split decoded value into a username/password pair"))?;
+
+ // Call login provider
+ let creds = match login.login(username, password).await {
+ Ok(c) => c,
+ Err(_) => {
+ tracing::info!(user=username, "Wrong credentials");
+ return Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
+ .body(text_body("Wrong credentials"))?)
+ },
+ };
+
+ // Build a user
+ let user = User::new(username.into(), creds).await?;
+
+ // Call router with user
+ next(user, req).await
+}
diff --git a/aero-proto/src/dav.rs b/aero-proto/src/dav/mod.rs
index 424d4be..379e210 100644
--- a/aero-proto/src/dav.rs
+++ b/aero-proto/src/dav/mod.rs
@@ -1,15 +1,16 @@
+mod middleware;
+mod codec;
+
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{anyhow, bail, Result};
-use base64::Engine;
use hyper::service::service_fn;
use hyper::{Request, Response, body::Bytes};
use hyper::server::conn::http1 as http;
use hyper::rt::{Read, Write};
use hyper::body::Incoming;
use hyper_util::rt::TokioIo;
-use http_body_util::Full;
use futures::stream::{FuturesUnordered, StreamExt};
use tokio::net::TcpListener;
use tokio::sync::watch;
@@ -25,7 +26,8 @@ use aero_dav::types as dav;
use aero_dav::caltypes as cal;
use aero_dav::acltypes as acl;
use aero_dav::realization::{All, self as all};
-use aero_dav::xml as dxml;
+
+use crate::dav::codec::{serialize, deserialize, depth, text_body};
type ArcUser = std::sync::Arc<User>;
@@ -106,13 +108,13 @@ impl Server {
let login = login.clone();
tracing::info!("{:?} {:?}", req.method(), req.uri());
async {
- match auth(login, req).await {
+ match middleware::auth(login, req, |user, request| async { router(user, request).await }.boxed()).await {
Ok(v) => Ok(v),
Err(e) => {
tracing::error!(err=?e, "internal error");
Response::builder()
.status(500)
- .body(text_body("Internal error"))
+ .body(codec::text_body("Internal error"))
},
}
}
@@ -145,62 +147,7 @@ impl Server {
use http_body_util::BodyExt;
//@FIXME We should not support only BasicAuth
-async fn auth(
- login: ArcLoginProvider,
- req: Request<Incoming>,
-) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
- let auth_val = match req.headers().get(hyper::header::AUTHORIZATION) {
- Some(hv) => hv.to_str()?,
- None => {
- tracing::info!("Missing authorization field");
- return Ok(Response::builder()
- .status(401)
- .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
- .body(text_body("Missing Authorization field"))?)
- },
- };
- let b64_creds_maybe_padded = match auth_val.split_once(" ") {
- Some(("Basic", b64)) => b64,
- _ => {
- tracing::info!("Unsupported authorization field");
- return Ok(Response::builder()
- .status(400)
- .body(text_body("Unsupported Authorization field"))?)
- },
- };
-
- // base64urlencoded may have trailing equals, base64urlsafe has not
- // theoretically authorization is padded but "be liberal in what you accept"
- let b64_creds_clean = b64_creds_maybe_padded.trim_end_matches('=');
-
- // Decode base64
- let creds = base64::engine::general_purpose::STANDARD_NO_PAD.decode(b64_creds_clean)?;
- let str_creds = std::str::from_utf8(&creds)?;
-
- // Split username and password
- let (username, password) = str_creds
- .split_once(':')
- .ok_or(anyhow!("Missing colon in Authorization, can't split decoded value into a username/password pair"))?;
-
- // Call login provider
- let creds = match login.login(username, password).await {
- Ok(c) => c,
- Err(_) => {
- tracing::info!(user=username, "Wrong credentials");
- return Ok(Response::builder()
- .status(401)
- .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
- .body(text_body("Wrong credentials"))?)
- },
- };
-
- // Build a user
- let user = User::new(username.into(), creds).await?;
-
- // Call router with user
- router(user, req).await
-}
/// Path is a voluntarily feature limited
/// compared to the expressiveness of a UNIX path
@@ -237,7 +184,7 @@ async fn router(user: std::sync::Arc<User>, req: Request<Incoming>) -> Result<Re
tracing::warn!(err=?e, "dav node fetch failed");
return Ok(Response::builder()
.status(404)
- .body(text_body("Resource not found"))?)
+ .body(codec::text_body("Resource not found"))?)
}
};
let response = DavResponse { node, user, req };
@@ -247,12 +194,12 @@ async fn router(user: std::sync::Arc<User>, req: Request<Incoming>) -> Result<Re
.status(200)
.header("DAV", "1")
.header("Allow", "HEAD,GET,PUT,OPTIONS,DELETE,PROPFIND,PROPPATCH,MKCOL,COPY,MOVE,LOCK,UNLOCK,MKCALENDAR,REPORT")
- .body(text_body(""))?),
+ .body(codec::text_body(""))?),
"HEAD" | "GET" => {
tracing::warn!("HEAD+GET not correctly implemented");
return Ok(Response::builder()
.status(404)
- .body(text_body(""))?)
+ .body(codec::text_body(""))?)
},
"PUT" => {
todo!();
@@ -264,7 +211,7 @@ async fn router(user: std::sync::Arc<User>, req: Request<Incoming>) -> Result<Re
"REPORT" => response.report().await,
_ => return Ok(Response::builder()
.status(501)
- .body(text_body("HTTP Method not implemented"))?),
+ .body(codec::text_body("HTTP Method not implemented"))?),
}
}
@@ -294,68 +241,8 @@ use std::io::{Error, ErrorKind};
use futures::sink::SinkExt;
use tokio_util::io::{SinkWriter, CopyToBytes};
-fn depth(req: &Request<impl hyper::body::Body>) -> dav::Depth {
- match req.headers().get("Depth").map(hyper::header::HeaderValue::to_str) {
- Some(Ok("0")) => dav::Depth::Zero,
- Some(Ok("1")) => dav::Depth::One,
- Some(Ok("Infinity")) => dav::Depth::Infinity,
- _ => dav::Depth::Zero,
- }
-}
-
-fn text_body(txt: &'static str) -> BoxBody<Bytes, std::io::Error> {
- BoxBody::new(Full::new(Bytes::from(txt)).map_err(|e| match e {}))
-}
-
-fn serialize<T: dxml::QWrite + Send + 'static>(status_ok: hyper::StatusCode, elem: T) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
- let (tx, rx) = tokio::sync::mpsc::channel::<Bytes>(1);
-
- // Build the writer
- tokio::task::spawn(async move {
- let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe));
- let mut writer = SinkWriter::new(CopyToBytes::new(sink));
- let q = quick_xml::writer::Writer::new_with_indent(&mut writer, b' ', 4);
- let ns_to_apply = vec![ ("xmlns:D".into(), "DAV:".into()), ("xmlns:C".into(), "urn:ietf:params:xml:ns:caldav".into()) ];
- let mut qwriter = dxml::Writer { q, ns_to_apply };
- let decl = quick_xml::events::BytesDecl::from_start(quick_xml::events::BytesStart::from_content("xml version=\"1.0\" encoding=\"utf-8\"", 0));
- match qwriter.q.write_event_async(quick_xml::events::Event::Decl(decl)).await {
- Ok(_) => (),
- Err(e) => tracing::error!(err=?e, "unable to write XML declaration <?xml ... >"),
- }
- match elem.qwrite(&mut qwriter).await {
- Ok(_) => tracing::debug!("fully serialized object"),
- Err(e) => tracing::error!(err=?e, "failed to serialize object"),
- }
- });
- // Build the reader
- let recv = tokio_stream::wrappers::ReceiverStream::new(rx);
- let stream = StreamBody::new(recv.map(|v| Ok(Frame::data(v))));
- let boxed_body = BoxBody::new(stream);
-
- let response = Response::builder()
- .status(status_ok)
- .header("content-type", "application/xml; charset=\"utf-8\"")
- .body(boxed_body)?;
-
- Ok(response)
-}
-
-
-/// Deserialize a request body to an XML request
-async fn deserialize<T: dxml::Node<T>>(req: Request<Incoming>) -> Result<T> {
- let stream_of_frames = BodyStream::new(req.into_body());
- let stream_of_bytes = stream_of_frames
- .try_filter_map(|frame| async move { Ok(frame.into_data().ok()) })
- .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err));
- let async_read = tokio_util::io::StreamReader::new(stream_of_bytes);
- let async_read = std::pin::pin!(async_read);
- let mut rdr = dxml::Reader::new(quick_xml::reader::NsReader::from_reader(async_read)).await?;
- let parsed = rdr.find::<T>().await?;
- Ok(parsed)
-}
-
//---
use futures::{future, future::BoxFuture, future::FutureExt};