aboutsummaryrefslogtreecommitdiff
path: root/aero-proto/src/dav/codec.rs
blob: a441e7e4e2004d489b6c17833f8ee3bf3bf0194a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use anyhow::{bail, Result};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use futures::stream::TryStreamExt;
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::BodyExt;
use http_body_util::BodyStream;
use http_body_util::Full;
use http_body_util::StreamBody;
use hyper::body::Frame;
use hyper::body::Incoming;
use hyper::{body::Bytes, Request, Response};
use std::io::{Error, ErrorKind};
use tokio_util::io::{CopyToBytes, SinkWriter};
use tokio_util::sync::PollSender;

use super::controller::HttpResponse;
use super::node::PutPolicy;
use aero_dav::types as dav;
use aero_dav::xml as dxml;

pub(crate) fn depth(req: &Request<impl hyper::body::Body>) -> dav::Depth {
    match req
        .headers()
        .get("Depth")
        .map(hyper::header::HeaderValue::to_str)
    {
        Some(Ok("0")) => dav::Depth::Zero,
        Some(Ok("1")) => dav::Depth::One,
        Some(Ok("Infinity")) => dav::Depth::Infinity,
        _ => dav::Depth::Zero,
    }
}

pub(crate) fn put_policy(req: &Request<impl hyper::body::Body>) -> Result<PutPolicy> {
    if let Some(maybe_txt_etag) = req
        .headers()
        .get("If-Match")
        .map(hyper::header::HeaderValue::to_str)
    {
        let etag = maybe_txt_etag?;
        let dquote_count = etag.chars().filter(|c| *c == '"').count();
        if dquote_count != 2 {
            bail!("Either If-Match value is invalid or it's not supported (only single etag is supported)");
        }

        return Ok(PutPolicy::ReplaceEtag(etag.into()));
    }

    if let Some(maybe_txt_etag) = req
        .headers()
        .get("If-None-Match")
        .map(hyper::header::HeaderValue::to_str)
    {
        let etag = maybe_txt_etag?;
        if etag == "*" {
            return Ok(PutPolicy::CreateOnly);
        }
        bail!("Either If-None-Match value is invalid or it's not supported (only asterisk is supported)")
    }

    Ok(PutPolicy::OverwriteAll)
}

pub(crate) fn text_body(txt: &'static str) -> UnsyncBoxBody<Bytes, std::io::Error> {
    UnsyncBoxBody::new(Full::new(Bytes::from(txt)).map_err(|e| match e {}))
}

pub(crate) fn serialize<T: dxml::QWrite + Send + 'static>(
    status_ok: hyper::StatusCode,
    elem: T,
) -> Result<HttpResponse> {
    let (tx, rx) = tokio::sync::mpsc::channel::<Bytes>(1);

    // Build the writer
    tokio::task::spawn(async move {
        let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe));
        let mut writer = SinkWriter::new(CopyToBytes::new(sink));
        let q = quick_xml::writer::Writer::new_with_indent(&mut writer, b' ', 4);
        let ns_to_apply = vec![
            ("xmlns:D".into(), "DAV:".into()),
            ("xmlns:C".into(), "urn:ietf:params:xml:ns:caldav".into()),
        ];
        let mut qwriter = dxml::Writer { q, ns_to_apply };
        let decl =
            quick_xml::events::BytesDecl::from_start(quick_xml::events::BytesStart::from_content(
                "xml version=\"1.0\" encoding=\"utf-8\"",
                0,
            ));
        match qwriter
            .q
            .write_event_async(quick_xml::events::Event::Decl(decl))
            .await
        {
            Ok(_) => (),
            Err(e) => tracing::error!(err=?e, "unable to write XML declaration <?xml ... >"),
        }
        match elem.qwrite(&mut qwriter).await {
            Ok(_) => tracing::debug!("fully serialized object"),
            Err(e) => tracing::error!(err=?e, "failed to serialize object"),
        }
    });

    // Build the reader
    let recv = tokio_stream::wrappers::ReceiverStream::new(rx);
    let stream = StreamBody::new(recv.map(|v| Ok(Frame::data(v))));
    let boxed_body = UnsyncBoxBody::new(stream);

    let response = Response::builder()
        .status(status_ok)
        .header("content-type", "application/xml; charset=\"utf-8\"")
        .body(boxed_body)?;

    Ok(response)
}

/// Deserialize a request body to an XML request
pub(crate) async fn deserialize<T: dxml::Node<T>>(req: Request<Incoming>) -> Result<T> {
    let stream_of_frames = BodyStream::new(req.into_body());
    let stream_of_bytes = stream_of_frames
        .map_ok(|frame| frame.into_data())
        .map(|obj| match obj {
            Ok(Ok(v)) => Ok(v),
            Ok(Err(_)) => Err(std::io::Error::new(
                std::io::ErrorKind::Other,
                "conversion error",
            )),
            Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
        });
    let async_read = tokio_util::io::StreamReader::new(stream_of_bytes);
    let async_read = std::pin::pin!(async_read);
    let mut rdr = dxml::Reader::new(quick_xml::reader::NsReader::from_reader(async_read)).await?;
    let parsed = rdr.find::<T>().await?;
    Ok(parsed)
}