diff options
Diffstat (limited to 'aero-proto/src')
28 files changed, 7095 insertions, 0 deletions
diff --git a/aero-proto/src/dav/codec.rs b/aero-proto/src/dav/codec.rs new file mode 100644 index 0000000..a441e7e --- /dev/null +++ b/aero-proto/src/dav/codec.rs @@ -0,0 +1,135 @@ +use anyhow::{bail, Result}; +use futures::sink::SinkExt; +use futures::stream::StreamExt; +use futures::stream::TryStreamExt; +use http_body_util::combinators::UnsyncBoxBody; +use http_body_util::BodyExt; +use http_body_util::BodyStream; +use http_body_util::Full; +use http_body_util::StreamBody; +use hyper::body::Frame; +use hyper::body::Incoming; +use hyper::{body::Bytes, Request, Response}; +use std::io::{Error, ErrorKind}; +use tokio_util::io::{CopyToBytes, SinkWriter}; +use tokio_util::sync::PollSender; + +use super::controller::HttpResponse; +use super::node::PutPolicy; +use aero_dav::types as dav; +use aero_dav::xml as dxml; + +pub(crate) fn depth(req: &Request<impl hyper::body::Body>) -> dav::Depth { + match req + .headers() + .get("Depth") + .map(hyper::header::HeaderValue::to_str) + { + Some(Ok("0")) => dav::Depth::Zero, + Some(Ok("1")) => dav::Depth::One, + Some(Ok("Infinity")) => dav::Depth::Infinity, + _ => dav::Depth::Zero, + } +} + +pub(crate) fn put_policy(req: &Request<impl hyper::body::Body>) -> Result<PutPolicy> { + if let Some(maybe_txt_etag) = req + .headers() + .get("If-Match") + .map(hyper::header::HeaderValue::to_str) + { + let etag = maybe_txt_etag?; + let dquote_count = etag.chars().filter(|c| *c == '"').count(); + if dquote_count != 2 { + bail!("Either If-Match value is invalid or it's not supported (only single etag is supported)"); + } + + return Ok(PutPolicy::ReplaceEtag(etag.into())); + } + + if let Some(maybe_txt_etag) = req + .headers() + .get("If-None-Match") + .map(hyper::header::HeaderValue::to_str) + { + let etag = maybe_txt_etag?; + if etag == "*" { + return Ok(PutPolicy::CreateOnly); + } + bail!("Either If-None-Match value is invalid or it's not supported (only asterisk is supported)") + } + + Ok(PutPolicy::OverwriteAll) +} + +pub(crate) fn text_body(txt: &'static str) -> UnsyncBoxBody<Bytes, std::io::Error> { + UnsyncBoxBody::new(Full::new(Bytes::from(txt)).map_err(|e| match e {})) +} + +pub(crate) fn serialize<T: dxml::QWrite + Send + 'static>( + status_ok: hyper::StatusCode, + elem: T, +) -> Result<HttpResponse> { + let (tx, rx) = tokio::sync::mpsc::channel::<Bytes>(1); + + // Build the writer + tokio::task::spawn(async move { + let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe)); + let mut writer = SinkWriter::new(CopyToBytes::new(sink)); + let q = quick_xml::writer::Writer::new_with_indent(&mut writer, b' ', 4); + let ns_to_apply = vec![ + ("xmlns:D".into(), "DAV:".into()), + ("xmlns:C".into(), "urn:ietf:params:xml:ns:caldav".into()), + ]; + let mut qwriter = dxml::Writer { q, ns_to_apply }; + let decl = + quick_xml::events::BytesDecl::from_start(quick_xml::events::BytesStart::from_content( + "xml version=\"1.0\" encoding=\"utf-8\"", + 0, + )); + match qwriter + .q + .write_event_async(quick_xml::events::Event::Decl(decl)) + .await + { + Ok(_) => (), + Err(e) => tracing::error!(err=?e, "unable to write XML declaration <?xml ... >"), + } + match elem.qwrite(&mut qwriter).await { + Ok(_) => tracing::debug!("fully serialized object"), + Err(e) => tracing::error!(err=?e, "failed to serialize object"), + } + }); + + // Build the reader + let recv = tokio_stream::wrappers::ReceiverStream::new(rx); + let stream = StreamBody::new(recv.map(|v| Ok(Frame::data(v)))); + let boxed_body = UnsyncBoxBody::new(stream); + + let response = Response::builder() + .status(status_ok) + .header("content-type", "application/xml; charset=\"utf-8\"") + .body(boxed_body)?; + + Ok(response) +} + +/// Deserialize a request body to an XML request +pub(crate) async fn deserialize<T: dxml::Node<T>>(req: Request<Incoming>) -> Result<T> { + let stream_of_frames = BodyStream::new(req.into_body()); + let stream_of_bytes = stream_of_frames + .map_ok(|frame| frame.into_data()) + .map(|obj| match obj { + Ok(Ok(v)) => Ok(v), + Ok(Err(_)) => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "conversion error", + )), + Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)), + }); + let async_read = tokio_util::io::StreamReader::new(stream_of_bytes); + let async_read = std::pin::pin!(async_read); + let mut rdr = dxml::Reader::new(quick_xml::reader::NsReader::from_reader(async_read)).await?; + let parsed = rdr.find::<T>().await?; + Ok(parsed) +} diff --git a/aero-proto/src/dav/controller.rs b/aero-proto/src/dav/controller.rs new file mode 100644 index 0000000..8c53c6b --- /dev/null +++ b/aero-proto/src/dav/controller.rs @@ -0,0 +1,436 @@ +use anyhow::Result; +use futures::stream::{StreamExt, TryStreamExt}; +use http_body_util::combinators::UnsyncBoxBody; +use http_body_util::BodyStream; +use http_body_util::StreamBody; +use hyper::body::Frame; +use hyper::body::Incoming; +use hyper::{body::Bytes, Request, Response}; + +use aero_collections::{davdag::Token, user::User}; +use aero_dav::caltypes as cal; +use aero_dav::realization::{self, All}; +use aero_dav::synctypes as sync; +use aero_dav::types as dav; +use aero_dav::versioningtypes as vers; +use aero_ical::query::is_component_match; + +use crate::dav::codec; +use crate::dav::codec::{depth, deserialize, serialize, text_body}; +use crate::dav::node::DavNode; +use crate::dav::resource::{RootNode, BASE_TOKEN_URI}; + +pub(super) type ArcUser = std::sync::Arc<User>; +pub(super) type HttpResponse = Response<UnsyncBoxBody<Bytes, std::io::Error>>; + +const ALLPROP: [dav::PropertyRequest<All>; 10] = [ + dav::PropertyRequest::CreationDate, + dav::PropertyRequest::DisplayName, + dav::PropertyRequest::GetContentLanguage, + dav::PropertyRequest::GetContentLength, + dav::PropertyRequest::GetContentType, + dav::PropertyRequest::GetEtag, + dav::PropertyRequest::GetLastModified, + dav::PropertyRequest::LockDiscovery, + dav::PropertyRequest::ResourceType, + dav::PropertyRequest::SupportedLock, +]; + +pub(crate) struct Controller { + node: Box<dyn DavNode>, + user: std::sync::Arc<User>, + req: Request<Incoming>, +} +impl Controller { + pub(crate) async fn route( + user: std::sync::Arc<User>, + req: Request<Incoming>, + ) -> Result<HttpResponse> { + let path = req.uri().path().to_string(); + let path_segments: Vec<_> = path.split("/").filter(|s| *s != "").collect(); + let method = req.method().as_str().to_uppercase(); + + let can_create = matches!(method.as_str(), "PUT" | "MKCOL" | "MKCALENDAR"); + let node = match (RootNode {}).fetch(&user, &path_segments, can_create).await { + Ok(v) => v, + Err(e) => { + tracing::warn!(err=?e, "dav node fetch failed"); + return Ok(Response::builder() + .status(404) + .body(codec::text_body("Resource not found"))?); + } + }; + + let dav_hdrs = node.dav_header(); + let ctrl = Self { node, user, req }; + + match method.as_str() { + "OPTIONS" => Ok(Response::builder() + .status(200) + .header("DAV", dav_hdrs) + .header("Allow", "HEAD,GET,PUT,OPTIONS,DELETE,PROPFIND,PROPPATCH,MKCOL,COPY,MOVE,LOCK,UNLOCK,MKCALENDAR,REPORT") + .body(codec::text_body(""))?), + "HEAD" => { + tracing::warn!("HEAD might not correctly implemented: should return ETags & co"); + Ok(Response::builder() + .status(200) + .body(codec::text_body(""))?) + }, + "GET" => ctrl.get().await, + "PUT" => ctrl.put().await, + "DELETE" => ctrl.delete().await, + "PROPFIND" => ctrl.propfind().await, + "REPORT" => ctrl.report().await, + _ => Ok(Response::builder() + .status(501) + .body(codec::text_body("HTTP Method not implemented"))?), + } + } + + // --- Per-method functions --- + + /// REPORT has been first described in the "Versioning Extension" of WebDAV + /// It allows more complex queries compared to PROPFIND + /// + /// Note: current implementation is not generic at all, it is heavily tied to CalDAV. + /// A rewrite would be required to make it more generic (with the extension system that has + /// been introduced in aero-dav) + async fn report(self) -> Result<HttpResponse> { + let status = hyper::StatusCode::from_u16(207)?; + + let cal_report = match deserialize::<vers::Report<All>>(self.req).await { + Ok(v) => v, + Err(e) => { + tracing::error!(err=?e, "unable to decode REPORT body"); + return Ok(Response::builder() + .status(400) + .body(text_body("Bad request"))?); + } + }; + + // Internal representation that will handle processed request + let (mut ok_node, mut not_found) = (Vec::new(), Vec::new()); + let calprop: Option<cal::CalendarSelector<All>>; + let extension: Option<realization::Multistatus>; + + // Extracting request information + match cal_report { + vers::Report::Extension(realization::ReportType::Cal(cal::ReportType::Multiget(m))) => { + // Multiget is really like a propfind where Depth: 0|1|Infinity is replaced by an arbitrary + // list of URLs + // Getting the list of nodes + for h in m.href.into_iter() { + let maybe_collected_node = match Path::new(h.0.as_str()) { + Ok(Path::Abs(p)) => RootNode {} + .fetch(&self.user, p.as_slice(), false) + .await + .or(Err(h)), + Ok(Path::Rel(p)) => self + .node + .fetch(&self.user, p.as_slice(), false) + .await + .or(Err(h)), + Err(_) => Err(h), + }; + + match maybe_collected_node { + Ok(v) => ok_node.push(v), + Err(h) => not_found.push(h), + }; + } + calprop = m.selector; + extension = None; + } + vers::Report::Extension(realization::ReportType::Cal(cal::ReportType::Query(q))) => { + calprop = q.selector; + extension = None; + ok_node = apply_filter(self.node.children(&self.user).await, &q.filter) + .try_collect() + .await?; + } + vers::Report::Extension(realization::ReportType::Sync(sync_col)) => { + calprop = Some(cal::CalendarSelector::Prop(sync_col.prop)); + + if sync_col.limit.is_some() { + tracing::warn!("limit is not supported, ignoring"); + } + if matches!(sync_col.sync_level, sync::SyncLevel::Infinite) { + tracing::debug!("aerogramme calendar collections are not nested"); + } + + let token = match sync_col.sync_token { + sync::SyncTokenRequest::InitialSync => None, + sync::SyncTokenRequest::IncrementalSync(token_raw) => { + // parse token + if token_raw.len() != BASE_TOKEN_URI.len() + 48 { + anyhow::bail!("invalid token length") + } + let token = token_raw[BASE_TOKEN_URI.len()..] + .parse() + .or(Err(anyhow::anyhow!("can't parse token")))?; + Some(token) + } + }; + // do the diff + let new_token: Token; + (new_token, ok_node, not_found) = match self.node.diff(token).await { + Ok(t) => t, + Err(e) => match e.kind() { + std::io::ErrorKind::NotFound => return Ok(Response::builder() + .status(410) + .body(text_body("Diff failed, token might be expired"))?), + _ => return Ok(Response::builder() + .status(500) + .body(text_body("Server error, maybe this operation is not supported on this collection"))?), + }, + }; + extension = Some(realization::Multistatus::Sync(sync::Multistatus { + sync_token: sync::SyncToken(format!("{}{}", BASE_TOKEN_URI, new_token)), + })); + } + _ => { + return Ok(Response::builder() + .status(501) + .body(text_body("Not implemented"))?) + } + }; + + // Getting props + let props = match calprop { + None | Some(cal::CalendarSelector::AllProp) => Some(dav::PropName(ALLPROP.to_vec())), + Some(cal::CalendarSelector::PropName) => None, + Some(cal::CalendarSelector::Prop(inner)) => Some(inner), + }; + + serialize( + status, + Self::multistatus(&self.user, ok_node, not_found, props, extension).await, + ) + } + + /// PROPFIND is the standard way to fetch WebDAV properties + async fn propfind(self) -> Result<HttpResponse> { + let depth = depth(&self.req); + if matches!(depth, dav::Depth::Infinity) { + return Ok(Response::builder() + .status(501) + .body(text_body("Depth: Infinity not implemented"))?); + } + + let status = hyper::StatusCode::from_u16(207)?; + + // A client may choose not to submit a request body. An empty PROPFIND + // request body MUST be treated as if it were an 'allprop' request. + // @FIXME here we handle any invalid data as an allprop, an empty request is thus correctly + // handled, but corrupted requests are also silently handled as allprop. + let propfind = deserialize::<dav::PropFind<All>>(self.req) + .await + .unwrap_or_else(|_| dav::PropFind::<All>::AllProp(None)); + tracing::debug!(recv=?propfind, "inferred propfind request"); + + // Collect nodes as PROPFIND is not limited to the targeted node + let mut nodes = vec![]; + if matches!(depth, dav::Depth::One | dav::Depth::Infinity) { + nodes.extend(self.node.children(&self.user).await); + } + nodes.push(self.node); + + // Expand properties request + let propname = match propfind { + dav::PropFind::PropName => None, + dav::PropFind::AllProp(None) => Some(dav::PropName(ALLPROP.to_vec())), + dav::PropFind::AllProp(Some(dav::Include(mut include))) => { + include.extend_from_slice(&ALLPROP); + Some(dav::PropName(include)) + } + dav::PropFind::Prop(inner) => Some(inner), + }; + + // Not Found is currently impossible considering the way we designed this function + let not_found = vec![]; + serialize( + status, + Self::multistatus(&self.user, nodes, not_found, propname, None).await, + ) + } + + async fn put(self) -> Result<HttpResponse> { + let put_policy = codec::put_policy(&self.req)?; + + let stream_of_frames = BodyStream::new(self.req.into_body()); + let stream_of_bytes = stream_of_frames + .map_ok(|frame| frame.into_data()) + .map(|obj| match obj { + Ok(Ok(v)) => Ok(v), + Ok(Err(_)) => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "conversion error", + )), + Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)), + }) + .boxed(); + + let etag = match self.node.put(put_policy, stream_of_bytes).await { + Ok(etag) => etag, + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { + tracing::warn!("put pre-condition failed"); + let response = Response::builder().status(412).body(text_body(""))?; + return Ok(response); + } + Err(e) => Err(e)?, + }; + + let response = Response::builder() + .status(201) + .header("ETag", etag) + //.header("content-type", "application/xml; charset=\"utf-8\"") + .body(text_body(""))?; + + Ok(response) + } + + async fn get(self) -> Result<HttpResponse> { + let stream_body = StreamBody::new(self.node.content().map_ok(|v| Frame::data(v))); + let boxed_body = UnsyncBoxBody::new(stream_body); + + let mut builder = Response::builder().status(200); + builder = builder.header("content-type", self.node.content_type()); + if let Some(etag) = self.node.etag().await { + builder = builder.header("etag", etag); + } + let response = builder.body(boxed_body)?; + + Ok(response) + } + + async fn delete(self) -> Result<HttpResponse> { + self.node.delete().await?; + let response = Response::builder() + .status(204) + //.header("content-type", "application/xml; charset=\"utf-8\"") + .body(text_body(""))?; + Ok(response) + } + + // --- Common utility functions --- + /// Build a multistatus response from a list of DavNodes + async fn multistatus( + user: &ArcUser, + nodes: Vec<Box<dyn DavNode>>, + not_found: Vec<dav::Href>, + props: Option<dav::PropName<All>>, + extension: Option<realization::Multistatus>, + ) -> dav::Multistatus<All> { + // Collect properties on existing objects + let mut responses: Vec<dav::Response<All>> = match props { + Some(props) => { + futures::stream::iter(nodes) + .then(|n| n.response_props(user, props.clone())) + .collect() + .await + } + None => nodes + .into_iter() + .map(|n| n.response_propname(user)) + .collect(), + }; + + // Register not found objects only if relevant + if !not_found.is_empty() { + responses.push(dav::Response { + status_or_propstat: dav::StatusOrPropstat::Status( + not_found, + dav::Status(hyper::StatusCode::NOT_FOUND), + ), + error: None, + location: None, + responsedescription: None, + }); + } + + // Build response + let multistatus = dav::Multistatus::<All> { + responses, + responsedescription: None, + extension, + }; + + tracing::debug!(multistatus=?multistatus, "multistatus response"); + multistatus + } +} + +/// Path is a voluntarily feature limited +/// compared to the expressiveness of a UNIX path +/// For example getting parent with ../ is not supported, scheme is not supported, etc. +/// More complex support could be added later if needed by clients +enum Path<'a> { + Abs(Vec<&'a str>), + Rel(Vec<&'a str>), +} +impl<'a> Path<'a> { + fn new(path: &'a str) -> Result<Self> { + // This check is naive, it does not aim at detecting all fully qualified + // URL or protect from any attack, its only goal is to help debugging. + if path.starts_with("http://") || path.starts_with("https://") { + anyhow::bail!("Full URL are not supported") + } + + let path_segments: Vec<_> = path.split("/").filter(|s| *s != "" && *s != ".").collect(); + if path.starts_with("/") { + return Ok(Path::Abs(path_segments)); + } + Ok(Path::Rel(path_segments)) + } +} + +//@FIXME naive implementation, must be refactored later +use futures::stream::Stream; +fn apply_filter<'a>( + nodes: Vec<Box<dyn DavNode>>, + filter: &'a cal::Filter, +) -> impl Stream<Item = std::result::Result<Box<dyn DavNode>, std::io::Error>> + 'a { + futures::stream::iter(nodes).filter_map(move |single_node| async move { + // Get ICS + let chunks: Vec<_> = match single_node.content().try_collect().await { + Ok(v) => v, + Err(e) => return Some(Err(e)), + }; + let raw_ics = chunks.iter().fold(String::new(), |mut acc, single_chunk| { + let str_fragment = std::str::from_utf8(single_chunk.as_ref()); + acc.extend(str_fragment); + acc + }); + + // Parse ICS + let ics = match icalendar::parser::read_calendar(&raw_ics) { + Ok(v) => v, + Err(e) => { + tracing::warn!(err=?e, "Unable to parse ICS in calendar-query"); + return Some(Err(std::io::Error::from(std::io::ErrorKind::InvalidData))); + } + }; + + // Do checks + // @FIXME: icalendar does not consider VCALENDAR as a component + // but WebDAV does... + // Build a fake VCALENDAR component for icalendar compatibility, it's a hack + let root_filter = &filter.0; + let fake_vcal_component = icalendar::parser::Component { + name: cal::Component::VCalendar.as_str().into(), + properties: ics.properties, + components: ics.components, + }; + tracing::debug!(filter=?root_filter, "calendar-query filter"); + + // Adjust return value according to filter + match is_component_match( + &fake_vcal_component, + &[fake_vcal_component.clone()], + root_filter, + ) { + true => Some(Ok(single_node)), + _ => None, + } + }) +} diff --git a/aero-proto/src/dav/middleware.rs b/aero-proto/src/dav/middleware.rs new file mode 100644 index 0000000..8964699 --- /dev/null +++ b/aero-proto/src/dav/middleware.rs @@ -0,0 +1,70 @@ +use anyhow::{anyhow, Result}; +use base64::Engine; +use hyper::body::Incoming; +use hyper::{Request, Response}; + +use aero_collections::user::User; +use aero_user::login::ArcLoginProvider; + +use super::codec::text_body; +use super::controller::HttpResponse; + +type ArcUser = std::sync::Arc<User>; + +pub(super) async fn auth<'a>( + login: ArcLoginProvider, + req: Request<Incoming>, + next: impl Fn(ArcUser, Request<Incoming>) -> futures::future::BoxFuture<'a, Result<HttpResponse>>, +) -> Result<HttpResponse> { + let auth_val = match req.headers().get(hyper::header::AUTHORIZATION) { + Some(hv) => hv.to_str()?, + None => { + tracing::info!("Missing authorization field"); + return Ok(Response::builder() + .status(401) + .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"") + .body(text_body("Missing Authorization field"))?); + } + }; + + let b64_creds_maybe_padded = match auth_val.split_once(" ") { + Some(("Basic", b64)) => b64, + _ => { + tracing::info!("Unsupported authorization field"); + return Ok(Response::builder() + .status(400) + .body(text_body("Unsupported Authorization field"))?); + } + }; + + // base64urlencoded may have trailing equals, base64urlsafe has not + // theoretically authorization is padded but "be liberal in what you accept" + let b64_creds_clean = b64_creds_maybe_padded.trim_end_matches('='); + + // Decode base64 + let creds = base64::engine::general_purpose::STANDARD_NO_PAD.decode(b64_creds_clean)?; + let str_creds = std::str::from_utf8(&creds)?; + + // Split username and password + let (username, password) = str_creds.split_once(':').ok_or(anyhow!( + "Missing colon in Authorization, can't split decoded value into a username/password pair" + ))?; + + // Call login provider + let creds = match login.login(username, password).await { + Ok(c) => c, + Err(_) => { + tracing::info!(user = username, "Wrong credentials"); + return Ok(Response::builder() + .status(401) + .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"") + .body(text_body("Wrong credentials"))?); + } + }; + + // Build a user + let user = User::new(username.into(), creds).await?; + + // Call router with user + next(user, req).await +} diff --git a/aero-proto/src/dav/mod.rs b/aero-proto/src/dav/mod.rs new file mode 100644 index 0000000..a3dd58d --- /dev/null +++ b/aero-proto/src/dav/mod.rs @@ -0,0 +1,195 @@ +mod codec; +mod controller; +mod middleware; +mod node; +mod resource; + +use std::net::SocketAddr; +use std::sync::Arc; + +use anyhow::Result; +use futures::future::FutureExt; +use futures::stream::{FuturesUnordered, StreamExt}; +use hyper::rt::{Read, Write}; +use hyper::server::conn::http1 as http; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; +use rustls_pemfile::{certs, private_key}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::watch; +use tokio_rustls::TlsAcceptor; + +use aero_user::config::{DavConfig, DavUnsecureConfig}; +use aero_user::login::ArcLoginProvider; + +use crate::dav::controller::Controller; + +pub struct Server { + bind_addr: SocketAddr, + login_provider: ArcLoginProvider, + tls: Option<TlsAcceptor>, +} + +pub fn new_unsecure(config: DavUnsecureConfig, login: ArcLoginProvider) -> Server { + Server { + bind_addr: config.bind_addr, + login_provider: login, + tls: None, + } +} + +pub fn new(config: DavConfig, login: ArcLoginProvider) -> Result<Server> { + let loaded_certs = certs(&mut std::io::BufReader::new(std::fs::File::open( + config.certs, + )?)) + .collect::<Result<Vec<_>, _>>()?; + let loaded_key = private_key(&mut std::io::BufReader::new(std::fs::File::open( + config.key, + )?))? + .unwrap(); + + let tls_config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(loaded_certs, loaded_key)?; + let acceptor = TlsAcceptor::from(Arc::new(tls_config)); + + Ok(Server { + bind_addr: config.bind_addr, + login_provider: login, + tls: Some(acceptor), + }) +} + +trait Stream: Read + Write + Send + Unpin {} +impl<T: Unpin + AsyncRead + AsyncWrite + Send> Stream for TokioIo<T> {} + +impl Server { + pub async fn run(self: Self, mut must_exit: watch::Receiver<bool>) -> Result<()> { + let tcp = TcpListener::bind(self.bind_addr).await?; + tracing::info!("DAV server listening on {:#}", self.bind_addr); + + let mut connections = FuturesUnordered::new(); + while !*must_exit.borrow() { + let wait_conn_finished = async { + if connections.is_empty() { + futures::future::pending().await + } else { + connections.next().await + } + }; + let (socket, remote_addr) = tokio::select! { + a = tcp.accept() => a?, + _ = wait_conn_finished => continue, + _ = must_exit.changed() => continue, + }; + tracing::info!("Accepted connection from {}", remote_addr); + let stream = match self.build_stream(socket).await { + Ok(v) => v, + Err(e) => { + tracing::error!(err=?e, "TLS acceptor failed"); + continue; + } + }; + + let login = self.login_provider.clone(); + let conn = tokio::spawn(async move { + //@FIXME should create a generic "public web" server on which "routers" could be + //abitrarily bound + //@FIXME replace with a handler supporting http2 + + match http::Builder::new() + .serve_connection( + stream, + service_fn(|req: Request<hyper::body::Incoming>| { + let login = login.clone(); + tracing::info!("{:?} {:?}", req.method(), req.uri()); + tracing::debug!(req=?req, "full request"); + async { + let response = match middleware::auth(login, req, |user, request| { + async { Controller::route(user, request).await }.boxed() + }) + .await + { + Ok(v) => Ok(v), + Err(e) => { + tracing::error!(err=?e, "internal error"); + Response::builder() + .status(500) + .body(codec::text_body("Internal error")) + } + }; + tracing::debug!(resp=?response, "full response"); + response + } + }), + ) + .await + { + Err(e) => tracing::warn!(err=?e, "connection failed"), + Ok(()) => tracing::trace!("connection terminated with success"), + } + }); + connections.push(conn); + } + drop(tcp); + + tracing::info!("Server shutting down, draining remaining connections..."); + while connections.next().await.is_some() {} + + Ok(()) + } + + async fn build_stream(&self, socket: TcpStream) -> Result<Box<dyn Stream>> { + match self.tls.clone() { + Some(acceptor) => { + let stream = acceptor.accept(socket).await?; + Ok(Box::new(TokioIo::new(stream))) + } + None => Ok(Box::new(TokioIo::new(socket))), + } + } +} + +// <D:propfind xmlns:D='DAV:' xmlns:A='http://apple.com/ns/ical/'> +// <D:prop> +// <D:getcontenttype/> +// <D:resourcetype/> +// <D:displayname/> +// <A:calendar-color/> +// </D:prop> +// </D:propfind> + +// <D:propfind xmlns:D='DAV:' xmlns:A='http://apple.com/ns/ical/' xmlns:C='urn:ietf:params:xml:ns:caldav'> +// <D:prop> +// <D:resourcetype/> +// <D:owner/> +// <D:displayname/> +// <D:current-user-principal/> +// <D:current-user-privilege-set/> +// <A:calendar-color/> +// <C:calendar-home-set/> +// </D:prop> +// </D:propfind> + +// <D:propfind xmlns:D='DAV:' xmlns:C='urn:ietf:params:xml:ns:caldav' xmlns:CS='http://calendarserver.org/ns/'> +// <D:prop> +// <D:resourcetype/> +// <D:owner/> +// <D:current-user-principal/> +// <D:current-user-privilege-set/> +// <D:supported-report-set/> +// <C:supported-calendar-component-set/> +// <CS:getctag/> +// </D:prop> +// </D:propfind> + +// <C:calendar-multiget xmlns:D="DAV:" xmlns:C="urn:ietf:params:xml:ns:caldav"> +// <D:prop> +// <D:getetag/> +// <C:calendar-data/> +// </D:prop> +// <D:href>/alice/calendar/personal/something.ics</D:href> +// </C:calendar-multiget> diff --git a/aero-proto/src/dav/node.rs b/aero-proto/src/dav/node.rs new file mode 100644 index 0000000..3af3b81 --- /dev/null +++ b/aero-proto/src/dav/node.rs @@ -0,0 +1,145 @@ +use anyhow::Result; +use futures::future::{BoxFuture, FutureExt}; +use futures::stream::{BoxStream, StreamExt}; +use hyper::body::Bytes; + +use aero_collections::davdag::{Etag, Token}; +use aero_dav::realization::All; +use aero_dav::types as dav; + +use super::controller::ArcUser; + +pub(crate) type Content<'a> = BoxStream<'a, std::result::Result<Bytes, std::io::Error>>; +pub(crate) type PropertyStream<'a> = + BoxStream<'a, std::result::Result<dav::Property<All>, dav::PropertyRequest<All>>>; + +pub(crate) enum PutPolicy { + OverwriteAll, + CreateOnly, + ReplaceEtag(String), +} + +/// A DAV node should implement the following methods +/// @FIXME not satisfied by BoxFutures but I have no better idea currently +pub(crate) trait DavNode: Send { + // recurence, filesystem hierarchy + /// This node direct children + fn children<'a>(&self, user: &'a ArcUser) -> BoxFuture<'a, Vec<Box<dyn DavNode>>>; + /// Recursively fetch a child (progress inside the filesystem hierarchy) + fn fetch<'a>( + &self, + user: &'a ArcUser, + path: &'a [&str], + create: bool, + ) -> BoxFuture<'a, Result<Box<dyn DavNode>>>; + + // node properties + /// Get the path + fn path(&self, user: &ArcUser) -> String; + /// Get the supported WebDAV properties + fn supported_properties(&self, user: &ArcUser) -> dav::PropName<All>; + /// Get the values for the given properties + fn properties(&self, user: &ArcUser, prop: dav::PropName<All>) -> PropertyStream<'static>; + /// Get the value of the DAV header to return + fn dav_header(&self) -> String; + + /// Put an element (create or update) + fn put<'a>( + &'a self, + policy: PutPolicy, + stream: Content<'a>, + ) -> BoxFuture<'a, std::result::Result<Etag, std::io::Error>>; + /// Content type of the element + fn content_type(&self) -> &str; + /// Get ETag + fn etag(&self) -> BoxFuture<Option<Etag>>; + /// Get content + fn content<'a>(&self) -> Content<'a>; + /// Delete + fn delete(&self) -> BoxFuture<std::result::Result<(), std::io::Error>>; + /// Sync + fn diff<'a>( + &self, + sync_token: Option<Token>, + ) -> BoxFuture< + 'a, + std::result::Result<(Token, Vec<Box<dyn DavNode>>, Vec<dav::Href>), std::io::Error>, + >; + + /// Utility function to get a propname response from a node + fn response_propname(&self, user: &ArcUser) -> dav::Response<All> { + dav::Response { + status_or_propstat: dav::StatusOrPropstat::PropStat( + dav::Href(self.path(user)), + vec![dav::PropStat { + status: dav::Status(hyper::StatusCode::OK), + prop: dav::AnyProp( + self.supported_properties(user) + .0 + .into_iter() + .map(dav::AnyProperty::Request) + .collect(), + ), + error: None, + responsedescription: None, + }], + ), + error: None, + location: None, + responsedescription: None, + } + } + + /// Utility function to get a prop response from a node & a list of propname + fn response_props( + &self, + user: &ArcUser, + props: dav::PropName<All>, + ) -> BoxFuture<'static, dav::Response<All>> { + //@FIXME we should make the DAV parsed object a stream... + let mut result_stream = self.properties(user, props); + let path = self.path(user); + + async move { + let mut prop_desc = vec![]; + let (mut found, mut not_found) = (vec![], vec![]); + while let Some(maybe_prop) = result_stream.next().await { + match maybe_prop { + Ok(v) => found.push(dav::AnyProperty::Value(v)), + Err(v) => not_found.push(dav::AnyProperty::Request(v)), + } + } + + // If at least one property has been found on this object, adding a HTTP 200 propstat to + // the response + if !found.is_empty() { + prop_desc.push(dav::PropStat { + status: dav::Status(hyper::StatusCode::OK), + prop: dav::AnyProp(found), + error: None, + responsedescription: None, + }); + } + + // If at least one property can't be found on this object, adding a HTTP 404 propstat to + // the response + if !not_found.is_empty() { + prop_desc.push(dav::PropStat { + status: dav::Status(hyper::StatusCode::NOT_FOUND), + prop: dav::AnyProp(not_found), + error: None, + responsedescription: None, + }) + } + + // Build the finale response + dav::Response { + status_or_propstat: dav::StatusOrPropstat::PropStat(dav::Href(path), prop_desc), + error: None, + location: None, + responsedescription: None, + } + } + .boxed() + } +} diff --git a/aero-proto/src/dav/resource.rs b/aero-proto/src/dav/resource.rs new file mode 100644 index 0000000..b5ae029 --- /dev/null +++ b/aero-proto/src/dav/resource.rs @@ -0,0 +1,999 @@ +use std::sync::Arc; +type ArcUser = std::sync::Arc<User>; + +use anyhow::{anyhow, Result}; +use futures::io::AsyncReadExt; +use futures::stream::{StreamExt, TryStreamExt}; +use futures::{future::BoxFuture, future::FutureExt}; + +use aero_collections::{ + calendar::Calendar, + davdag::{BlobId, Etag, SyncChange, Token}, + user::User, +}; +use aero_dav::acltypes as acl; +use aero_dav::caltypes as cal; +use aero_dav::realization::{self as all, All}; +use aero_dav::synctypes as sync; +use aero_dav::types as dav; +use aero_dav::versioningtypes as vers; + +use super::node::PropertyStream; +use crate::dav::node::{Content, DavNode, PutPolicy}; + +/// Why "https://aerogramme.0"? +/// Because tokens must be valid URI. +/// And numeric TLD are ~mostly valid in URI (check the .42 TLD experience) +/// and at the same time, they are not used sold by the ICANN and there is no plan to use them. +/// So I am sure that the URL remains invalid, avoiding leaking requests to an hardcoded URL in the +/// future. +/// The best option would be to make it configurable ofc, so someone can put a domain name +/// that they control, it would probably improve compatibility (maybe some WebDAV spec tells us +/// how to handle/resolve this URI but I am not aware of that...). But that's not the plan for +/// now. So here we are: https://aerogramme.0. +pub const BASE_TOKEN_URI: &str = "https://aerogramme.0/sync/"; + +#[derive(Clone)] +pub(crate) struct RootNode {} +impl DavNode for RootNode { + fn fetch<'a>( + &self, + user: &'a ArcUser, + path: &'a [&str], + create: bool, + ) -> BoxFuture<'a, Result<Box<dyn DavNode>>> { + if path.len() == 0 { + let this = self.clone(); + return async { Ok(Box::new(this) as Box<dyn DavNode>) }.boxed(); + } + + if path[0] == user.username { + let child = Box::new(HomeNode {}); + return child.fetch(user, &path[1..], create); + } + + //@NOTE: We can't create a node at this level + async { Err(anyhow!("Not found")) }.boxed() + } + + fn children<'a>(&self, user: &'a ArcUser) -> BoxFuture<'a, Vec<Box<dyn DavNode>>> { + async { vec![Box::new(HomeNode {}) as Box<dyn DavNode>] }.boxed() + } + + fn path(&self, user: &ArcUser) -> String { + "/".into() + } + + fn supported_properties(&self, user: &ArcUser) -> dav::PropName<All> { + dav::PropName(vec![ + dav::PropertyRequest::DisplayName, + dav::PropertyRequest::ResourceType, + dav::PropertyRequest::GetContentType, + dav::PropertyRequest::Extension(all::PropertyRequest::Acl( + acl::PropertyRequest::CurrentUserPrincipal, + )), + ]) + } + + fn properties(&self, user: &ArcUser, prop: dav::PropName<All>) -> PropertyStream<'static> { + let user = user.clone(); + futures::stream::iter(prop.0) + .map(move |n| { + let prop = match n { + dav::PropertyRequest::DisplayName => { + dav::Property::DisplayName("DAV Root".to_string()) + } + dav::PropertyRequest::ResourceType => { + dav::Property::ResourceType(vec![dav::ResourceType::Collection]) + } + dav::PropertyRequest::GetContentType => { + dav::Property::GetContentType("httpd/unix-directory".into()) + } + dav::PropertyRequest::Extension(all::PropertyRequest::Acl( + acl::PropertyRequest::CurrentUserPrincipal, + )) => dav::Property::Extension(all::Property::Acl( + acl::Property::CurrentUserPrincipal(acl::User::Authenticated(dav::Href( + HomeNode {}.path(&user), + ))), + )), + v => return Err(v), + }; + Ok(prop) + }) + .boxed() + } + + fn put<'a>( + &'a self, + _policy: PutPolicy, + stream: Content<'a>, + ) -> BoxFuture<'a, std::result::Result<Etag, std::io::Error>> { + futures::future::err(std::io::Error::from(std::io::ErrorKind::Unsupported)).boxed() + } + + fn content<'a>(&self) -> Content<'a> { + futures::stream::once(futures::future::err(std::io::Error::from( + std::io::ErrorKind::Unsupported, + ))) + .boxed() + } + + fn content_type(&self) -> &str { + "text/plain" + } + + fn etag(&self) -> BoxFuture<Option<Etag>> { + async { None }.boxed() + } + + fn delete(&self) -> BoxFuture<std::result::Result<(), std::io::Error>> { + async { Err(std::io::Error::from(std::io::ErrorKind::PermissionDenied)) }.boxed() + } + + fn diff<'a>( + &self, + _sync_token: Option<Token>, + ) -> BoxFuture< + 'a, + std::result::Result<(Token, Vec<Box<dyn DavNode>>, Vec<dav::Href>), std::io::Error>, + > { + async { Err(std::io::Error::from(std::io::ErrorKind::Unsupported)) }.boxed() + } + + fn dav_header(&self) -> String { + "1".into() + } +} + +#[derive(Clone)] +pub(crate) struct HomeNode {} +impl DavNode for HomeNode { + fn fetch<'a>( + &self, + user: &'a ArcUser, + path: &'a [&str], + create: bool, + ) -> BoxFuture<'a, Result<Box<dyn DavNode>>> { + if path.len() == 0 { + let node = Box::new(self.clone()) as Box<dyn DavNode>; + return async { Ok(node) }.boxed(); + } + + if path[0] == "calendar" { + return async move { + let child = Box::new(CalendarListNode::new(user).await?); + child.fetch(user, &path[1..], create).await + } + .boxed(); + } + + //@NOTE: we can't create a node at this level + async { Err(anyhow!("Not found")) }.boxed() + } + + fn children<'a>(&self, user: &'a ArcUser) -> BoxFuture<'a, Vec<Box<dyn DavNode>>> { + async { + CalendarListNode::new(user) + .await + .map(|c| vec![Box::new(c) as Box<dyn DavNode>]) + .unwrap_or(vec![]) + } + .boxed() + } + + fn path(&self, user: &ArcUser) -> String { + format!("/{}/", user.username) + } + + fn supported_properties(&self, user: &ArcUser) -> dav::PropName<All> { + dav::PropName(vec![ + dav::PropertyRequest::DisplayName, + dav::PropertyRequest::ResourceType, + dav::PropertyRequest::GetContentType, + dav::PropertyRequest::Extension(all::PropertyRequest::Cal( + cal::PropertyRequest::CalendarHomeSet, + )), + ]) + } + fn properties(&self, user: &ArcUser, prop: dav::PropName<All>) -> PropertyStream<'static> { + let user = user.clone(); + + futures::stream::iter(prop.0) + .map(move |n| { + let prop = match n { + dav::PropertyRequest::DisplayName => { + dav::Property::DisplayName(format!("{} home", user.username)) + } + dav::PropertyRequest::ResourceType => dav::Property::ResourceType(vec![ + dav::ResourceType::Collection, + dav::ResourceType::Extension(all::ResourceType::Acl( + acl::ResourceType::Principal, + )), + ]), + dav::PropertyRequest::GetContentType => { + dav::Property::GetContentType("httpd/unix-directory".into()) + } + dav::PropertyRequest::Extension(all::PropertyRequest::Cal( + cal::PropertyRequest::CalendarHomeSet, + )) => dav::Property::Extension(all::Property::Cal( + cal::Property::CalendarHomeSet(dav::Href( + //@FIXME we are hardcoding the calendar path, instead we would want to use + //objects + format!("/{}/calendar/", user.username), + )), + )), + v => return Err(v), + }; + Ok(prop) + }) + .boxed() + } + + fn put<'a>( + &'a self, + _policy: PutPolicy, + stream: Content<'a>, + ) -> BoxFuture<'a, std::result::Result<Etag, std::io::Error>> { + futures::future::err(std::io::Error::from(std::io::ErrorKind::Unsupported)).boxed() + } + + fn content<'a>(&self) -> Content<'a> { + futures::stream::once(futures::future::err(std::io::Error::from( + std::io::ErrorKind::Unsupported, + ))) + .boxed() + } + + fn content_type(&self) -> &str { + "text/plain" + } + + fn etag(&self) -> BoxFuture<Option<Etag>> { + async { None }.boxed() + } + + fn delete(&self) -> BoxFuture<std::result::Result<(), std::io::Error>> { + async { Err(std::io::Error::from(std::io::ErrorKind::PermissionDenied)) }.boxed() + } + fn diff<'a>( + &self, + _sync_token: Option<Token>, + ) -> BoxFuture< + 'a, + std::result::Result<(Token, Vec<Box<dyn DavNode>>, Vec<dav::Href>), std::io::Error>, + > { + async { Err(std::io::Error::from(std::io::ErrorKind::Unsupported)) }.boxed() + } + + fn dav_header(&self) -> String { + "1, access-control, calendar-access".into() + } +} + +#[derive(Clone)] +pub(crate) struct CalendarListNode { + list: Vec<String>, +} +impl CalendarListNode { + async fn new(user: &ArcUser) -> Result<Self> { + let list = user.calendars.list(user).await?; + Ok(Self { list }) + } +} +impl DavNode for CalendarListNode { + fn fetch<'a>( + &self, + user: &'a ArcUser, + path: &'a [&str], + create: bool, + ) -> BoxFuture<'a, Result<Box<dyn DavNode>>> { + if path.len() == 0 { + let node = Box::new(self.clone()) as Box<dyn DavNode>; + return async { Ok(node) }.boxed(); + } + + async move { + //@FIXME: we should create a node if the open returns a "not found". + let cal = user + .calendars + .open(user, path[0]) + .await? + .ok_or(anyhow!("Not found"))?; + let child = Box::new(CalendarNode { + col: cal, + calname: path[0].to_string(), + }); + child.fetch(user, &path[1..], create).await + } + .boxed() + } + + fn children<'a>(&self, user: &'a ArcUser) -> BoxFuture<'a, Vec<Box<dyn DavNode>>> { + let list = self.list.clone(); + async move { + //@FIXME maybe we want to be lazy here?! + futures::stream::iter(list.iter()) + .filter_map(|name| async move { + user.calendars + .open(user, name) + .await + .ok() + .flatten() + .map(|v| (name, v)) + }) + .map(|(name, cal)| { + Box::new(CalendarNode { + col: cal, + calname: name.to_string(), + }) as Box<dyn DavNode> + }) + .collect::<Vec<Box<dyn DavNode>>>() + .await + } + .boxed() + } + + fn path(&self, user: &ArcUser) -> String { + format!("/{}/calendar/", user.username) + } + + fn supported_properties(&self, user: &ArcUser) -> dav::PropName<All> { + dav::PropName(vec![ + dav::PropertyRequest::DisplayName, + dav::PropertyRequest::ResourceType, + dav::PropertyRequest::GetContentType, + ]) + } + fn properties(&self, user: &ArcUser, prop: dav::PropName<All>) -> PropertyStream<'static> { + let user = user.clone(); + + futures::stream::iter(prop.0) + .map(move |n| { + let prop = match n { + dav::PropertyRequest::DisplayName => { + dav::Property::DisplayName(format!("{} calendars", user.username)) + } + dav::PropertyRequest::ResourceType => { + dav::Property::ResourceType(vec![dav::ResourceType::Collection]) + } + dav::PropertyRequest::GetContentType => { + dav::Property::GetContentType("httpd/unix-directory".into()) + } + v => return Err(v), + }; + Ok(prop) + }) + .boxed() + } + + fn put<'a>( + &'a self, + _policy: PutPolicy, + stream: Content<'a>, + ) -> BoxFuture<'a, std::result::Result<Etag, std::io::Error>> { + futures::future::err(std::io::Error::from(std::io::ErrorKind::Unsupported)).boxed() + } + + fn content<'a>(&self) -> Content<'a> { + futures::stream::once(futures::future::err(std::io::Error::from( + std::io::ErrorKind::Unsupported, + ))) + .boxed() + } + + fn content_type(&self) -> &str { + "text/plain" + } + + fn etag(&self) -> BoxFuture<Option<Etag>> { + async { None }.boxed() + } + + fn delete(&self) -> BoxFuture<std::result::Result<(), std::io::Error>> { + async { Err(std::io::Error::from(std::io::ErrorKind::PermissionDenied)) }.boxed() + } + fn diff<'a>( + &self, + _sync_token: Option<Token>, + ) -> BoxFuture< + 'a, + std::result::Result<(Token, Vec<Box<dyn DavNode>>, Vec<dav::Href>), std::io::Error>, + > { + async { Err(std::io::Error::from(std::io::ErrorKind::Unsupported)) }.boxed() + } + + fn dav_header(&self) -> String { + "1, access-control, calendar-access".into() + } +} + +#[derive(Clone)] +pub(crate) struct CalendarNode { + col: Arc<Calendar>, + calname: String, +} +impl DavNode for CalendarNode { + fn fetch<'a>( + &self, + user: &'a ArcUser, + path: &'a [&str], + create: bool, + ) -> BoxFuture<'a, Result<Box<dyn DavNode>>> { + if path.len() == 0 { + let node = Box::new(self.clone()) as Box<dyn DavNode>; + return async { Ok(node) }.boxed(); + } + + let col = self.col.clone(); + let calname = self.calname.clone(); + async move { + match (col.dag().await.idx_by_filename.get(path[0]), create) { + (Some(blob_id), _) => { + let child = Box::new(EventNode { + col: col.clone(), + calname, + filename: path[0].to_string(), + blob_id: *blob_id, + }); + child.fetch(user, &path[1..], create).await + } + (None, true) => { + let child = Box::new(CreateEventNode { + col: col.clone(), + calname, + filename: path[0].to_string(), + }); + child.fetch(user, &path[1..], create).await + } + _ => Err(anyhow!("Not found")), + } + } + .boxed() + } + + fn children<'a>(&self, user: &'a ArcUser) -> BoxFuture<'a, Vec<Box<dyn DavNode>>> { + let col = self.col.clone(); + let calname = self.calname.clone(); + + async move { + col.dag() + .await + .idx_by_filename + .iter() + .map(|(filename, blob_id)| { + Box::new(EventNode { + col: col.clone(), + calname: calname.clone(), + filename: filename.to_string(), + blob_id: *blob_id, + }) as Box<dyn DavNode> + }) + .collect() + } + .boxed() + } + + fn path(&self, user: &ArcUser) -> String { + format!("/{}/calendar/{}/", user.username, self.calname) + } + + fn supported_properties(&self, user: &ArcUser) -> dav::PropName<All> { + dav::PropName(vec![ + dav::PropertyRequest::DisplayName, + dav::PropertyRequest::ResourceType, + dav::PropertyRequest::GetContentType, + dav::PropertyRequest::Extension(all::PropertyRequest::Cal( + cal::PropertyRequest::SupportedCalendarComponentSet, + )), + dav::PropertyRequest::Extension(all::PropertyRequest::Sync( + sync::PropertyRequest::SyncToken, + )), + dav::PropertyRequest::Extension(all::PropertyRequest::Vers( + vers::PropertyRequest::SupportedReportSet, + )), + ]) + } + fn properties(&self, _user: &ArcUser, prop: dav::PropName<All>) -> PropertyStream<'static> { + let calname = self.calname.to_string(); + let col = self.col.clone(); + + futures::stream::iter(prop.0) + .then(move |n| { + let calname = calname.clone(); + let col = col.clone(); + + async move { + let prop = match n { + dav::PropertyRequest::DisplayName => { + dav::Property::DisplayName(format!("{} calendar", calname)) + } + dav::PropertyRequest::ResourceType => dav::Property::ResourceType(vec![ + dav::ResourceType::Collection, + dav::ResourceType::Extension(all::ResourceType::Cal( + cal::ResourceType::Calendar, + )), + ]), + //dav::PropertyRequest::GetContentType => dav::AnyProperty::Value(dav::Property::GetContentType("httpd/unix-directory".into())), + //@FIXME seems wrong but seems to be what Thunderbird expects... + dav::PropertyRequest::GetContentType => { + dav::Property::GetContentType("text/calendar".into()) + } + dav::PropertyRequest::Extension(all::PropertyRequest::Cal( + cal::PropertyRequest::SupportedCalendarComponentSet, + )) => dav::Property::Extension(all::Property::Cal( + cal::Property::SupportedCalendarComponentSet(vec![ + cal::CompSupport(cal::Component::VEvent), + cal::CompSupport(cal::Component::VTodo), + cal::CompSupport(cal::Component::VJournal), + ]), + )), + dav::PropertyRequest::Extension(all::PropertyRequest::Sync( + sync::PropertyRequest::SyncToken, + )) => match col.token().await { + Ok(token) => dav::Property::Extension(all::Property::Sync( + sync::Property::SyncToken(sync::SyncToken(format!( + "{}{}", + BASE_TOKEN_URI, token + ))), + )), + _ => return Err(n.clone()), + }, + dav::PropertyRequest::Extension(all::PropertyRequest::Vers( + vers::PropertyRequest::SupportedReportSet, + )) => dav::Property::Extension(all::Property::Vers( + vers::Property::SupportedReportSet(vec![ + vers::SupportedReport(vers::ReportName::Extension( + all::ReportTypeName::Cal(cal::ReportTypeName::Multiget), + )), + vers::SupportedReport(vers::ReportName::Extension( + all::ReportTypeName::Cal(cal::ReportTypeName::Query), + )), + vers::SupportedReport(vers::ReportName::Extension( + all::ReportTypeName::Sync(sync::ReportTypeName::SyncCollection), + )), + ]), + )), + v => return Err(v), + }; + Ok(prop) + } + }) + .boxed() + } + + fn put<'a>( + &'a self, + _policy: PutPolicy, + _stream: Content<'a>, + ) -> BoxFuture<'a, std::result::Result<Etag, std::io::Error>> { + futures::future::err(std::io::Error::from(std::io::ErrorKind::Unsupported)).boxed() + } + + fn content<'a>(&self) -> Content<'a> { + futures::stream::once(futures::future::err(std::io::Error::from( + std::io::ErrorKind::Unsupported, + ))) + .boxed() + } + + fn content_type(&self) -> &str { + "text/plain" + } + + fn etag(&self) -> BoxFuture<Option<Etag>> { + async { None }.boxed() + } + + fn delete(&self) -> BoxFuture<std::result::Result<(), std::io::Error>> { + async { Err(std::io::Error::from(std::io::ErrorKind::PermissionDenied)) }.boxed() + } + fn diff<'a>( + &self, + sync_token: Option<Token>, + ) -> BoxFuture< + 'a, + std::result::Result<(Token, Vec<Box<dyn DavNode>>, Vec<dav::Href>), std::io::Error>, + > { + let col = self.col.clone(); + let calname = self.calname.clone(); + async move { + let sync_token = match sync_token { + Some(v) => v, + None => { + let token = col + .token() + .await + .or(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)))?; + let ok_nodes = col + .dag() + .await + .idx_by_filename + .iter() + .map(|(filename, blob_id)| { + Box::new(EventNode { + col: col.clone(), + calname: calname.clone(), + filename: filename.to_string(), + blob_id: *blob_id, + }) as Box<dyn DavNode> + }) + .collect(); + + return Ok((token, ok_nodes, vec![])); + } + }; + let (new_token, listed_changes) = match col.diff(sync_token).await { + Ok(v) => v, + Err(e) => { + tracing::info!(err=?e, "token resolution failed, maybe a forgotten token"); + return Err(std::io::Error::from(std::io::ErrorKind::NotFound)); + } + }; + + let mut ok_nodes: Vec<Box<dyn DavNode>> = vec![]; + let mut rm_nodes: Vec<dav::Href> = vec![]; + for change in listed_changes.into_iter() { + match change { + SyncChange::Ok((filename, blob_id)) => { + let child = Box::new(EventNode { + col: col.clone(), + calname: calname.clone(), + filename, + blob_id, + }); + ok_nodes.push(child); + } + SyncChange::NotFound(filename) => { + rm_nodes.push(dav::Href(filename)); + } + } + } + + Ok((new_token, ok_nodes, rm_nodes)) + } + .boxed() + } + fn dav_header(&self) -> String { + "1, access-control, calendar-access".into() + } +} + +#[derive(Clone)] +pub(crate) struct EventNode { + col: Arc<Calendar>, + calname: String, + filename: String, + blob_id: BlobId, +} + +impl DavNode for EventNode { + fn fetch<'a>( + &self, + user: &'a ArcUser, + path: &'a [&str], + create: bool, + ) -> BoxFuture<'a, Result<Box<dyn DavNode>>> { + if path.len() == 0 { + let node = Box::new(self.clone()) as Box<dyn DavNode>; + return async { Ok(node) }.boxed(); + } + + async { + Err(anyhow!( + "Not supported: can't create a child on an event node" + )) + } + .boxed() + } + + fn children<'a>(&self, user: &'a ArcUser) -> BoxFuture<'a, Vec<Box<dyn DavNode>>> { + async { vec![] }.boxed() + } + + fn path(&self, user: &ArcUser) -> String { + format!( + "/{}/calendar/{}/{}", + user.username, self.calname, self.filename + ) + } + + fn supported_properties(&self, user: &ArcUser) -> dav::PropName<All> { + dav::PropName(vec![ + dav::PropertyRequest::DisplayName, + dav::PropertyRequest::ResourceType, + dav::PropertyRequest::GetEtag, + dav::PropertyRequest::Extension(all::PropertyRequest::Cal( + cal::PropertyRequest::CalendarData(cal::CalendarDataRequest::default()), + )), + ]) + } + fn properties(&self, _user: &ArcUser, prop: dav::PropName<All>) -> PropertyStream<'static> { + let this = self.clone(); + + futures::stream::iter(prop.0) + .then(move |n| { + let this = this.clone(); + + async move { + let prop = match &n { + dav::PropertyRequest::DisplayName => { + dav::Property::DisplayName(format!("{} event", this.filename)) + } + dav::PropertyRequest::ResourceType => dav::Property::ResourceType(vec![]), + dav::PropertyRequest::GetContentType => { + dav::Property::GetContentType("text/calendar".into()) + } + dav::PropertyRequest::GetEtag => { + let etag = this.etag().await.ok_or(n.clone())?; + dav::Property::GetEtag(etag) + } + dav::PropertyRequest::Extension(all::PropertyRequest::Cal( + cal::PropertyRequest::CalendarData(req), + )) => { + let ics = String::from_utf8( + this.col.get(this.blob_id).await.or(Err(n.clone()))?, + ) + .or(Err(n.clone()))?; + + let new_ics = match &req.comp { + None => ics, + Some(prune_comp) => { + // parse content + let ics = match icalendar::parser::read_calendar(&ics) { + Ok(v) => v, + Err(e) => { + tracing::warn!(err=?e, "Unable to parse ICS in calendar-query"); + return Err(n.clone()) + } + }; + + // build a fake vcal component for caldav compat + let fake_vcal_component = icalendar::parser::Component { + name: cal::Component::VCalendar.as_str().into(), + properties: ics.properties, + components: ics.components, + }; + + // rebuild component + let new_comp = match aero_ical::prune::component(&fake_vcal_component, prune_comp) { + Some(v) => v, + None => return Err(n.clone()), + }; + + // reserialize + format!("{}", icalendar::parser::Calendar { properties: new_comp.properties, components: new_comp.components }) + }, + }; + + + + dav::Property::Extension(all::Property::Cal( + cal::Property::CalendarData(cal::CalendarDataPayload { + mime: None, + payload: new_ics, + }), + )) + } + _ => return Err(n), + }; + Ok(prop) + } + }) + .boxed() + } + + fn put<'a>( + &'a self, + policy: PutPolicy, + stream: Content<'a>, + ) -> BoxFuture<'a, std::result::Result<Etag, std::io::Error>> { + async { + let existing_etag = self + .etag() + .await + .ok_or(std::io::Error::new(std::io::ErrorKind::Other, "Etag error"))?; + match policy { + PutPolicy::CreateOnly => { + return Err(std::io::Error::from(std::io::ErrorKind::AlreadyExists)) + } + PutPolicy::ReplaceEtag(etag) if etag != existing_etag.as_str() => { + return Err(std::io::Error::from(std::io::ErrorKind::AlreadyExists)) + } + _ => (), + }; + + //@FIXME for now, our storage interface does not allow streaming, + // so we load everything in memory + let mut evt = Vec::new(); + let mut reader = stream.into_async_read(); + reader + .read_to_end(&mut evt) + .await + .or(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)))?; + let (_token, entry) = self + .col + .put(self.filename.as_str(), evt.as_ref()) + .await + .or(Err(std::io::ErrorKind::Interrupted))?; + self.col + .opportunistic_sync() + .await + .or(Err(std::io::ErrorKind::ConnectionReset))?; + Ok(entry.2) + } + .boxed() + } + + fn content<'a>(&self) -> Content<'a> { + //@FIXME for now, our storage interface does not allow streaming, + // so we load everything in memory + let calendar = self.col.clone(); + let blob_id = self.blob_id.clone(); + let calblob = async move { + let raw_ics = calendar + .get(blob_id) + .await + .or(Err(std::io::Error::from(std::io::ErrorKind::Interrupted)))?; + + Ok(hyper::body::Bytes::from(raw_ics)) + }; + futures::stream::once(Box::pin(calblob)).boxed() + } + + fn content_type(&self) -> &str { + "text/calendar" + } + + fn etag(&self) -> BoxFuture<Option<Etag>> { + let calendar = self.col.clone(); + + async move { + calendar + .dag() + .await + .table + .get(&self.blob_id) + .map(|(_, _, etag)| etag.to_string()) + } + .boxed() + } + + fn delete(&self) -> BoxFuture<std::result::Result<(), std::io::Error>> { + let calendar = self.col.clone(); + let blob_id = self.blob_id.clone(); + + async move { + let _token = match calendar.delete(blob_id).await { + Ok(v) => v, + Err(e) => { + tracing::error!(err=?e, "delete event node"); + return Err(std::io::Error::from(std::io::ErrorKind::Interrupted)); + } + }; + calendar + .opportunistic_sync() + .await + .or(Err(std::io::ErrorKind::ConnectionReset))?; + Ok(()) + } + .boxed() + } + fn diff<'a>( + &self, + _sync_token: Option<Token>, + ) -> BoxFuture< + 'a, + std::result::Result<(Token, Vec<Box<dyn DavNode>>, Vec<dav::Href>), std::io::Error>, + > { + async { Err(std::io::Error::from(std::io::ErrorKind::Unsupported)) }.boxed() + } + + fn dav_header(&self) -> String { + "1, access-control".into() + } +} + +#[derive(Clone)] +pub(crate) struct CreateEventNode { + col: Arc<Calendar>, + calname: String, + filename: String, +} +impl DavNode for CreateEventNode { + fn fetch<'a>( + &self, + user: &'a ArcUser, + path: &'a [&str], + create: bool, + ) -> BoxFuture<'a, Result<Box<dyn DavNode>>> { + if path.len() == 0 { + let node = Box::new(self.clone()) as Box<dyn DavNode>; + return async { Ok(node) }.boxed(); + } + + async { + Err(anyhow!( + "Not supported: can't create a child on an event node" + )) + } + .boxed() + } + + fn children<'a>(&self, user: &'a ArcUser) -> BoxFuture<'a, Vec<Box<dyn DavNode>>> { + async { vec![] }.boxed() + } + + fn path(&self, user: &ArcUser) -> String { + format!( + "/{}/calendar/{}/{}", + user.username, self.calname, self.filename + ) + } + + fn supported_properties(&self, user: &ArcUser) -> dav::PropName<All> { + dav::PropName(vec![]) + } + + fn properties(&self, _user: &ArcUser, prop: dav::PropName<All>) -> PropertyStream<'static> { + futures::stream::iter(vec![]).boxed() + } + + fn put<'a>( + &'a self, + _policy: PutPolicy, + stream: Content<'a>, + ) -> BoxFuture<'a, std::result::Result<Etag, std::io::Error>> { + //@NOTE: policy might not be needed here: whatever we put, there is no known entries here + + async { + //@FIXME for now, our storage interface does not allow for streaming + let mut evt = Vec::new(); + let mut reader = stream.into_async_read(); + reader.read_to_end(&mut evt).await.unwrap(); + let (_token, entry) = self + .col + .put(self.filename.as_str(), evt.as_ref()) + .await + .or(Err(std::io::ErrorKind::Interrupted))?; + self.col + .opportunistic_sync() + .await + .or(Err(std::io::ErrorKind::ConnectionReset))?; + Ok(entry.2) + } + .boxed() + } + + fn content<'a>(&self) -> Content<'a> { + futures::stream::once(futures::future::err(std::io::Error::from( + std::io::ErrorKind::Unsupported, + ))) + .boxed() + } + + fn content_type(&self) -> &str { + "text/plain" + } + + fn etag(&self) -> BoxFuture<Option<Etag>> { + async { None }.boxed() + } + + fn delete(&self) -> BoxFuture<std::result::Result<(), std::io::Error>> { + // Nothing to delete + async { Ok(()) }.boxed() + } + fn diff<'a>( + &self, + _sync_token: Option<Token>, + ) -> BoxFuture< + 'a, + std::result::Result<(Token, Vec<Box<dyn DavNode>>, Vec<dav::Href>), std::io::Error>, + > { + async { Err(std::io::Error::from(std::io::ErrorKind::Unsupported)) }.boxed() + } + + fn dav_header(&self) -> String { + "1, access-control".into() + } +} diff --git a/aero-proto/src/imap/attributes.rs b/aero-proto/src/imap/attributes.rs new file mode 100644 index 0000000..89446a8 --- /dev/null +++ b/aero-proto/src/imap/attributes.rs @@ -0,0 +1,77 @@ +use imap_codec::imap_types::command::FetchModifier; +use imap_codec::imap_types::fetch::{MacroOrMessageDataItemNames, MessageDataItemName, Section}; + +/// Internal decisions based on fetched attributes +/// passed by the client + +pub struct AttributesProxy { + pub attrs: Vec<MessageDataItemName<'static>>, +} +impl AttributesProxy { + pub fn new( + attrs: &MacroOrMessageDataItemNames<'static>, + modifiers: &[FetchModifier], + is_uid_fetch: bool, + ) -> Self { + // Expand macros + let mut fetch_attrs = match attrs { + MacroOrMessageDataItemNames::Macro(m) => { + use imap_codec::imap_types::fetch::Macro; + use MessageDataItemName::*; + match m { + Macro::All => vec![Flags, InternalDate, Rfc822Size, Envelope], + Macro::Fast => vec![Flags, InternalDate, Rfc822Size], + Macro::Full => vec![Flags, InternalDate, Rfc822Size, Envelope, Body], + _ => { + tracing::error!("unimplemented macro"); + vec![] + } + } + } + MacroOrMessageDataItemNames::MessageDataItemNames(a) => a.clone(), + }; + + // Handle uids + if is_uid_fetch && !fetch_attrs.contains(&MessageDataItemName::Uid) { + fetch_attrs.push(MessageDataItemName::Uid); + } + + // Handle inferred MODSEQ tag + let is_changed_since = modifiers + .iter() + .any(|m| matches!(m, FetchModifier::ChangedSince(..))); + if is_changed_since && !fetch_attrs.contains(&MessageDataItemName::ModSeq) { + fetch_attrs.push(MessageDataItemName::ModSeq); + } + + Self { attrs: fetch_attrs } + } + + pub fn is_enabling_condstore(&self) -> bool { + self.attrs + .iter() + .any(|x| matches!(x, MessageDataItemName::ModSeq)) + } + + pub fn need_body(&self) -> bool { + self.attrs.iter().any(|x| match x { + MessageDataItemName::Body + | MessageDataItemName::Rfc822 + | MessageDataItemName::Rfc822Text + | MessageDataItemName::BodyStructure => true, + + MessageDataItemName::BodyExt { + section: Some(section), + partial: _, + peek: _, + } => match section { + Section::Header(None) + | Section::HeaderFields(None, _) + | Section::HeaderFieldsNot(None, _) => false, + _ => true, + }, + MessageDataItemName::BodyExt { .. } => true, + _ => false, + }) + } +} diff --git a/aero-proto/src/imap/capability.rs b/aero-proto/src/imap/capability.rs new file mode 100644 index 0000000..c76b51c --- /dev/null +++ b/aero-proto/src/imap/capability.rs @@ -0,0 +1,159 @@ +use imap_codec::imap_types::command::{FetchModifier, SelectExamineModifier, StoreModifier}; +use imap_codec::imap_types::core::Vec1; +use imap_codec::imap_types::extensions::enable::{CapabilityEnable, Utf8Kind}; +use imap_codec::imap_types::response::Capability; +use std::collections::HashSet; + +use crate::imap::attributes::AttributesProxy; + +fn capability_unselect() -> Capability<'static> { + Capability::try_from("UNSELECT").unwrap() +} + +fn capability_condstore() -> Capability<'static> { + Capability::try_from("CONDSTORE").unwrap() +} + +fn capability_uidplus() -> Capability<'static> { + Capability::try_from("UIDPLUS").unwrap() +} + +fn capability_liststatus() -> Capability<'static> { + Capability::try_from("LIST-STATUS").unwrap() +} + +/* +fn capability_qresync() -> Capability<'static> { + Capability::try_from("QRESYNC").unwrap() +} +*/ + +#[derive(Debug, Clone)] +pub struct ServerCapability(HashSet<Capability<'static>>); + +impl Default for ServerCapability { + fn default() -> Self { + Self(HashSet::from([ + Capability::Imap4Rev1, + Capability::Enable, + Capability::Move, + Capability::LiteralPlus, + Capability::Idle, + capability_unselect(), + capability_condstore(), + capability_uidplus(), + capability_liststatus(), + //capability_qresync(), + ])) + } +} + +impl ServerCapability { + pub fn to_vec(&self) -> Vec1<Capability<'static>> { + self.0 + .iter() + .map(|v| v.clone()) + .collect::<Vec<_>>() + .try_into() + .unwrap() + } + + #[allow(dead_code)] + pub fn support(&self, cap: &Capability<'static>) -> bool { + self.0.contains(cap) + } +} + +#[derive(Clone)] +pub enum ClientStatus { + NotSupportedByServer, + Disabled, + Enabled, +} +impl ClientStatus { + pub fn is_enabled(&self) -> bool { + matches!(self, Self::Enabled) + } + + pub fn enable(&self) -> Self { + match self { + Self::Disabled => Self::Enabled, + other => other.clone(), + } + } +} + +pub struct ClientCapability { + pub condstore: ClientStatus, + pub utf8kind: Option<Utf8Kind>, +} + +impl ClientCapability { + pub fn new(sc: &ServerCapability) -> Self { + Self { + condstore: match sc.0.contains(&capability_condstore()) { + true => ClientStatus::Disabled, + _ => ClientStatus::NotSupportedByServer, + }, + utf8kind: None, + } + } + + pub fn enable_condstore(&mut self) { + self.condstore = self.condstore.enable(); + } + + pub fn attributes_enable(&mut self, ap: &AttributesProxy) { + if ap.is_enabling_condstore() { + self.enable_condstore() + } + } + + pub fn fetch_modifiers_enable(&mut self, mods: &[FetchModifier]) { + if mods + .iter() + .any(|x| matches!(x, FetchModifier::ChangedSince(..))) + { + self.enable_condstore() + } + } + + pub fn store_modifiers_enable(&mut self, mods: &[StoreModifier]) { + if mods + .iter() + .any(|x| matches!(x, StoreModifier::UnchangedSince(..))) + { + self.enable_condstore() + } + } + + pub fn select_enable(&mut self, mods: &[SelectExamineModifier]) { + for m in mods.iter() { + match m { + SelectExamineModifier::Condstore => self.enable_condstore(), + } + } + } + + pub fn try_enable( + &mut self, + caps: &[CapabilityEnable<'static>], + ) -> Vec<CapabilityEnable<'static>> { + let mut enabled = vec![]; + for cap in caps { + match cap { + CapabilityEnable::CondStore if matches!(self.condstore, ClientStatus::Disabled) => { + self.condstore = ClientStatus::Enabled; + enabled.push(cap.clone()); + } + CapabilityEnable::Utf8(kind) if Some(kind) != self.utf8kind.as_ref() => { + self.utf8kind = Some(kind.clone()); + enabled.push(cap.clone()); + } + _ => (), + } + } + + enabled + } +} diff --git a/aero-proto/src/imap/command/anonymous.rs b/aero-proto/src/imap/command/anonymous.rs new file mode 100644 index 0000000..f23ec17 --- /dev/null +++ b/aero-proto/src/imap/command/anonymous.rs @@ -0,0 +1,84 @@ +use anyhow::Result; +use imap_codec::imap_types::command::{Command, CommandBody}; +use imap_codec::imap_types::core::AString; +use imap_codec::imap_types::response::Code; +use imap_codec::imap_types::secret::Secret; + +use aero_collections::user::User; +use aero_user::login::ArcLoginProvider; + +use crate::imap::capability::ServerCapability; +use crate::imap::command::anystate; +use crate::imap::flow; +use crate::imap::response::Response; + +//--- dispatching + +pub struct AnonymousContext<'a> { + pub req: &'a Command<'static>, + pub server_capabilities: &'a ServerCapability, + pub login_provider: &'a ArcLoginProvider, +} + +pub async fn dispatch(ctx: AnonymousContext<'_>) -> Result<(Response<'static>, flow::Transition)> { + match &ctx.req.body { + // Any State + CommandBody::Noop => anystate::noop_nothing(ctx.req.tag.clone()), + CommandBody::Capability => { + anystate::capability(ctx.req.tag.clone(), ctx.server_capabilities) + } + CommandBody::Logout => anystate::logout(), + + // Specific to anonymous context (3 commands) + CommandBody::Login { username, password } => ctx.login(username, password).await, + CommandBody::Authenticate { .. } => { + anystate::not_implemented(ctx.req.tag.clone(), "authenticate") + } + //StartTLS is not implemented for now, we will probably go full TLS. + + // Collect other commands + _ => anystate::wrong_state(ctx.req.tag.clone()), + } +} + +//--- Command controllers, private + +impl<'a> AnonymousContext<'a> { + async fn login( + self, + username: &AString<'a>, + password: &Secret<AString<'a>>, + ) -> Result<(Response<'static>, flow::Transition)> { + let (u, p) = ( + std::str::from_utf8(username.as_ref())?, + std::str::from_utf8(password.declassify().as_ref())?, + ); + tracing::info!(user = %u, "command.login"); + + let creds = match self.login_provider.login(&u, &p).await { + Err(e) => { + tracing::debug!(error=%e, "authentication failed"); + return Ok(( + Response::build() + .to_req(self.req) + .message("Authentication failed") + .no()?, + flow::Transition::None, + )); + } + Ok(c) => c, + }; + + let user = User::new(u.to_string(), creds).await?; + + tracing::info!(username=%u, "connected"); + Ok(( + Response::build() + .to_req(self.req) + .code(Code::Capability(self.server_capabilities.to_vec())) + .message("Completed") + .ok()?, + flow::Transition::Authenticate(user), + )) + } +} diff --git a/aero-proto/src/imap/command/anystate.rs b/aero-proto/src/imap/command/anystate.rs new file mode 100644 index 0000000..718ba3f --- /dev/null +++ b/aero-proto/src/imap/command/anystate.rs @@ -0,0 +1,54 @@ +use anyhow::Result; +use imap_codec::imap_types::core::Tag; +use imap_codec::imap_types::response::Data; + +use crate::imap::capability::ServerCapability; +use crate::imap::flow; +use crate::imap::response::Response; + +pub(crate) fn capability( + tag: Tag<'static>, + cap: &ServerCapability, +) -> Result<(Response<'static>, flow::Transition)> { + let res = Response::build() + .tag(tag) + .message("Server capabilities") + .data(Data::Capability(cap.to_vec())) + .ok()?; + + Ok((res, flow::Transition::None)) +} + +pub(crate) fn noop_nothing(tag: Tag<'static>) -> Result<(Response<'static>, flow::Transition)> { + Ok(( + Response::build().tag(tag).message("Noop completed.").ok()?, + flow::Transition::None, + )) +} + +pub(crate) fn logout() -> Result<(Response<'static>, flow::Transition)> { + Ok((Response::bye()?, flow::Transition::Logout)) +} + +pub(crate) fn not_implemented<'a>( + tag: Tag<'a>, + what: &str, +) -> Result<(Response<'a>, flow::Transition)> { + Ok(( + Response::build() + .tag(tag) + .message(format!("Command not implemented {}", what)) + .bad()?, + flow::Transition::None, + )) +} + +pub(crate) fn wrong_state(tag: Tag<'static>) -> Result<(Response<'static>, flow::Transition)> { + Ok(( + Response::build() + .tag(tag) + .message("Command not authorized in this state") + .bad()?, + flow::Transition::None, + )) +} diff --git a/aero-proto/src/imap/command/authenticated.rs b/aero-proto/src/imap/command/authenticated.rs new file mode 100644 index 0000000..5bd34cb --- /dev/null +++ b/aero-proto/src/imap/command/authenticated.rs @@ -0,0 +1,682 @@ +use std::collections::BTreeMap; +use std::sync::Arc; +use thiserror::Error; + +use anyhow::{anyhow, bail, Result}; +use imap_codec::imap_types::command::{ + Command, CommandBody, ListReturnItem, SelectExamineModifier, +}; +use imap_codec::imap_types::core::{Atom, Literal, QuotedChar, Vec1}; +use imap_codec::imap_types::datetime::DateTime; +use imap_codec::imap_types::extensions::enable::CapabilityEnable; +use imap_codec::imap_types::flag::{Flag, FlagNameAttribute}; +use imap_codec::imap_types::mailbox::{ListMailbox, Mailbox as MailboxCodec}; +use imap_codec::imap_types::response::{Code, CodeOther, Data}; +use imap_codec::imap_types::status::{StatusDataItem, StatusDataItemName}; + +use aero_collections::mail::namespace::MAILBOX_HIERARCHY_DELIMITER as MBX_HIER_DELIM_RAW; +use aero_collections::mail::uidindex::*; +use aero_collections::mail::IMF; +use aero_collections::user::User; + +use crate::imap::capability::{ClientCapability, ServerCapability}; +use crate::imap::command::{anystate, MailboxName}; +use crate::imap::flow; +use crate::imap::mailbox_view::MailboxView; +use crate::imap::response::Response; + +pub struct AuthenticatedContext<'a> { + pub req: &'a Command<'static>, + pub server_capabilities: &'a ServerCapability, + pub client_capabilities: &'a mut ClientCapability, + pub user: &'a Arc<User>, +} + +pub async fn dispatch<'a>( + mut ctx: AuthenticatedContext<'a>, +) -> Result<(Response<'static>, flow::Transition)> { + match &ctx.req.body { + // Any state + CommandBody::Noop => anystate::noop_nothing(ctx.req.tag.clone()), + CommandBody::Capability => { + anystate::capability(ctx.req.tag.clone(), ctx.server_capabilities) + } + CommandBody::Logout => anystate::logout(), + + // Specific to this state (11 commands) + CommandBody::Create { mailbox } => ctx.create(mailbox).await, + CommandBody::Delete { mailbox } => ctx.delete(mailbox).await, + CommandBody::Rename { from, to } => ctx.rename(from, to).await, + CommandBody::Lsub { + reference, + mailbox_wildcard, + } => ctx.list(reference, mailbox_wildcard, &[], true).await, + CommandBody::List { + reference, + mailbox_wildcard, + r#return, + } => ctx.list(reference, mailbox_wildcard, r#return, false).await, + CommandBody::Status { + mailbox, + item_names, + } => ctx.status(mailbox, item_names).await, + CommandBody::Subscribe { mailbox } => ctx.subscribe(mailbox).await, + CommandBody::Unsubscribe { mailbox } => ctx.unsubscribe(mailbox).await, + CommandBody::Select { mailbox, modifiers } => ctx.select(mailbox, modifiers).await, + CommandBody::Examine { mailbox, modifiers } => ctx.examine(mailbox, modifiers).await, + CommandBody::Append { + mailbox, + flags, + date, + message, + } => ctx.append(mailbox, flags, date, message).await, + + // rfc5161 ENABLE + CommandBody::Enable { capabilities } => ctx.enable(capabilities), + + // Collect other commands + _ => anystate::wrong_state(ctx.req.tag.clone()), + } +} + +// --- PRIVATE --- +impl<'a> AuthenticatedContext<'a> { + async fn create( + self, + mailbox: &MailboxCodec<'a>, + ) -> Result<(Response<'static>, flow::Transition)> { + let name = match mailbox { + MailboxCodec::Inbox => { + return Ok(( + Response::build() + .to_req(self.req) + .message("Cannot create INBOX") + .bad()?, + flow::Transition::None, + )); + } + MailboxCodec::Other(aname) => std::str::from_utf8(aname.as_ref())?, + }; + + match self.user.create_mailbox(&name).await { + Ok(()) => Ok(( + Response::build() + .to_req(self.req) + .message("CREATE complete") + .ok()?, + flow::Transition::None, + )), + Err(e) => Ok(( + Response::build() + .to_req(self.req) + .message(&e.to_string()) + .no()?, + flow::Transition::None, + )), + } + } + + async fn delete( + self, + mailbox: &MailboxCodec<'a>, + ) -> Result<(Response<'static>, flow::Transition)> { + let name: &str = MailboxName(mailbox).try_into()?; + + match self.user.delete_mailbox(&name).await { + Ok(()) => Ok(( + Response::build() + .to_req(self.req) + .message("DELETE complete") + .ok()?, + flow::Transition::None, + )), + Err(e) => Ok(( + Response::build() + .to_req(self.req) + .message(e.to_string()) + .no()?, + flow::Transition::None, + )), + } + } + + async fn rename( + self, + from: &MailboxCodec<'a>, + to: &MailboxCodec<'a>, + ) -> Result<(Response<'static>, flow::Transition)> { + let name: &str = MailboxName(from).try_into()?; + let new_name: &str = MailboxName(to).try_into()?; + + match self.user.rename_mailbox(&name, &new_name).await { + Ok(()) => Ok(( + Response::build() + .to_req(self.req) + .message("RENAME complete") + .ok()?, + flow::Transition::None, + )), + Err(e) => Ok(( + Response::build() + .to_req(self.req) + .message(e.to_string()) + .no()?, + flow::Transition::None, + )), + } + } + + async fn list( + &mut self, + reference: &MailboxCodec<'a>, + mailbox_wildcard: &ListMailbox<'a>, + must_return: &[ListReturnItem], + is_lsub: bool, + ) -> Result<(Response<'static>, flow::Transition)> { + let mbx_hier_delim: QuotedChar = QuotedChar::unvalidated(MBX_HIER_DELIM_RAW); + + let reference: &str = MailboxName(reference).try_into()?; + if !reference.is_empty() { + return Ok(( + Response::build() + .to_req(self.req) + .message("References not supported") + .bad()?, + flow::Transition::None, + )); + } + + let status_item_names = must_return.iter().find_map(|m| match m { + ListReturnItem::Status(v) => Some(v), + _ => None, + }); + + // @FIXME would probably need a rewrite to better use the imap_codec library + let wildcard = match mailbox_wildcard { + ListMailbox::Token(v) => std::str::from_utf8(v.as_ref())?, + ListMailbox::String(v) => std::str::from_utf8(v.as_ref())?, + }; + if wildcard.is_empty() { + if is_lsub { + return Ok(( + Response::build() + .to_req(self.req) + .message("LSUB complete") + .data(Data::Lsub { + items: vec![], + delimiter: Some(mbx_hier_delim), + mailbox: "".try_into().unwrap(), + }) + .ok()?, + flow::Transition::None, + )); + } else { + return Ok(( + Response::build() + .to_req(self.req) + .message("LIST complete") + .data(Data::List { + items: vec![], + delimiter: Some(mbx_hier_delim), + mailbox: "".try_into().unwrap(), + }) + .ok()?, + flow::Transition::None, + )); + } + } + + let mailboxes = self.user.list_mailboxes().await?; + let mut vmailboxes = BTreeMap::new(); + for mb in mailboxes.iter() { + for (i, _) in mb.match_indices(MBX_HIER_DELIM_RAW) { + if i > 0 { + let smb = &mb[..i]; + vmailboxes.entry(smb).or_insert(false); + } + } + vmailboxes.insert(mb, true); + } + + let mut ret = vec![]; + for (mb, is_real) in vmailboxes.iter() { + if matches_wildcard(&wildcard, mb) { + let mailbox: MailboxCodec = mb + .to_string() + .try_into() + .map_err(|_| anyhow!("invalid mailbox name"))?; + let mut items = vec![FlagNameAttribute::from(Atom::unvalidated("Subscribed"))]; + + // Decoration + if !*is_real { + items.push(FlagNameAttribute::Noselect); + } else { + match *mb { + "Drafts" => items.push(Atom::unvalidated("Drafts").into()), + "Archive" => items.push(Atom::unvalidated("Archive").into()), + "Sent" => items.push(Atom::unvalidated("Sent").into()), + "Trash" => items.push(Atom::unvalidated("Trash").into()), + _ => (), + }; + } + + // Result type + if is_lsub { + ret.push(Data::Lsub { + items, + delimiter: Some(mbx_hier_delim), + mailbox: mailbox.clone(), + }); + } else { + ret.push(Data::List { + items, + delimiter: Some(mbx_hier_delim), + mailbox: mailbox.clone(), + }); + } + + // Also collect status + if let Some(sin) = status_item_names { + let ret_attrs = match self.status_items(mb, sin).await { + Ok(a) => a, + Err(e) => { + tracing::error!(err=?e, mailbox=%mb, "Unable to fetch status for mailbox"); + continue; + } + }; + + let data = Data::Status { + mailbox, + items: ret_attrs.into(), + }; + + ret.push(data); + } + } + } + + let msg = if is_lsub { + "LSUB completed" + } else { + "LIST completed" + }; + Ok(( + Response::build() + .to_req(self.req) + .message(msg) + .many_data(ret) + .ok()?, + flow::Transition::None, + )) + } + + async fn status( + &mut self, + mailbox: &MailboxCodec<'static>, + attributes: &[StatusDataItemName], + ) -> Result<(Response<'static>, flow::Transition)> { + let name: &str = MailboxName(mailbox).try_into()?; + + let ret_attrs = match self.status_items(name, attributes).await { + Ok(v) => v, + Err(e) => match e.downcast_ref::<CommandError>() { + Some(CommandError::MailboxNotFound) => { + return Ok(( + Response::build() + .to_req(self.req) + .message("Mailbox does not exist") + .no()?, + flow::Transition::None, + )) + } + _ => return Err(e.into()), + }, + }; + + let data = Data::Status { + mailbox: mailbox.clone(), + items: ret_attrs.into(), + }; + + Ok(( + Response::build() + .to_req(self.req) + .message("STATUS completed") + .data(data) + .ok()?, + flow::Transition::None, + )) + } + + async fn status_items( + &mut self, + name: &str, + attributes: &[StatusDataItemName], + ) -> Result<Vec<StatusDataItem>> { + let mb_opt = self.user.open_mailbox(name).await?; + let mb = match mb_opt { + Some(mb) => mb, + None => return Err(CommandError::MailboxNotFound.into()), + }; + + let view = MailboxView::new(mb, self.client_capabilities.condstore.is_enabled()).await; + + let mut ret_attrs = vec![]; + for attr in attributes.iter() { + ret_attrs.push(match attr { + StatusDataItemName::Messages => StatusDataItem::Messages(view.exists()?), + StatusDataItemName::Unseen => StatusDataItem::Unseen(view.unseen_count() as u32), + StatusDataItemName::Recent => StatusDataItem::Recent(view.recent()?), + StatusDataItemName::UidNext => StatusDataItem::UidNext(view.uidnext()), + StatusDataItemName::UidValidity => { + StatusDataItem::UidValidity(view.uidvalidity()) + } + StatusDataItemName::Deleted => { + bail!("quota not implemented, can't return deleted elements waiting for EXPUNGE"); + }, + StatusDataItemName::DeletedStorage => { + bail!("quota not implemented, can't return freed storage after EXPUNGE will be run"); + }, + StatusDataItemName::HighestModSeq => { + self.client_capabilities.enable_condstore(); + StatusDataItem::HighestModSeq(view.highestmodseq().get()) + }, + }); + } + Ok(ret_attrs) + } + + async fn subscribe( + self, + mailbox: &MailboxCodec<'a>, + ) -> Result<(Response<'static>, flow::Transition)> { + let name: &str = MailboxName(mailbox).try_into()?; + + if self.user.has_mailbox(&name).await? { + Ok(( + Response::build() + .to_req(self.req) + .message("SUBSCRIBE complete") + .ok()?, + flow::Transition::None, + )) + } else { + Ok(( + Response::build() + .to_req(self.req) + .message(format!("Mailbox {} does not exist", name)) + .bad()?, + flow::Transition::None, + )) + } + } + + async fn unsubscribe( + self, + mailbox: &MailboxCodec<'a>, + ) -> Result<(Response<'static>, flow::Transition)> { + let name: &str = MailboxName(mailbox).try_into()?; + + if self.user.has_mailbox(&name).await? { + Ok(( + Response::build() + .to_req(self.req) + .message(format!( + "Cannot unsubscribe from mailbox {}: not supported by Aerogramme", + name + )) + .bad()?, + flow::Transition::None, + )) + } else { + Ok(( + Response::build() + .to_req(self.req) + .message(format!("Mailbox {} does not exist", name)) + .no()?, + flow::Transition::None, + )) + } + } + + /* + * TRACE BEGIN --- + + + Example: C: A142 SELECT INBOX + S: * 172 EXISTS + S: * 1 RECENT + S: * OK [UNSEEN 12] Message 12 is first unseen + S: * OK [UIDVALIDITY 3857529045] UIDs valid + S: * OK [UIDNEXT 4392] Predicted next UID + S: * FLAGS (\Answered \Flagged \Deleted \Seen \Draft) + S: * OK [PERMANENTFLAGS (\Deleted \Seen \*)] Limited + S: A142 OK [READ-WRITE] SELECT completed + + --- a mailbox with no unseen message -> no unseen entry + NOTES: + RFC3501 (imap4rev1) says if there is no OK [UNSEEN] response, client must make no assumption, + it is therefore correct to not return it even if there are unseen messages + RFC9051 (imap4rev2) says that OK [UNSEEN] responses are deprecated after SELECT and EXAMINE + For Aerogramme, we just don't send the OK [UNSEEN], it's correct to do in both specifications. + + + 20 select "INBOX.achats" + * FLAGS (\Answered \Flagged \Deleted \Seen \Draft $Forwarded JUNK $label1) + * OK [PERMANENTFLAGS (\Answered \Flagged \Deleted \Seen \Draft $Forwarded JUNK $label1 \*)] Flags permitted. + * 88 EXISTS + * 0 RECENT + * OK [UIDVALIDITY 1347986788] UIDs valid + * OK [UIDNEXT 91] Predicted next UID + * OK [HIGHESTMODSEQ 72] Highest + 20 OK [READ-WRITE] Select completed (0.001 + 0.000 secs). + + * TRACE END --- + */ + async fn select( + self, + mailbox: &MailboxCodec<'a>, + modifiers: &[SelectExamineModifier], + ) -> Result<(Response<'static>, flow::Transition)> { + self.client_capabilities.select_enable(modifiers); + + let name: &str = MailboxName(mailbox).try_into()?; + + let mb_opt = self.user.open_mailbox(&name).await?; + let mb = match mb_opt { + Some(mb) => mb, + None => { + return Ok(( + Response::build() + .to_req(self.req) + .message("Mailbox does not exist") + .no()?, + flow::Transition::None, + )) + } + }; + tracing::info!(username=%self.user.username, mailbox=%name, "mailbox.selected"); + + let mb = MailboxView::new(mb, self.client_capabilities.condstore.is_enabled()).await; + let data = mb.summary()?; + + Ok(( + Response::build() + .message("Select completed") + .to_req(self.req) + .code(Code::ReadWrite) + .set_body(data) + .ok()?, + flow::Transition::Select(mb, flow::MailboxPerm::ReadWrite), + )) + } + + async fn examine( + self, + mailbox: &MailboxCodec<'a>, + modifiers: &[SelectExamineModifier], + ) -> Result<(Response<'static>, flow::Transition)> { + self.client_capabilities.select_enable(modifiers); + + let name: &str = MailboxName(mailbox).try_into()?; + + let mb_opt = self.user.open_mailbox(&name).await?; + let mb = match mb_opt { + Some(mb) => mb, + None => { + return Ok(( + Response::build() + .to_req(self.req) + .message("Mailbox does not exist") + .no()?, + flow::Transition::None, + )) + } + }; + tracing::info!(username=%self.user.username, mailbox=%name, "mailbox.examined"); + + let mb = MailboxView::new(mb, self.client_capabilities.condstore.is_enabled()).await; + let data = mb.summary()?; + + Ok(( + Response::build() + .to_req(self.req) + .message("Examine completed") + .code(Code::ReadOnly) + .set_body(data) + .ok()?, + flow::Transition::Select(mb, flow::MailboxPerm::ReadOnly), + )) + } + + //@FIXME we should write a specific version for the "selected" state + //that returns some unsollicited responses + async fn append( + self, + mailbox: &MailboxCodec<'a>, + flags: &[Flag<'a>], + date: &Option<DateTime>, + message: &Literal<'a>, + ) -> Result<(Response<'static>, flow::Transition)> { + let append_tag = self.req.tag.clone(); + match self.append_internal(mailbox, flags, date, message).await { + Ok((_mb_view, uidvalidity, uid, _modseq)) => Ok(( + Response::build() + .tag(append_tag) + .message("APPEND completed") + .code(Code::Other(CodeOther::unvalidated( + format!("APPENDUID {} {}", uidvalidity, uid).into_bytes(), + ))) + .ok()?, + flow::Transition::None, + )), + Err(e) => Ok(( + Response::build() + .tag(append_tag) + .message(e.to_string()) + .no()?, + flow::Transition::None, + )), + } + } + + fn enable( + self, + cap_enable: &Vec1<CapabilityEnable<'static>>, + ) -> Result<(Response<'static>, flow::Transition)> { + let mut response_builder = Response::build().to_req(self.req); + let capabilities = self.client_capabilities.try_enable(cap_enable.as_ref()); + if capabilities.len() > 0 { + response_builder = response_builder.data(Data::Enabled { capabilities }); + } + Ok(( + response_builder.message("ENABLE completed").ok()?, + flow::Transition::None, + )) + } + + //@FIXME should be refactored and integrated to the mailbox view + pub(crate) async fn append_internal( + self, + mailbox: &MailboxCodec<'a>, + flags: &[Flag<'a>], + date: &Option<DateTime>, + message: &Literal<'a>, + ) -> Result<(MailboxView, ImapUidvalidity, ImapUid, ModSeq)> { + let name: &str = MailboxName(mailbox).try_into()?; + + let mb_opt = self.user.open_mailbox(&name).await?; + let mb = match mb_opt { + Some(mb) => mb, + None => bail!("Mailbox does not exist"), + }; + let view = MailboxView::new(mb, self.client_capabilities.condstore.is_enabled()).await; + + if date.is_some() { + tracing::warn!("Cannot set date when appending message"); + } + + let msg = + IMF::try_from(message.data()).map_err(|_| anyhow!("Could not parse e-mail message"))?; + let flags = flags.iter().map(|x| x.to_string()).collect::<Vec<_>>(); + // TODO: filter allowed flags? ping @Quentin + + let (uidvalidity, uid, modseq) = + view.internal.mailbox.append(msg, None, &flags[..]).await?; + //let unsollicited = view.update(UpdateParameters::default()).await?; + + Ok((view, uidvalidity, uid, modseq)) + } +} + +fn matches_wildcard(wildcard: &str, name: &str) -> bool { + let wildcard = wildcard.chars().collect::<Vec<char>>(); + let name = name.chars().collect::<Vec<char>>(); + + let mut matches = vec![vec![false; wildcard.len() + 1]; name.len() + 1]; + + for i in 0..=name.len() { + for j in 0..=wildcard.len() { + matches[i][j] = (i == 0 && j == 0) + || (j > 0 + && matches[i][j - 1] + && (wildcard[j - 1] == '%' || wildcard[j - 1] == '*')) + || (i > 0 + && j > 0 + && matches[i - 1][j - 1] + && wildcard[j - 1] == name[i - 1] + && wildcard[j - 1] != '%' + && wildcard[j - 1] != '*') + || (i > 0 + && j > 0 + && matches[i - 1][j] + && (wildcard[j - 1] == '*' + || (wildcard[j - 1] == '%' && name[i - 1] != MBX_HIER_DELIM_RAW))); + } + } + + matches[name.len()][wildcard.len()] +} + +#[derive(Error, Debug)] +pub enum CommandError { + #[error("Mailbox not found")] + MailboxNotFound, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_wildcard_matches() { + assert!(matches_wildcard("INBOX", "INBOX")); + assert!(matches_wildcard("*", "INBOX")); + assert!(matches_wildcard("%", "INBOX")); + assert!(!matches_wildcard("%", "Test.Azerty")); + assert!(!matches_wildcard("INBOX.*", "INBOX")); + assert!(matches_wildcard("Sent.*", "Sent.A")); + assert!(matches_wildcard("Sent.*", "Sent.A.B")); + assert!(!matches_wildcard("Sent.%", "Sent.A.B")); + } +} diff --git a/aero-proto/src/imap/command/mod.rs b/aero-proto/src/imap/command/mod.rs new file mode 100644 index 0000000..5382d06 --- /dev/null +++ b/aero-proto/src/imap/command/mod.rs @@ -0,0 +1,20 @@ +pub mod anonymous; +pub mod anystate; +pub mod authenticated; +pub mod selected; + +use aero_collections::mail::namespace::INBOX; +use imap_codec::imap_types::mailbox::Mailbox as MailboxCodec; + +/// Convert an IMAP mailbox name/identifier representation +/// to an utf-8 string that is used internally in Aerogramme +struct MailboxName<'a>(&'a MailboxCodec<'a>); +impl<'a> TryInto<&'a str> for MailboxName<'a> { + type Error = std::str::Utf8Error; + fn try_into(self) -> Result<&'a str, Self::Error> { + match self.0 { + MailboxCodec::Inbox => Ok(INBOX), + MailboxCodec::Other(aname) => Ok(std::str::from_utf8(aname.as_ref())?), + } + } +} diff --git a/aero-proto/src/imap/command/selected.rs b/aero-proto/src/imap/command/selected.rs new file mode 100644 index 0000000..190949b --- /dev/null +++ b/aero-proto/src/imap/command/selected.rs @@ -0,0 +1,425 @@ +use std::num::NonZeroU64; +use std::sync::Arc; + +use anyhow::Result; +use imap_codec::imap_types::command::{Command, CommandBody, FetchModifier, StoreModifier}; +use imap_codec::imap_types::core::Charset; +use imap_codec::imap_types::fetch::MacroOrMessageDataItemNames; +use imap_codec::imap_types::flag::{Flag, StoreResponse, StoreType}; +use imap_codec::imap_types::mailbox::Mailbox as MailboxCodec; +use imap_codec::imap_types::response::{Code, CodeOther}; +use imap_codec::imap_types::search::SearchKey; +use imap_codec::imap_types::sequence::SequenceSet; + +use aero_collections::user::User; + +use crate::imap::attributes::AttributesProxy; +use crate::imap::capability::{ClientCapability, ServerCapability}; +use crate::imap::command::{anystate, authenticated, MailboxName}; +use crate::imap::flow; +use crate::imap::mailbox_view::{MailboxView, UpdateParameters}; +use crate::imap::response::Response; + +pub struct SelectedContext<'a> { + pub req: &'a Command<'static>, + pub user: &'a Arc<User>, + pub mailbox: &'a mut MailboxView, + pub server_capabilities: &'a ServerCapability, + pub client_capabilities: &'a mut ClientCapability, + pub perm: &'a flow::MailboxPerm, +} + +pub async fn dispatch<'a>( + ctx: SelectedContext<'a>, +) -> Result<(Response<'static>, flow::Transition)> { + match &ctx.req.body { + // Any State + // noop is specific to this state + CommandBody::Capability => { + anystate::capability(ctx.req.tag.clone(), ctx.server_capabilities) + } + CommandBody::Logout => anystate::logout(), + + // Specific to this state (7 commands + NOOP) + CommandBody::Close => match ctx.perm { + flow::MailboxPerm::ReadWrite => ctx.close().await, + flow::MailboxPerm::ReadOnly => ctx.examine_close().await, + }, + CommandBody::Noop | CommandBody::Check => ctx.noop().await, + CommandBody::Fetch { + sequence_set, + macro_or_item_names, + modifiers, + uid, + } => { + ctx.fetch(sequence_set, macro_or_item_names, modifiers, uid) + .await + } + //@FIXME SearchKey::And is a legacy hack, should be refactored + CommandBody::Search { + charset, + criteria, + uid, + } => { + ctx.search(charset, &SearchKey::And(criteria.clone()), uid) + .await + } + CommandBody::Expunge { + // UIDPLUS (rfc4315) + uid_sequence_set, + } => ctx.expunge(uid_sequence_set).await, + CommandBody::Store { + sequence_set, + kind, + response, + flags, + modifiers, + uid, + } => { + ctx.store(sequence_set, kind, response, flags, modifiers, uid) + .await + } + CommandBody::Copy { + sequence_set, + mailbox, + uid, + } => ctx.copy(sequence_set, mailbox, uid).await, + CommandBody::Move { + sequence_set, + mailbox, + uid, + } => ctx.r#move(sequence_set, mailbox, uid).await, + + // UNSELECT extension (rfc3691) + CommandBody::Unselect => ctx.unselect().await, + + // In selected mode, we fallback to authenticated when needed + _ => { + authenticated::dispatch(authenticated::AuthenticatedContext { + req: ctx.req, + server_capabilities: ctx.server_capabilities, + client_capabilities: ctx.client_capabilities, + user: ctx.user, + }) + .await + } + } +} + +// --- PRIVATE --- + +impl<'a> SelectedContext<'a> { + async fn close(self) -> Result<(Response<'static>, flow::Transition)> { + // We expunge messages, + // but we don't send the untagged EXPUNGE responses + let tag = self.req.tag.clone(); + self.expunge(&None).await?; + Ok(( + Response::build().tag(tag).message("CLOSE completed").ok()?, + flow::Transition::Unselect, + )) + } + + /// CLOSE in examined state is not the same as in selected state + /// (in selected state it also does an EXPUNGE, here it doesn't) + async fn examine_close(self) -> Result<(Response<'static>, flow::Transition)> { + Ok(( + Response::build() + .to_req(self.req) + .message("CLOSE completed") + .ok()?, + flow::Transition::Unselect, + )) + } + + async fn unselect(self) -> Result<(Response<'static>, flow::Transition)> { + Ok(( + Response::build() + .to_req(self.req) + .message("UNSELECT completed") + .ok()?, + flow::Transition::Unselect, + )) + } + + pub async fn fetch( + self, + sequence_set: &SequenceSet, + attributes: &'a MacroOrMessageDataItemNames<'static>, + modifiers: &[FetchModifier], + uid: &bool, + ) -> Result<(Response<'static>, flow::Transition)> { + let ap = AttributesProxy::new(attributes, modifiers, *uid); + let mut changed_since: Option<NonZeroU64> = None; + modifiers.iter().for_each(|m| match m { + FetchModifier::ChangedSince(val) => { + changed_since = Some(*val); + } + }); + + match self + .mailbox + .fetch(sequence_set, &ap, changed_since, uid) + .await + { + Ok(resp) => { + // Capabilities enabling logic only on successful command + // (according to my understanding of the spec) + self.client_capabilities.attributes_enable(&ap); + self.client_capabilities.fetch_modifiers_enable(modifiers); + + // Response to the client + Ok(( + Response::build() + .to_req(self.req) + .message("FETCH completed") + .set_body(resp) + .ok()?, + flow::Transition::None, + )) + } + Err(e) => Ok(( + Response::build() + .to_req(self.req) + .message(e.to_string()) + .no()?, + flow::Transition::None, + )), + } + } + + pub async fn search( + self, + charset: &Option<Charset<'a>>, + criteria: &SearchKey<'a>, + uid: &bool, + ) -> Result<(Response<'static>, flow::Transition)> { + let (found, enable_condstore) = self.mailbox.search(charset, criteria, *uid).await?; + if enable_condstore { + self.client_capabilities.enable_condstore(); + } + Ok(( + Response::build() + .to_req(self.req) + .set_body(found) + .message("SEARCH completed") + .ok()?, + flow::Transition::None, + )) + } + + pub async fn noop(self) -> Result<(Response<'static>, flow::Transition)> { + self.mailbox.internal.mailbox.force_sync().await?; + + let updates = self.mailbox.update(UpdateParameters::default()).await?; + Ok(( + Response::build() + .to_req(self.req) + .message("NOOP completed.") + .set_body(updates) + .ok()?, + flow::Transition::None, + )) + } + + async fn expunge( + self, + uid_sequence_set: &Option<SequenceSet>, + ) -> Result<(Response<'static>, flow::Transition)> { + if let Some(failed) = self.fail_read_only() { + return Ok((failed, flow::Transition::None)); + } + + let tag = self.req.tag.clone(); + let data = self.mailbox.expunge(uid_sequence_set).await?; + + Ok(( + Response::build() + .tag(tag) + .message("EXPUNGE completed") + .set_body(data) + .ok()?, + flow::Transition::None, + )) + } + + async fn store( + self, + sequence_set: &SequenceSet, + kind: &StoreType, + response: &StoreResponse, + flags: &[Flag<'a>], + modifiers: &[StoreModifier], + uid: &bool, + ) -> Result<(Response<'static>, flow::Transition)> { + if let Some(failed) = self.fail_read_only() { + return Ok((failed, flow::Transition::None)); + } + + let mut unchanged_since: Option<NonZeroU64> = None; + modifiers.iter().for_each(|m| match m { + StoreModifier::UnchangedSince(val) => { + unchanged_since = Some(*val); + } + }); + + let (data, modified) = self + .mailbox + .store(sequence_set, kind, response, flags, unchanged_since, uid) + .await?; + + let mut ok_resp = Response::build() + .to_req(self.req) + .message("STORE completed") + .set_body(data); + + match modified[..] { + [] => (), + [_head, ..] => { + let modified_str = format!( + "MODIFIED {}", + modified + .into_iter() + .map(|x| x.to_string()) + .collect::<Vec<_>>() + .join(",") + ); + ok_resp = ok_resp.code(Code::Other(CodeOther::unvalidated( + modified_str.into_bytes(), + ))); + } + }; + + self.client_capabilities.store_modifiers_enable(modifiers); + + Ok((ok_resp.ok()?, flow::Transition::None)) + } + + async fn copy( + self, + sequence_set: &SequenceSet, + mailbox: &MailboxCodec<'a>, + uid: &bool, + ) -> Result<(Response<'static>, flow::Transition)> { + //@FIXME Could copy be valid in EXAMINE mode? + if let Some(failed) = self.fail_read_only() { + return Ok((failed, flow::Transition::None)); + } + + let name: &str = MailboxName(mailbox).try_into()?; + + let mb_opt = self.user.open_mailbox(&name).await?; + let mb = match mb_opt { + Some(mb) => mb, + None => { + return Ok(( + Response::build() + .to_req(self.req) + .message("Destination mailbox does not exist") + .code(Code::TryCreate) + .no()?, + flow::Transition::None, + )) + } + }; + + let (uidval, uid_map) = self.mailbox.copy(sequence_set, mb, uid).await?; + + let copyuid_str = format!( + "{} {} {}", + uidval, + uid_map + .iter() + .map(|(sid, _)| format!("{}", sid)) + .collect::<Vec<_>>() + .join(","), + uid_map + .iter() + .map(|(_, tuid)| format!("{}", tuid)) + .collect::<Vec<_>>() + .join(",") + ); + + Ok(( + Response::build() + .to_req(self.req) + .message("COPY completed") + .code(Code::Other(CodeOther::unvalidated( + format!("COPYUID {}", copyuid_str).into_bytes(), + ))) + .ok()?, + flow::Transition::None, + )) + } + + async fn r#move( + self, + sequence_set: &SequenceSet, + mailbox: &MailboxCodec<'a>, + uid: &bool, + ) -> Result<(Response<'static>, flow::Transition)> { + if let Some(failed) = self.fail_read_only() { + return Ok((failed, flow::Transition::None)); + } + + let name: &str = MailboxName(mailbox).try_into()?; + + let mb_opt = self.user.open_mailbox(&name).await?; + let mb = match mb_opt { + Some(mb) => mb, + None => { + return Ok(( + Response::build() + .to_req(self.req) + .message("Destination mailbox does not exist") + .code(Code::TryCreate) + .no()?, + flow::Transition::None, + )) + } + }; + + let (uidval, uid_map, data) = self.mailbox.r#move(sequence_set, mb, uid).await?; + + // compute code + let copyuid_str = format!( + "{} {} {}", + uidval, + uid_map + .iter() + .map(|(sid, _)| format!("{}", sid)) + .collect::<Vec<_>>() + .join(","), + uid_map + .iter() + .map(|(_, tuid)| format!("{}", tuid)) + .collect::<Vec<_>>() + .join(",") + ); + + Ok(( + Response::build() + .to_req(self.req) + .message("COPY completed") + .code(Code::Other(CodeOther::unvalidated( + format!("COPYUID {}", copyuid_str).into_bytes(), + ))) + .set_body(data) + .ok()?, + flow::Transition::None, + )) + } + + fn fail_read_only(&self) -> Option<Response<'static>> { + match self.perm { + flow::MailboxPerm::ReadWrite => None, + flow::MailboxPerm::ReadOnly => Some( + Response::build() + .to_req(self.req) + .message("Write command are forbidden while exmining mailbox") + .no() + .unwrap(), + ), + } + } +} diff --git a/aero-proto/src/imap/flags.rs b/aero-proto/src/imap/flags.rs new file mode 100644 index 0000000..0f6ec64 --- /dev/null +++ b/aero-proto/src/imap/flags.rs @@ -0,0 +1,30 @@ +use imap_codec::imap_types::core::Atom; +use imap_codec::imap_types::flag::{Flag, FlagFetch}; + +pub fn from_str(f: &str) -> Option<FlagFetch<'static>> { + match f.chars().next() { + Some('\\') => match f { + "\\Seen" => Some(FlagFetch::Flag(Flag::Seen)), + "\\Answered" => Some(FlagFetch::Flag(Flag::Answered)), + "\\Flagged" => Some(FlagFetch::Flag(Flag::Flagged)), + "\\Deleted" => Some(FlagFetch::Flag(Flag::Deleted)), + "\\Draft" => Some(FlagFetch::Flag(Flag::Draft)), + "\\Recent" => Some(FlagFetch::Recent), + _ => match Atom::try_from(f.strip_prefix('\\').unwrap().to_string()) { + Err(_) => { + tracing::error!(flag=%f, "Unable to encode flag as IMAP atom"); + None + } + Ok(a) => Some(FlagFetch::Flag(Flag::system(a))), + }, + }, + Some(_) => match Atom::try_from(f.to_string()) { + Err(_) => { + tracing::error!(flag=%f, "Unable to encode flag as IMAP atom"); + None + } + Ok(a) => Some(FlagFetch::Flag(Flag::keyword(a))), + }, + None => None, + } +} diff --git a/aero-proto/src/imap/flow.rs b/aero-proto/src/imap/flow.rs new file mode 100644 index 0000000..1986447 --- /dev/null +++ b/aero-proto/src/imap/flow.rs @@ -0,0 +1,115 @@ +use std::error::Error as StdError; +use std::fmt; +use std::sync::Arc; + +use imap_codec::imap_types::core::Tag; +use tokio::sync::Notify; + +use aero_collections::user::User; + +use crate::imap::mailbox_view::MailboxView; + +#[derive(Debug)] +pub enum Error { + ForbiddenTransition, +} +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Forbidden Transition") + } +} +impl StdError for Error {} + +pub enum State { + NotAuthenticated, + Authenticated(Arc<User>), + Selected(Arc<User>, MailboxView, MailboxPerm), + Idle( + Arc<User>, + MailboxView, + MailboxPerm, + Tag<'static>, + Arc<Notify>, + ), + Logout, +} +impl State { + pub fn notify(&self) -> Option<Arc<Notify>> { + match self { + Self::Idle(_, _, _, _, anotif) => Some(anotif.clone()), + _ => None, + } + } +} +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use State::*; + match self { + NotAuthenticated => write!(f, "NotAuthenticated"), + Authenticated(..) => write!(f, "Authenticated"), + Selected(..) => write!(f, "Selected"), + Idle(..) => write!(f, "Idle"), + Logout => write!(f, "Logout"), + } + } +} + +#[derive(Clone)] +pub enum MailboxPerm { + ReadOnly, + ReadWrite, +} + +pub enum Transition { + None, + Authenticate(Arc<User>), + Select(MailboxView, MailboxPerm), + Idle(Tag<'static>, Notify), + UnIdle, + Unselect, + Logout, +} +impl fmt::Display for Transition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Transition::*; + match self { + None => write!(f, "None"), + Authenticate(..) => write!(f, "Authenticated"), + Select(..) => write!(f, "Selected"), + Idle(..) => write!(f, "Idle"), + UnIdle => write!(f, "UnIdle"), + Unselect => write!(f, "Unselect"), + Logout => write!(f, "Logout"), + } + } +} + +// See RFC3501 section 3. +// https://datatracker.ietf.org/doc/html/rfc3501#page-13 +impl State { + pub fn apply(&mut self, tr: Transition) -> Result<(), Error> { + tracing::debug!(state=%self, transition=%tr, "try change state"); + + let new_state = match (std::mem::replace(self, State::Logout), tr) { + (s, Transition::None) => s, + (State::NotAuthenticated, Transition::Authenticate(u)) => State::Authenticated(u), + (State::Authenticated(u) | State::Selected(u, _, _), Transition::Select(m, p)) => { + State::Selected(u, m, p) + } + (State::Selected(u, _, _), Transition::Unselect) => State::Authenticated(u.clone()), + (State::Selected(u, m, p), Transition::Idle(t, s)) => { + State::Idle(u, m, p, t, Arc::new(s)) + } + (State::Idle(u, m, p, _, _), Transition::UnIdle) => State::Selected(u, m, p), + (_, Transition::Logout) => State::Logout, + (s, t) => { + tracing::error!(state=%s, transition=%t, "forbidden transition"); + return Err(Error::ForbiddenTransition); + } + }; + *self = new_state; + tracing::debug!(state=%self, "transition succeeded"); + + Ok(()) + } +} diff --git a/aero-proto/src/imap/imf_view.rs b/aero-proto/src/imap/imf_view.rs new file mode 100644 index 0000000..a4ca2e8 --- /dev/null +++ b/aero-proto/src/imap/imf_view.rs @@ -0,0 +1,109 @@ +use anyhow::{anyhow, Result}; +use chrono::naive::NaiveDate; + +use imap_codec::imap_types::core::{IString, NString}; +use imap_codec::imap_types::envelope::{Address, Envelope}; + +use eml_codec::imf; + +pub struct ImfView<'a>(pub &'a imf::Imf<'a>); + +impl<'a> ImfView<'a> { + pub fn naive_date(&self) -> Result<NaiveDate> { + Ok(self.0.date.ok_or(anyhow!("date is not set"))?.date_naive()) + } + + /// Envelope rules are defined in RFC 3501, section 7.4.2 + /// https://datatracker.ietf.org/doc/html/rfc3501#section-7.4.2 + /// + /// Some important notes: + /// + /// If the Sender or Reply-To lines are absent in the [RFC-2822] + /// header, or are present but empty, the server sets the + /// corresponding member of the envelope to be the same value as + /// the from member (the client is not expected to know to do + /// this). Note: [RFC-2822] requires that all messages have a valid + /// From header. Therefore, the from, sender, and reply-to + /// members in the envelope can not be NIL. + /// + /// If the Date, Subject, In-Reply-To, and Message-ID header lines + /// are absent in the [RFC-2822] header, the corresponding member + /// of the envelope is NIL; if these header lines are present but + /// empty the corresponding member of the envelope is the empty + /// string. + + //@FIXME return an error if the envelope is invalid instead of panicking + //@FIXME some fields must be defaulted if there are not set. + pub fn message_envelope(&self) -> Envelope<'static> { + let msg = self.0; + let from = msg.from.iter().map(convert_mbx).collect::<Vec<_>>(); + + Envelope { + date: NString( + msg.date + .as_ref() + .map(|d| IString::try_from(d.to_rfc3339()).unwrap()), + ), + subject: NString( + msg.subject + .as_ref() + .map(|d| IString::try_from(d.to_string()).unwrap()), + ), + sender: msg + .sender + .as_ref() + .map(|v| vec![convert_mbx(v)]) + .unwrap_or(from.clone()), + reply_to: if msg.reply_to.is_empty() { + from.clone() + } else { + convert_addresses(&msg.reply_to) + }, + from, + to: convert_addresses(&msg.to), + cc: convert_addresses(&msg.cc), + bcc: convert_addresses(&msg.bcc), + in_reply_to: NString( + msg.in_reply_to + .iter() + .next() + .map(|d| IString::try_from(d.to_string()).unwrap()), + ), + message_id: NString( + msg.msg_id + .as_ref() + .map(|d| IString::try_from(d.to_string()).unwrap()), + ), + } + } +} + +pub fn convert_addresses(addrlist: &Vec<imf::address::AddressRef>) -> Vec<Address<'static>> { + let mut acc = vec![]; + for item in addrlist { + match item { + imf::address::AddressRef::Single(a) => acc.push(convert_mbx(a)), + imf::address::AddressRef::Many(l) => acc.extend(l.participants.iter().map(convert_mbx)), + } + } + return acc; +} + +pub fn convert_mbx(addr: &imf::mailbox::MailboxRef) -> Address<'static> { + Address { + name: NString( + addr.name + .as_ref() + .map(|x| IString::try_from(x.to_string()).unwrap()), + ), + // SMTP at-domain-list (source route) seems obsolete since at least 1991 + // https://www.mhonarc.org/archive/html/ietf-822/1991-06/msg00060.html + adl: NString(None), + mailbox: NString(Some( + IString::try_from(addr.addrspec.local_part.to_string()).unwrap(), + )), + host: NString(Some( + IString::try_from(addr.addrspec.domain.to_string()).unwrap(), + )), + } +} diff --git a/aero-proto/src/imap/index.rs b/aero-proto/src/imap/index.rs new file mode 100644 index 0000000..afe6991 --- /dev/null +++ b/aero-proto/src/imap/index.rs @@ -0,0 +1,211 @@ +use std::num::{NonZeroU32, NonZeroU64}; + +use anyhow::{anyhow, Result}; +use imap_codec::imap_types::sequence::{SeqOrUid, Sequence, SequenceSet}; + +use aero_collections::mail::uidindex::{ImapUid, ModSeq, UidIndex}; +use aero_collections::unique_ident::UniqueIdent; + +pub struct Index<'a> { + pub imap_index: Vec<MailIndex<'a>>, + pub internal: &'a UidIndex, +} +impl<'a> Index<'a> { + pub fn new(internal: &'a UidIndex) -> Result<Self> { + let imap_index = internal + .idx_by_uid + .iter() + .enumerate() + .map(|(i_enum, (&uid, &uuid))| { + let (_, modseq, flags) = internal + .table + .get(&uuid) + .ok_or(anyhow!("mail is missing from index"))?; + let i_int: u32 = (i_enum + 1).try_into()?; + let i: NonZeroU32 = i_int.try_into()?; + + Ok(MailIndex { + i, + uid, + uuid, + modseq: *modseq, + flags, + }) + }) + .collect::<Result<Vec<_>>>()?; + + Ok(Self { + imap_index, + internal, + }) + } + + pub fn last(&'a self) -> Option<&'a MailIndex<'a>> { + self.imap_index.last() + } + + /// Fetch mail descriptors based on a sequence of UID + /// + /// Complexity analysis: + /// - Sort is O(n * log n) where n is the number of uid generated by the sequence + /// - Finding the starting point in the index O(log m) where m is the size of the mailbox + /// While n =< m, it's not clear if the difference is big or not. + /// + /// For now, the algorithm tries to be fast for small values of n, + /// as it is what is expected by clients. + /// + /// So we assume for our implementation that : n << m. + /// It's not true for full mailbox searches for example... + pub fn fetch_on_uid(&'a self, sequence_set: &SequenceSet) -> Vec<&'a MailIndex<'a>> { + if self.imap_index.is_empty() { + return vec![]; + } + let largest = self.last().expect("The mailbox is not empty").uid; + let mut unroll_seq = sequence_set.iter(largest).collect::<Vec<_>>(); + unroll_seq.sort(); + + let start_seq = match unroll_seq.iter().next() { + Some(elem) => elem, + None => return vec![], + }; + + // Quickly jump to the right point in the mailbox vector O(log m) instead + // of iterating one by one O(m). Works only because both unroll_seq & imap_index are sorted per uid. + let mut imap_idx = { + let start_idx = self + .imap_index + .partition_point(|mail_idx| &mail_idx.uid < start_seq); + &self.imap_index[start_idx..] + }; + + let mut acc = vec![]; + for wanted_uid in unroll_seq.iter() { + // Slide the window forward as long as its first element is lower than our wanted uid. + let start_idx = match imap_idx.iter().position(|midx| &midx.uid >= wanted_uid) { + Some(v) => v, + None => break, + }; + imap_idx = &imap_idx[start_idx..]; + + // If the beginning of our new window is the uid we want, we collect it + if &imap_idx[0].uid == wanted_uid { + acc.push(&imap_idx[0]); + } + } + + acc + } + + pub fn fetch_on_id(&'a self, sequence_set: &SequenceSet) -> Result<Vec<&'a MailIndex<'a>>> { + if self.imap_index.is_empty() { + return Ok(vec![]); + } + let largest = NonZeroU32::try_from(self.imap_index.len() as u32)?; + let mut acc = sequence_set + .iter(largest) + .map(|wanted_id| { + self.imap_index + .get((wanted_id.get() as usize) - 1) + .ok_or(anyhow!("Mail not found")) + }) + .collect::<Result<Vec<_>>>()?; + + // Sort the result to be consistent with UID + acc.sort_by(|a, b| a.i.cmp(&b.i)); + + Ok(acc) + } + + pub fn fetch( + self: &'a Index<'a>, + sequence_set: &SequenceSet, + by_uid: bool, + ) -> Result<Vec<&'a MailIndex<'a>>> { + match by_uid { + true => Ok(self.fetch_on_uid(sequence_set)), + _ => self.fetch_on_id(sequence_set), + } + } + + pub fn fetch_changed_since( + self: &'a Index<'a>, + sequence_set: &SequenceSet, + maybe_modseq: Option<NonZeroU64>, + by_uid: bool, + ) -> Result<Vec<&'a MailIndex<'a>>> { + let raw = self.fetch(sequence_set, by_uid)?; + let res = match maybe_modseq { + Some(pit) => raw.into_iter().filter(|midx| midx.modseq > pit).collect(), + None => raw, + }; + + Ok(res) + } + + pub fn fetch_unchanged_since( + self: &'a Index<'a>, + sequence_set: &SequenceSet, + maybe_modseq: Option<NonZeroU64>, + by_uid: bool, + ) -> Result<(Vec<&'a MailIndex<'a>>, Vec<&'a MailIndex<'a>>)> { + let raw = self.fetch(sequence_set, by_uid)?; + let res = match maybe_modseq { + Some(pit) => raw.into_iter().partition(|midx| midx.modseq <= pit), + None => (raw, vec![]), + }; + + Ok(res) + } +} + +#[derive(Clone, Debug)] +pub struct MailIndex<'a> { + pub i: NonZeroU32, + pub uid: ImapUid, + pub uuid: UniqueIdent, + pub modseq: ModSeq, + pub flags: &'a Vec<String>, +} + +impl<'a> MailIndex<'a> { + // The following functions are used to implement the SEARCH command + pub fn is_in_sequence_i(&self, seq: &Sequence) -> bool { + match seq { + Sequence::Single(SeqOrUid::Asterisk) => true, + Sequence::Single(SeqOrUid::Value(target)) => target == &self.i, + Sequence::Range(SeqOrUid::Asterisk, SeqOrUid::Value(x)) + | Sequence::Range(SeqOrUid::Value(x), SeqOrUid::Asterisk) => x <= &self.i, + Sequence::Range(SeqOrUid::Value(x1), SeqOrUid::Value(x2)) => { + if x1 < x2 { + x1 <= &self.i && &self.i <= x2 + } else { + x1 >= &self.i && &self.i >= x2 + } + } + Sequence::Range(SeqOrUid::Asterisk, SeqOrUid::Asterisk) => true, + } + } + + pub fn is_in_sequence_uid(&self, seq: &Sequence) -> bool { + match seq { + Sequence::Single(SeqOrUid::Asterisk) => true, + Sequence::Single(SeqOrUid::Value(target)) => target == &self.uid, + Sequence::Range(SeqOrUid::Asterisk, SeqOrUid::Value(x)) + | Sequence::Range(SeqOrUid::Value(x), SeqOrUid::Asterisk) => x <= &self.uid, + Sequence::Range(SeqOrUid::Value(x1), SeqOrUid::Value(x2)) => { + if x1 < x2 { + x1 <= &self.uid && &self.uid <= x2 + } else { + x1 >= &self.uid && &self.uid >= x2 + } + } + Sequence::Range(SeqOrUid::Asterisk, SeqOrUid::Asterisk) => true, + } + } + + pub fn is_flag_set(&self, flag: &str) -> bool { + self.flags + .iter() + .any(|candidate| candidate.as_str() == flag) + } +} diff --git a/aero-proto/src/imap/mail_view.rs b/aero-proto/src/imap/mail_view.rs new file mode 100644 index 0000000..054014a --- /dev/null +++ b/aero-proto/src/imap/mail_view.rs @@ -0,0 +1,306 @@ +use std::num::NonZeroU32; + +use anyhow::{anyhow, bail, Result}; +use chrono::{naive::NaiveDate, DateTime as ChronoDateTime, Local, Offset, TimeZone, Utc}; + +use imap_codec::imap_types::core::NString; +use imap_codec::imap_types::datetime::DateTime; +use imap_codec::imap_types::fetch::{ + MessageDataItem, MessageDataItemName, Section as FetchSection, +}; +use imap_codec::imap_types::flag::Flag; +use imap_codec::imap_types::response::Data; + +use eml_codec::{ + imf, + part::{composite::Message, AnyPart}, +}; + +use aero_collections::mail::query::QueryResult; + +use crate::imap::attributes::AttributesProxy; +use crate::imap::flags; +use crate::imap::imf_view::ImfView; +use crate::imap::index::MailIndex; +use crate::imap::mime_view; +use crate::imap::response::Body; + +pub struct MailView<'a> { + pub in_idx: &'a MailIndex<'a>, + pub query_result: &'a QueryResult, + pub content: FetchedMail<'a>, +} + +impl<'a> MailView<'a> { + pub fn new(query_result: &'a QueryResult, in_idx: &'a MailIndex<'a>) -> Result<MailView<'a>> { + Ok(Self { + in_idx, + query_result, + content: match query_result { + QueryResult::FullResult { content, .. } => { + let (_, parsed) = + eml_codec::parse_message(&content).or(Err(anyhow!("Invalid mail body")))?; + FetchedMail::full_from_message(parsed) + } + QueryResult::PartialResult { metadata, .. } => { + let (_, parsed) = eml_codec::parse_message(&metadata.headers) + .or(Err(anyhow!("unable to parse email headers")))?; + FetchedMail::partial_from_message(parsed) + } + QueryResult::IndexResult { .. } => FetchedMail::IndexOnly, + }, + }) + } + + pub fn imf(&self) -> Option<ImfView> { + self.content.as_imf().map(ImfView) + } + + pub fn selected_mime(&'a self) -> Option<mime_view::SelectedMime<'a>> { + self.content.as_anypart().ok().map(mime_view::SelectedMime) + } + + pub fn filter(&self, ap: &AttributesProxy) -> Result<(Body<'static>, SeenFlag)> { + let mut seen = SeenFlag::DoNothing; + let res_attrs = ap + .attrs + .iter() + .map(|attr| match attr { + MessageDataItemName::Uid => Ok(self.uid()), + MessageDataItemName::Flags => Ok(self.flags()), + MessageDataItemName::Rfc822Size => self.rfc_822_size(), + MessageDataItemName::Rfc822Header => self.rfc_822_header(), + MessageDataItemName::Rfc822Text => self.rfc_822_text(), + MessageDataItemName::Rfc822 => { + if self.is_not_yet_seen() { + seen = SeenFlag::MustAdd; + } + self.rfc822() + } + MessageDataItemName::Envelope => Ok(self.envelope()), + MessageDataItemName::Body => self.body(), + MessageDataItemName::BodyStructure => self.body_structure(), + MessageDataItemName::BodyExt { + section, + partial, + peek, + } => { + let (body, has_seen) = self.body_ext(section, partial, peek)?; + seen = has_seen; + Ok(body) + } + MessageDataItemName::InternalDate => self.internal_date(), + MessageDataItemName::ModSeq => Ok(self.modseq()), + }) + .collect::<Result<Vec<_>, _>>()?; + + Ok(( + Body::Data(Data::Fetch { + seq: self.in_idx.i, + items: res_attrs.try_into()?, + }), + seen, + )) + } + + pub fn stored_naive_date(&self) -> Result<NaiveDate> { + let mail_meta = self.query_result.metadata().expect("metadata were fetched"); + let mail_ts: i64 = mail_meta.internaldate.try_into()?; + let msg_date: ChronoDateTime<Local> = ChronoDateTime::from_timestamp(mail_ts, 0) + .ok_or(anyhow!("unable to parse timestamp"))? + .with_timezone(&Local); + + Ok(msg_date.date_naive()) + } + + pub fn is_header_contains_pattern(&self, hdr: &[u8], pattern: &[u8]) -> bool { + let mime = match self.selected_mime() { + None => return false, + Some(x) => x, + }; + + let val = match mime.header_value(hdr) { + None => return false, + Some(x) => x, + }; + + val.windows(pattern.len()).any(|win| win == pattern) + } + + // Private function, mainly for filter! + fn uid(&self) -> MessageDataItem<'static> { + MessageDataItem::Uid(self.in_idx.uid.clone()) + } + + fn flags(&self) -> MessageDataItem<'static> { + MessageDataItem::Flags( + self.in_idx + .flags + .iter() + .filter_map(|f| flags::from_str(f)) + .collect(), + ) + } + + fn rfc_822_size(&self) -> Result<MessageDataItem<'static>> { + let sz = self + .query_result + .metadata() + .ok_or(anyhow!("mail metadata are required"))? + .rfc822_size; + Ok(MessageDataItem::Rfc822Size(sz as u32)) + } + + fn rfc_822_header(&self) -> Result<MessageDataItem<'static>> { + let hdrs: NString = self + .query_result + .metadata() + .ok_or(anyhow!("mail metadata are required"))? + .headers + .to_vec() + .try_into()?; + Ok(MessageDataItem::Rfc822Header(hdrs)) + } + + fn rfc_822_text(&self) -> Result<MessageDataItem<'static>> { + let txt: NString = self.content.as_msg()?.raw_body.to_vec().try_into()?; + Ok(MessageDataItem::Rfc822Text(txt)) + } + + fn rfc822(&self) -> Result<MessageDataItem<'static>> { + let full: NString = self.content.as_msg()?.raw_part.to_vec().try_into()?; + Ok(MessageDataItem::Rfc822(full)) + } + + fn envelope(&self) -> MessageDataItem<'static> { + MessageDataItem::Envelope( + self.imf() + .expect("an imf object is derivable from fetchedmail") + .message_envelope(), + ) + } + + fn body(&self) -> Result<MessageDataItem<'static>> { + Ok(MessageDataItem::Body(mime_view::bodystructure( + self.content.as_msg()?.child.as_ref(), + false, + )?)) + } + + fn body_structure(&self) -> Result<MessageDataItem<'static>> { + Ok(MessageDataItem::BodyStructure(mime_view::bodystructure( + self.content.as_msg()?.child.as_ref(), + true, + )?)) + } + + fn is_not_yet_seen(&self) -> bool { + let seen_flag = Flag::Seen.to_string(); + !self.in_idx.flags.iter().any(|x| *x == seen_flag) + } + + /// maps to BODY[<section>]<<partial>> and BODY.PEEK[<section>]<<partial>> + /// peek does not implicitly set the \Seen flag + /// eg. BODY[HEADER.FIELDS (DATE FROM)] + /// eg. BODY[]<0.2048> + fn body_ext( + &self, + section: &Option<FetchSection<'static>>, + partial: &Option<(u32, NonZeroU32)>, + peek: &bool, + ) -> Result<(MessageDataItem<'static>, SeenFlag)> { + // Manage Seen flag + let mut seen = SeenFlag::DoNothing; + if !peek && self.is_not_yet_seen() { + // Add \Seen flag + //self.mailbox.add_flags(uuid, &[seen_flag]).await?; + seen = SeenFlag::MustAdd; + } + + // Process message + let (text, origin) = + match mime_view::body_ext(self.content.as_anypart()?, section, partial)? { + mime_view::BodySection::Full(body) => (body, None), + mime_view::BodySection::Slice { body, origin_octet } => (body, Some(origin_octet)), + }; + + let data: NString = text.to_vec().try_into()?; + + return Ok(( + MessageDataItem::BodyExt { + section: section.as_ref().map(|fs| fs.clone()), + origin, + data, + }, + seen, + )); + } + + fn internal_date(&self) -> Result<MessageDataItem<'static>> { + let dt = Utc + .fix() + .timestamp_opt( + i64::try_from( + self.query_result + .metadata() + .ok_or(anyhow!("mail metadata were not fetched"))? + .internaldate + / 1000, + )?, + 0, + ) + .earliest() + .ok_or(anyhow!("Unable to parse internal date"))?; + Ok(MessageDataItem::InternalDate(DateTime::unvalidated(dt))) + } + + fn modseq(&self) -> MessageDataItem<'static> { + MessageDataItem::ModSeq(self.in_idx.modseq) + } +} + +pub enum SeenFlag { + DoNothing, + MustAdd, +} + +// ------------------- + +pub enum FetchedMail<'a> { + IndexOnly, + Partial(AnyPart<'a>), + Full(AnyPart<'a>), +} +impl<'a> FetchedMail<'a> { + pub fn full_from_message(msg: Message<'a>) -> Self { + Self::Full(AnyPart::Msg(msg)) + } + + pub fn partial_from_message(msg: Message<'a>) -> Self { + Self::Partial(AnyPart::Msg(msg)) + } + + pub fn as_anypart(&self) -> Result<&AnyPart<'a>> { + match self { + FetchedMail::Full(x) => Ok(&x), + FetchedMail::Partial(x) => Ok(&x), + _ => bail!("The full message must be fetched, not only its headers"), + } + } + + pub fn as_msg(&self) -> Result<&Message<'a>> { + match self { + FetchedMail::Full(AnyPart::Msg(x)) => Ok(&x), + FetchedMail::Partial(AnyPart::Msg(x)) => Ok(&x), + _ => bail!("The full message must be fetched, not only its headers AND it must be an AnyPart::Msg."), + } + } + + pub fn as_imf(&self) -> Option<&imf::Imf<'a>> { + match self { + FetchedMail::Full(AnyPart::Msg(x)) => Some(&x.imf), + FetchedMail::Partial(AnyPart::Msg(x)) => Some(&x.imf), + _ => None, + } + } +} diff --git a/aero-proto/src/imap/mailbox_view.rs b/aero-proto/src/imap/mailbox_view.rs new file mode 100644 index 0000000..0b808aa --- /dev/null +++ b/aero-proto/src/imap/mailbox_view.rs @@ -0,0 +1,772 @@ +use std::collections::HashSet; +use std::num::{NonZeroU32, NonZeroU64}; +use std::sync::Arc; + +use anyhow::{anyhow, Error, Result}; + +use futures::stream::{StreamExt, TryStreamExt}; + +use imap_codec::imap_types::core::Charset; +use imap_codec::imap_types::fetch::MessageDataItem; +use imap_codec::imap_types::flag::{Flag, FlagFetch, FlagPerm, StoreResponse, StoreType}; +use imap_codec::imap_types::response::{Code, CodeOther, Data, Status}; +use imap_codec::imap_types::search::SearchKey; +use imap_codec::imap_types::sequence::SequenceSet; + +use aero_collections::mail::mailbox::Mailbox; +use aero_collections::mail::query::QueryScope; +use aero_collections::mail::snapshot::FrozenMailbox; +use aero_collections::mail::uidindex::{ImapUid, ImapUidvalidity, ModSeq}; +use aero_collections::unique_ident::UniqueIdent; + +use crate::imap::attributes::AttributesProxy; +use crate::imap::flags; +use crate::imap::index::Index; +use crate::imap::mail_view::{MailView, SeenFlag}; +use crate::imap::response::Body; +use crate::imap::search; + +const DEFAULT_FLAGS: [Flag; 5] = [ + Flag::Seen, + Flag::Answered, + Flag::Flagged, + Flag::Deleted, + Flag::Draft, +]; + +pub struct UpdateParameters { + pub silence: HashSet<UniqueIdent>, + pub with_modseq: bool, + pub with_uid: bool, +} +impl Default for UpdateParameters { + fn default() -> Self { + Self { + silence: HashSet::new(), + with_modseq: false, + with_uid: false, + } + } +} + +/// A MailboxView is responsible for giving the client the information +/// it needs about a mailbox, such as an initial summary of the mailbox's +/// content and continuous updates indicating when the content +/// of the mailbox has been changed. +/// To do this, it keeps a variable `known_state` that corresponds to +/// what the client knows, and produces IMAP messages to be sent to the +/// client that go along updates to `known_state`. +pub struct MailboxView { + pub internal: FrozenMailbox, + pub is_condstore: bool, +} + +impl MailboxView { + /// Creates a new IMAP view into a mailbox. + pub async fn new(mailbox: Arc<Mailbox>, is_cond: bool) -> Self { + Self { + internal: mailbox.frozen().await, + is_condstore: is_cond, + } + } + + /// Create an updated view, useful to make a diff + /// between what the client knows and new stuff + /// Produces a set of IMAP responses describing the change between + /// what the client knows and what is actually in the mailbox. + /// This does NOT trigger a sync, it bases itself on what is currently + /// loaded in RAM by Bayou. + pub async fn update(&mut self, params: UpdateParameters) -> Result<Vec<Body<'static>>> { + let old_snapshot = self.internal.update().await; + let new_snapshot = &self.internal.snapshot; + + let mut data = Vec::<Body>::new(); + + // Calculate diff between two mailbox states + // See example in IMAP RFC in section on NOOP command: + // we want to produce something like this: + // C: a047 NOOP + // S: * 22 EXPUNGE + // S: * 23 EXISTS + // S: * 14 FETCH (UID 1305 FLAGS (\Seen \Deleted)) + // S: a047 OK Noop completed + // In other words: + // - notify client of expunged mails + // - if new mails arrived, notify client of number of existing mails + // - if flags changed for existing mails, tell client + // (for this last step: if uidvalidity changed, do nothing, + // just notify of new uidvalidity and they will resync) + + // - notify client of expunged mails + let mut n_expunge = 0; + for (i, (_uid, uuid)) in old_snapshot.idx_by_uid.iter().enumerate() { + if !new_snapshot.table.contains_key(uuid) { + data.push(Body::Data(Data::Expunge( + NonZeroU32::try_from((i + 1 - n_expunge) as u32).unwrap(), + ))); + n_expunge += 1; + } + } + + // - if new mails arrived, notify client of number of existing mails + if new_snapshot.table.len() != old_snapshot.table.len() - n_expunge + || new_snapshot.uidvalidity != old_snapshot.uidvalidity + { + data.push(self.exists_status()?); + } + + if new_snapshot.uidvalidity != old_snapshot.uidvalidity { + // TODO: do we want to push less/more info than this? + data.push(self.uidvalidity_status()?); + data.push(self.uidnext_status()?); + } else { + // - if flags changed for existing mails, tell client + for (i, (_uid, uuid)) in new_snapshot.idx_by_uid.iter().enumerate() { + if params.silence.contains(uuid) { + continue; + } + + let old_mail = old_snapshot.table.get(uuid); + let new_mail = new_snapshot.table.get(uuid); + if old_mail.is_some() && old_mail != new_mail { + if let Some((uid, modseq, flags)) = new_mail { + let mut items = vec![MessageDataItem::Flags( + flags.iter().filter_map(|f| flags::from_str(f)).collect(), + )]; + + if params.with_uid { + items.push(MessageDataItem::Uid(*uid)); + } + + if params.with_modseq { + items.push(MessageDataItem::ModSeq(*modseq)); + } + + data.push(Body::Data(Data::Fetch { + seq: NonZeroU32::try_from((i + 1) as u32).unwrap(), + items: items.try_into()?, + })); + } + } + } + } + Ok(data) + } + + /// Generates the necessary IMAP messages so that the client + /// has a satisfactory summary of the current mailbox's state. + /// These are the messages that are sent in response to a SELECT command. + pub fn summary(&self) -> Result<Vec<Body<'static>>> { + let mut data = Vec::<Body>::new(); + data.push(self.exists_status()?); + data.push(self.recent_status()?); + data.extend(self.flags_status()?.into_iter()); + data.push(self.uidvalidity_status()?); + data.push(self.uidnext_status()?); + if self.is_condstore { + data.push(self.highestmodseq_status()?); + } + /*self.unseen_first_status()? + .map(|unseen_status| data.push(unseen_status));*/ + + Ok(data) + } + + pub async fn store<'a>( + &mut self, + sequence_set: &SequenceSet, + kind: &StoreType, + response: &StoreResponse, + flags: &[Flag<'a>], + unchanged_since: Option<NonZeroU64>, + is_uid_store: &bool, + ) -> Result<(Vec<Body<'static>>, Vec<NonZeroU32>)> { + self.internal.sync().await?; + + let flags = flags.iter().map(|x| x.to_string()).collect::<Vec<_>>(); + + let idx = self.index()?; + let (editable, in_conflict) = + idx.fetch_unchanged_since(sequence_set, unchanged_since, *is_uid_store)?; + + for mi in editable.iter() { + match kind { + StoreType::Add => { + self.internal.mailbox.add_flags(mi.uuid, &flags[..]).await?; + } + StoreType::Remove => { + self.internal.mailbox.del_flags(mi.uuid, &flags[..]).await?; + } + StoreType::Replace => { + self.internal.mailbox.set_flags(mi.uuid, &flags[..]).await?; + } + } + } + + let silence = match response { + StoreResponse::Answer => HashSet::new(), + StoreResponse::Silent => editable.iter().map(|midx| midx.uuid).collect(), + }; + + let conflict_id_or_uid = match is_uid_store { + true => in_conflict.into_iter().map(|midx| midx.uid).collect(), + _ => in_conflict.into_iter().map(|midx| midx.i).collect(), + }; + + let summary = self + .update(UpdateParameters { + with_uid: *is_uid_store, + with_modseq: unchanged_since.is_some(), + silence, + }) + .await?; + + Ok((summary, conflict_id_or_uid)) + } + + pub async fn idle_sync(&mut self) -> Result<Vec<Body<'static>>> { + self.internal + .mailbox + .notify() + .await + .upgrade() + .ok_or(anyhow!("test"))? + .notified() + .await; + self.internal.mailbox.opportunistic_sync().await?; + self.update(UpdateParameters::default()).await + } + + pub async fn expunge( + &mut self, + maybe_seq_set: &Option<SequenceSet>, + ) -> Result<Vec<Body<'static>>> { + // Get a recent view to apply our change + self.internal.sync().await?; + let state = self.internal.peek().await; + let idx = Index::new(&state)?; + + // Build a default sequence set for the default case + use imap_codec::imap_types::sequence::{SeqOrUid, Sequence}; + let seq = match maybe_seq_set { + Some(s) => s.clone(), + None => SequenceSet( + vec![Sequence::Range( + SeqOrUid::Value(NonZeroU32::MIN), + SeqOrUid::Asterisk, + )] + .try_into() + .unwrap(), + ), + }; + + let deleted_flag = Flag::Deleted.to_string(); + let msgs = idx + .fetch_on_uid(&seq) + .into_iter() + .filter(|midx| midx.flags.iter().any(|x| *x == deleted_flag)) + .map(|midx| midx.uuid); + + for msg in msgs { + self.internal.mailbox.delete(msg).await?; + } + + self.update(UpdateParameters::default()).await + } + + pub async fn copy( + &self, + sequence_set: &SequenceSet, + to: Arc<Mailbox>, + is_uid_copy: &bool, + ) -> Result<(ImapUidvalidity, Vec<(ImapUid, ImapUid)>)> { + let idx = self.index()?; + let mails = idx.fetch(sequence_set, *is_uid_copy)?; + + let mut new_uuids = vec![]; + for mi in mails.iter() { + new_uuids.push(to.copy_from(&self.internal.mailbox, mi.uuid).await?); + } + + let mut ret = vec![]; + let to_state = to.current_uid_index().await; + for (mi, new_uuid) in mails.iter().zip(new_uuids.iter()) { + let dest_uid = to_state + .table + .get(new_uuid) + .ok_or(anyhow!("copied mail not in destination mailbox"))? + .0; + ret.push((mi.uid, dest_uid)); + } + + Ok((to_state.uidvalidity, ret)) + } + + pub async fn r#move( + &mut self, + sequence_set: &SequenceSet, + to: Arc<Mailbox>, + is_uid_copy: &bool, + ) -> Result<(ImapUidvalidity, Vec<(ImapUid, ImapUid)>, Vec<Body<'static>>)> { + let idx = self.index()?; + let mails = idx.fetch(sequence_set, *is_uid_copy)?; + + for mi in mails.iter() { + to.move_from(&self.internal.mailbox, mi.uuid).await?; + } + + let mut ret = vec![]; + let to_state = to.current_uid_index().await; + for mi in mails.iter() { + let dest_uid = to_state + .table + .get(&mi.uuid) + .ok_or(anyhow!("moved mail not in destination mailbox"))? + .0; + ret.push((mi.uid, dest_uid)); + } + + let update = self + .update(UpdateParameters { + with_uid: *is_uid_copy, + ..UpdateParameters::default() + }) + .await?; + + Ok((to_state.uidvalidity, ret, update)) + } + + /// Looks up state changes in the mailbox and produces a set of IMAP + /// responses describing the new state. + pub async fn fetch<'b>( + &self, + sequence_set: &SequenceSet, + ap: &AttributesProxy, + changed_since: Option<NonZeroU64>, + is_uid_fetch: &bool, + ) -> Result<Vec<Body<'static>>> { + // [1/6] Pre-compute data + // a. what are the uuids of the emails we want? + // b. do we need to fetch the full body? + //let ap = AttributesProxy::new(attributes, *is_uid_fetch); + let query_scope = match ap.need_body() { + true => QueryScope::Full, + _ => QueryScope::Partial, + }; + tracing::debug!("Query scope {:?}", query_scope); + let idx = self.index()?; + let mail_idx_list = idx.fetch_changed_since(sequence_set, changed_since, *is_uid_fetch)?; + + // [2/6] Fetch the emails + let uuids = mail_idx_list + .iter() + .map(|midx| midx.uuid) + .collect::<Vec<_>>(); + + let query = self.internal.query(&uuids, query_scope); + //let query_result = self.internal.query(&uuids, query_scope).fetch().await?; + + let query_stream = query + .fetch() + .zip(futures::stream::iter(mail_idx_list)) + // [3/6] Derive an IMAP-specific view from the results, apply the filters + .map(|(maybe_qr, midx)| match maybe_qr { + Ok(qr) => Ok((MailView::new(&qr, midx)?.filter(&ap)?, midx)), + Err(e) => Err(e), + }) + // [4/6] Apply the IMAP transformation + .then(|maybe_ret| async move { + let ((body, seen), midx) = maybe_ret?; + + // [5/6] Register the \Seen flags + if matches!(seen, SeenFlag::MustAdd) { + let seen_flag = Flag::Seen.to_string(); + self.internal + .mailbox + .add_flags(midx.uuid, &[seen_flag]) + .await?; + } + + Ok::<_, anyhow::Error>(body) + }); + + // [6/6] Build the final result that will be sent to the client. + query_stream.try_collect().await + } + + /// A naive search implementation... + pub async fn search<'a>( + &self, + _charset: &Option<Charset<'a>>, + search_key: &SearchKey<'a>, + uid: bool, + ) -> Result<(Vec<Body<'static>>, bool)> { + // 1. Compute the subset of sequence identifiers we need to fetch + // based on the search query + let crit = search::Criteria(search_key); + let (seq_set, seq_type) = crit.to_sequence_set(); + + // 2. Get the selection + let idx = self.index()?; + let selection = idx.fetch(&seq_set, seq_type.is_uid())?; + + // 3. Filter the selection based on the ID / UID / Flags + let (kept_idx, to_fetch) = crit.filter_on_idx(&selection); + + // 4.a Fetch additional info about the emails + let query_scope = crit.query_scope(); + let uuids = to_fetch.iter().map(|midx| midx.uuid).collect::<Vec<_>>(); + let query = self.internal.query(&uuids, query_scope); + + // 4.b We don't want to keep all data in memory, so we do the computing in a stream + let query_stream = query + .fetch() + .zip(futures::stream::iter(&to_fetch)) + // 5.a Build a mailview with the body, might fail with an error + // 5.b If needed, filter the selection based on the body, but keep the errors + // 6. Drop the query+mailbox, keep only the mail index + // Here we release a lot of memory, this is the most important part ^^ + .filter_map(|(maybe_qr, midx)| { + let r = match maybe_qr { + Ok(qr) => match MailView::new(&qr, midx).map(|mv| crit.is_keep_on_query(&mv)) { + Ok(true) => Some(Ok(*midx)), + Ok(_) => None, + Err(e) => Some(Err(e)), + }, + Err(e) => Some(Err(e)), + }; + futures::future::ready(r) + }); + + // 7. Chain both streams (part resolved from index, part resolved from metadata+body) + let main_stream = futures::stream::iter(kept_idx) + .map(Ok) + .chain(query_stream) + .map_ok(|idx| match uid { + true => (idx.uid, idx.modseq), + _ => (idx.i, idx.modseq), + }); + + // 8. Do the actual computation + let internal_result: Vec<_> = main_stream.try_collect().await?; + let (selection, modseqs): (Vec<_>, Vec<_>) = internal_result.into_iter().unzip(); + + // 9. Aggregate the maximum modseq value + let maybe_modseq = match crit.is_modseq() { + true => modseqs.into_iter().max(), + _ => None, + }; + + // 10. Return the final result + Ok(( + vec![Body::Data(Data::Search(selection, maybe_modseq))], + maybe_modseq.is_some(), + )) + } + + // ---- + /// @FIXME index should be stored for longer than a single request + /// Instead they should be tied to the FrozenMailbox refresh + /// It's not trivial to refactor the code to do that, so we are doing + /// some useless computation for now... + fn index<'a>(&'a self) -> Result<Index<'a>> { + Index::new(&self.internal.snapshot) + } + + /// Produce an OK [UIDVALIDITY _] message corresponding to `known_state` + fn uidvalidity_status(&self) -> Result<Body<'static>> { + let uid_validity = Status::ok( + None, + Some(Code::UidValidity(self.uidvalidity())), + "UIDs valid", + ) + .map_err(Error::msg)?; + Ok(Body::Status(uid_validity)) + } + + pub(crate) fn uidvalidity(&self) -> ImapUidvalidity { + self.internal.snapshot.uidvalidity + } + + /// Produce an OK [UIDNEXT _] message corresponding to `known_state` + fn uidnext_status(&self) -> Result<Body<'static>> { + let next_uid = Status::ok( + None, + Some(Code::UidNext(self.uidnext())), + "Predict next UID", + ) + .map_err(Error::msg)?; + Ok(Body::Status(next_uid)) + } + + pub(crate) fn uidnext(&self) -> ImapUid { + self.internal.snapshot.uidnext + } + + pub(crate) fn highestmodseq_status(&self) -> Result<Body<'static>> { + Ok(Body::Status(Status::ok( + None, + Some(Code::Other(CodeOther::unvalidated( + format!("HIGHESTMODSEQ {}", self.highestmodseq()).into_bytes(), + ))), + "Highest", + )?)) + } + + pub(crate) fn highestmodseq(&self) -> ModSeq { + self.internal.snapshot.highestmodseq + } + + /// Produce an EXISTS message corresponding to the number of mails + /// in `known_state` + fn exists_status(&self) -> Result<Body<'static>> { + Ok(Body::Data(Data::Exists(self.exists()?))) + } + + pub(crate) fn exists(&self) -> Result<u32> { + Ok(u32::try_from(self.internal.snapshot.idx_by_uid.len())?) + } + + /// Produce a RECENT message corresponding to the number of + /// recent mails in `known_state` + fn recent_status(&self) -> Result<Body<'static>> { + Ok(Body::Data(Data::Recent(self.recent()?))) + } + + #[allow(dead_code)] + fn unseen_first_status(&self) -> Result<Option<Body<'static>>> { + Ok(self + .unseen_first()? + .map(|unseen_id| { + Status::ok(None, Some(Code::Unseen(unseen_id)), "First unseen.").map(Body::Status) + }) + .transpose()?) + } + + #[allow(dead_code)] + fn unseen_first(&self) -> Result<Option<NonZeroU32>> { + Ok(self + .internal + .snapshot + .table + .values() + .enumerate() + .find(|(_i, (_imap_uid, _modseq, flags))| !flags.contains(&"\\Seen".to_string())) + .map(|(i, _)| NonZeroU32::try_from(i as u32 + 1)) + .transpose()?) + } + + pub(crate) fn recent(&self) -> Result<u32> { + let recent = self + .internal + .snapshot + .idx_by_flag + .get(&"\\Recent".to_string()) + .map(|os| os.len()) + .unwrap_or(0); + Ok(u32::try_from(recent)?) + } + + /// Produce a FLAGS and a PERMANENTFLAGS message that indicates + /// the flags that are in `known_state` + default flags + fn flags_status(&self) -> Result<Vec<Body<'static>>> { + let mut body = vec![]; + + // 1. Collecting all the possible flags in the mailbox + // 1.a Fetch them from our index + let mut known_flags: Vec<Flag> = self + .internal + .snapshot + .idx_by_flag + .flags() + .filter_map(|f| match flags::from_str(f) { + Some(FlagFetch::Flag(fl)) => Some(fl), + _ => None, + }) + .collect(); + // 1.b Merge it with our default flags list + for f in DEFAULT_FLAGS.iter() { + if !known_flags.contains(f) { + known_flags.push(f.clone()); + } + } + // 1.c Create the IMAP message + body.push(Body::Data(Data::Flags(known_flags.clone()))); + + // 2. Returning flags that are persisted + // 2.a Always advertise our default flags + let mut permanent = DEFAULT_FLAGS + .iter() + .map(|f| FlagPerm::Flag(f.clone())) + .collect::<Vec<_>>(); + // 2.b Say that we support any keyword flag + permanent.push(FlagPerm::Asterisk); + // 2.c Create the IMAP message + let permanent_flags = Status::ok( + None, + Some(Code::PermanentFlags(permanent)), + "Flags permitted", + ) + .map_err(Error::msg)?; + body.push(Body::Status(permanent_flags)); + + // Done! + Ok(body) + } + + pub(crate) fn unseen_count(&self) -> usize { + let total = self.internal.snapshot.table.len(); + let seen = self + .internal + .snapshot + .idx_by_flag + .get(&Flag::Seen.to_string()) + .map(|x| x.len()) + .unwrap_or(0); + total - seen + } +} + +#[cfg(test)] +mod tests { + use super::*; + use imap_codec::encode::Encoder; + use imap_codec::imap_types::core::Vec1; + use imap_codec::imap_types::fetch::Section; + use imap_codec::imap_types::fetch::{MacroOrMessageDataItemNames, MessageDataItemName}; + use imap_codec::imap_types::response::Response; + use imap_codec::ResponseCodec; + use std::fs; + + use aero_collections::mail::mailbox::MailMeta; + use aero_collections::mail::query::QueryResult; + use aero_collections::unique_ident; + use aero_user::cryptoblob; + + use crate::imap::index::MailIndex; + use crate::imap::mime_view; + + #[test] + fn mailview_body_ext() -> Result<()> { + let ap = AttributesProxy::new( + &MacroOrMessageDataItemNames::MessageDataItemNames(vec![ + MessageDataItemName::BodyExt { + section: Some(Section::Header(None)), + partial: None, + peek: false, + }, + ]), + &[], + false, + ); + + let key = cryptoblob::gen_key(); + let meta = MailMeta { + internaldate: 0u64, + headers: vec![], + message_key: key, + rfc822_size: 8usize, + }; + + let index_entry = (NonZeroU32::MIN, NonZeroU64::MIN, vec![]); + let mail_in_idx = MailIndex { + i: NonZeroU32::MIN, + uid: index_entry.0, + modseq: index_entry.1, + uuid: unique_ident::gen_ident(), + flags: &index_entry.2, + }; + let rfc822 = b"Subject: hello\r\nFrom: a@a.a\r\nTo: b@b.b\r\nDate: Thu, 12 Oct 2023 08:45:28 +0000\r\n\r\nhello world"; + let qr = QueryResult::FullResult { + uuid: mail_in_idx.uuid.clone(), + metadata: meta, + content: rfc822.to_vec(), + }; + + let mv = MailView::new(&qr, &mail_in_idx)?; + let (res_body, _seen) = mv.filter(&ap)?; + + let fattr = match res_body { + Body::Data(Data::Fetch { + seq: _seq, + items: attr, + }) => Ok(attr), + _ => Err(anyhow!("Not a fetch body")), + }?; + + assert_eq!(fattr.as_ref().len(), 1); + + let (sec, _orig, _data) = match &fattr.as_ref()[0] { + MessageDataItem::BodyExt { + section, + origin, + data, + } => Ok((section, origin, data)), + _ => Err(anyhow!("not a body ext message attribute")), + }?; + + assert_eq!(sec.as_ref().unwrap(), &Section::Header(None)); + + Ok(()) + } + + /// Future automated test. We use lossy utf8 conversion + lowercase everything, + /// so this test might allow invalid results. But at least it allows us to quickly test a + /// large variety of emails. + /// Keep in mind that special cases must still be tested manually! + #[test] + fn fetch_body() -> Result<()> { + let prefixes = [ + /* *** MY OWN DATASET *** */ + "tests/emails/dxflrs/0001_simple", + "tests/emails/dxflrs/0002_mime", + "tests/emails/dxflrs/0003_mime-in-mime", + "tests/emails/dxflrs/0004_msg-in-msg", + // eml_codec do not support continuation for the moment + //"tests/emails/dxflrs/0005_mail-parser-readme", + "tests/emails/dxflrs/0006_single-mime", + "tests/emails/dxflrs/0007_raw_msg_in_rfc822", + /* *** (STRANGE) RFC *** */ + //"tests/emails/rfc/000", // must return text/enriched, we return text/plain + //"tests/emails/rfc/001", // does not recognize the multipart/external-body, breaks the + // whole parsing + //"tests/emails/rfc/002", // wrong date in email + + //"tests/emails/rfc/003", // dovecot fixes \r\r: the bytes number is wrong + text/enriched + + /* *** THIRD PARTY *** */ + //"tests/emails/thirdparty/000", // dovecot fixes \r\r: the bytes number is wrong + //"tests/emails/thirdparty/001", // same + "tests/emails/thirdparty/002", // same + + /* *** LEGACY *** */ + //"tests/emails/legacy/000", // same issue with \r\r + ]; + + for pref in prefixes.iter() { + println!("{}", pref); + let txt = fs::read(format!("../{}.eml", pref))?; + let oracle = fs::read(format!("../{}.dovecot.body", pref))?; + let message = eml_codec::parse_message(&txt).unwrap().1; + + let test_repr = Response::Data(Data::Fetch { + seq: NonZeroU32::new(1).unwrap(), + items: Vec1::from(MessageDataItem::Body(mime_view::bodystructure( + &message.child, + false, + )?)), + }); + let test_bytes = ResponseCodec::new().encode(&test_repr).dump(); + let test_str = String::from_utf8_lossy(&test_bytes).to_lowercase(); + + let oracle_str = + format!("* 1 FETCH {}\r\n", String::from_utf8_lossy(&oracle)).to_lowercase(); + + println!("aerogramme: {}\n\ndovecot: {}\n\n", test_str, oracle_str); + //println!("\n\n {} \n\n", String::from_utf8_lossy(&resp)); + assert_eq!(test_str, oracle_str); + } + + Ok(()) + } +} diff --git a/aero-proto/src/imap/mime_view.rs b/aero-proto/src/imap/mime_view.rs new file mode 100644 index 0000000..fd0f4b0 --- /dev/null +++ b/aero-proto/src/imap/mime_view.rs @@ -0,0 +1,582 @@ +use std::borrow::Cow; +use std::collections::HashSet; +use std::num::NonZeroU32; + +use anyhow::{anyhow, bail, Result}; + +use imap_codec::imap_types::body::{ + BasicFields, Body as FetchBody, BodyStructure, MultiPartExtensionData, SinglePartExtensionData, + SpecificFields, +}; +use imap_codec::imap_types::core::{AString, IString, NString, Vec1}; +use imap_codec::imap_types::fetch::{Part as FetchPart, Section as FetchSection}; + +use eml_codec::{ + header, mime, mime::r#type::Deductible, part::composite, part::discrete, part::AnyPart, +}; + +use crate::imap::imf_view::ImfView; + +pub enum BodySection<'a> { + Full(Cow<'a, [u8]>), + Slice { + body: Cow<'a, [u8]>, + origin_octet: u32, + }, +} + +/// Logic for BODY[<section>]<<partial>> +/// Works in 3 times: +/// 1. Find the section (RootMime::subset) +/// 2. Apply the extraction logic (SelectedMime::extract), like TEXT, HEADERS, etc. +/// 3. Keep only the given subset provided by partial +/// +/// Example of message sections: +/// +/// ```text +/// HEADER ([RFC-2822] header of the message) +/// TEXT ([RFC-2822] text body of the message) MULTIPART/MIXED +/// 1 TEXT/PLAIN +/// 2 APPLICATION/OCTET-STREAM +/// 3 MESSAGE/RFC822 +/// 3.HEADER ([RFC-2822] header of the message) +/// 3.TEXT ([RFC-2822] text body of the message) MULTIPART/MIXED +/// 3.1 TEXT/PLAIN +/// 3.2 APPLICATION/OCTET-STREAM +/// 4 MULTIPART/MIXED +/// 4.1 IMAGE/GIF +/// 4.1.MIME ([MIME-IMB] header for the IMAGE/GIF) +/// 4.2 MESSAGE/RFC822 +/// 4.2.HEADER ([RFC-2822] header of the message) +/// 4.2.TEXT ([RFC-2822] text body of the message) MULTIPART/MIXED +/// 4.2.1 TEXT/PLAIN +/// 4.2.2 MULTIPART/ALTERNATIVE +/// 4.2.2.1 TEXT/PLAIN +/// 4.2.2.2 TEXT/RICHTEXT +/// ``` +pub fn body_ext<'a>( + part: &'a AnyPart<'a>, + section: &'a Option<FetchSection<'a>>, + partial: &'a Option<(u32, NonZeroU32)>, +) -> Result<BodySection<'a>> { + let root_mime = NodeMime(part); + let (extractor, path) = SubsettedSection::from(section); + let selected_mime = root_mime.subset(path)?; + let extracted_full = selected_mime.extract(&extractor)?; + Ok(extracted_full.to_body_section(partial)) +} + +/// Logic for BODY and BODYSTRUCTURE +/// +/// ```raw +/// b fetch 29878:29879 (BODY) +/// * 29878 FETCH (BODY (("text" "plain" ("charset" "utf-8") NIL NIL "quoted-printable" 3264 82)("text" "html" ("charset" "utf-8") NIL NIL "quoted-printable" 31834 643) "alternative")) +/// * 29879 FETCH (BODY ("text" "html" ("charset" "us-ascii") NIL NIL "7bit" 4107 131)) +/// ^^^^^^^^^^^^^^^^^^^^^^ ^^^ ^^^ ^^^^^^ ^^^^ ^^^ +/// | | | | | | number of lines +/// | | | | | size +/// | | | | content transfer encoding +/// | | | description +/// | | id +/// | parameter list +/// b OK Fetch completed (0.001 + 0.000 secs). +/// ``` +pub fn bodystructure(part: &AnyPart, is_ext: bool) -> Result<BodyStructure<'static>> { + NodeMime(part).structure(is_ext) +} + +/// NodeMime +/// +/// Used for recursive logic on MIME. +/// See SelectedMime for inspection. +struct NodeMime<'a>(&'a AnyPart<'a>); +impl<'a> NodeMime<'a> { + /// A MIME object is a tree of elements. + /// The path indicates which element must be picked. + /// This function returns the picked element as the new view + fn subset(self, path: Option<&'a FetchPart>) -> Result<SelectedMime<'a>> { + match path { + None => Ok(SelectedMime(self.0)), + Some(v) => self.rec_subset(v.0.as_ref()), + } + } + + fn rec_subset(self, path: &'a [NonZeroU32]) -> Result<SelectedMime> { + if path.is_empty() { + Ok(SelectedMime(self.0)) + } else { + match self.0 { + AnyPart::Mult(x) => { + let next = Self(x.children + .get(path[0].get() as usize - 1) + .ok_or(anyhow!("Unable to resolve subpath {:?}, current multipart has only {} elements", path, x.children.len()))?); + next.rec_subset(&path[1..]) + }, + AnyPart::Msg(x) => { + let next = Self(x.child.as_ref()); + next.rec_subset(path) + }, + _ => bail!("You tried to access a subpart on an atomic part (text or binary). Unresolved subpath {:?}", path), + } + } + } + + fn structure(&self, is_ext: bool) -> Result<BodyStructure<'static>> { + match self.0 { + AnyPart::Txt(x) => NodeTxt(self, x).structure(is_ext), + AnyPart::Bin(x) => NodeBin(self, x).structure(is_ext), + AnyPart::Mult(x) => NodeMult(self, x).structure(is_ext), + AnyPart::Msg(x) => NodeMsg(self, x).structure(is_ext), + } + } +} + +//---------------------------------------------------------- + +/// A FetchSection must be handled in 2 times: +/// - First we must extract the MIME part +/// - Then we must process it as desired +/// The given struct mixes both work, so +/// we separate this work here. +enum SubsettedSection<'a> { + Part, + Header, + HeaderFields(&'a Vec1<AString<'a>>), + HeaderFieldsNot(&'a Vec1<AString<'a>>), + Text, + Mime, +} +impl<'a> SubsettedSection<'a> { + fn from(section: &'a Option<FetchSection>) -> (Self, Option<&'a FetchPart>) { + match section { + Some(FetchSection::Text(maybe_part)) => (Self::Text, maybe_part.as_ref()), + Some(FetchSection::Header(maybe_part)) => (Self::Header, maybe_part.as_ref()), + Some(FetchSection::HeaderFields(maybe_part, fields)) => { + (Self::HeaderFields(fields), maybe_part.as_ref()) + } + Some(FetchSection::HeaderFieldsNot(maybe_part, fields)) => { + (Self::HeaderFieldsNot(fields), maybe_part.as_ref()) + } + Some(FetchSection::Mime(part)) => (Self::Mime, Some(part)), + Some(FetchSection::Part(part)) => (Self::Part, Some(part)), + None => (Self::Part, None), + } + } +} + +/// Used for current MIME inspection +/// +/// See NodeMime for recursive logic +pub struct SelectedMime<'a>(pub &'a AnyPart<'a>); +impl<'a> SelectedMime<'a> { + pub fn header_value(&'a self, to_match_ext: &[u8]) -> Option<&'a [u8]> { + let to_match = to_match_ext.to_ascii_lowercase(); + + self.eml_mime() + .kv + .iter() + .filter_map(|field| match field { + header::Field::Good(header::Kv2(k, v)) => Some((k, v)), + _ => None, + }) + .find(|(k, _)| k.to_ascii_lowercase() == to_match) + .map(|(_, v)| v) + .copied() + } + + /// The subsetted fetch section basically tells us the + /// extraction logic to apply on our selected MIME. + /// This function acts as a router for these logic. + fn extract(&self, extractor: &SubsettedSection<'a>) -> Result<ExtractedFull<'a>> { + match extractor { + SubsettedSection::Text => self.text(), + SubsettedSection::Header => self.header(), + SubsettedSection::HeaderFields(fields) => self.header_fields(fields, false), + SubsettedSection::HeaderFieldsNot(fields) => self.header_fields(fields, true), + SubsettedSection::Part => self.part(), + SubsettedSection::Mime => self.mime(), + } + } + + fn mime(&self) -> Result<ExtractedFull<'a>> { + let bytes = match &self.0 { + AnyPart::Txt(p) => p.mime.fields.raw, + AnyPart::Bin(p) => p.mime.fields.raw, + AnyPart::Msg(p) => p.child.mime().raw, + AnyPart::Mult(p) => p.mime.fields.raw, + }; + Ok(ExtractedFull(bytes.into())) + } + + fn part(&self) -> Result<ExtractedFull<'a>> { + let bytes = match &self.0 { + AnyPart::Txt(p) => p.body, + AnyPart::Bin(p) => p.body, + AnyPart::Msg(p) => p.raw_part, + AnyPart::Mult(_) => bail!("Multipart part has no body"), + }; + Ok(ExtractedFull(bytes.to_vec().into())) + } + + fn eml_mime(&self) -> &eml_codec::mime::NaiveMIME<'_> { + match &self.0 { + AnyPart::Msg(msg) => msg.child.mime(), + other => other.mime(), + } + } + + /// The [...] HEADER.FIELDS, and HEADER.FIELDS.NOT part + /// specifiers refer to the [RFC-2822] header of the message or of + /// an encapsulated [MIME-IMT] MESSAGE/RFC822 message. + /// HEADER.FIELDS and HEADER.FIELDS.NOT are followed by a list of + /// field-name (as defined in [RFC-2822]) names, and return a + /// subset of the header. The subset returned by HEADER.FIELDS + /// contains only those header fields with a field-name that + /// matches one of the names in the list; similarly, the subset + /// returned by HEADER.FIELDS.NOT contains only the header fields + /// with a non-matching field-name. The field-matching is + /// case-insensitive but otherwise exact. + fn header_fields( + &self, + fields: &'a Vec1<AString<'a>>, + invert: bool, + ) -> Result<ExtractedFull<'a>> { + // Build a lowercase ascii hashset with the fields to fetch + let index = fields + .as_ref() + .iter() + .map(|x| { + match x { + AString::Atom(a) => a.inner().as_bytes(), + AString::String(IString::Literal(l)) => l.as_ref(), + AString::String(IString::Quoted(q)) => q.inner().as_bytes(), + } + .to_ascii_lowercase() + }) + .collect::<HashSet<_>>(); + + // Extract MIME headers + let mime = self.eml_mime(); + + // Filter our MIME headers based on the field index + // 1. Keep only the correctly formatted headers + // 2. Keep only based on the index presence or absence + // 3. Reduce as a byte vector + let buffer = mime + .kv + .iter() + .filter_map(|field| match field { + header::Field::Good(header::Kv2(k, v)) => Some((k, v)), + _ => None, + }) + .filter(|(k, _)| index.contains(&k.to_ascii_lowercase()) ^ invert) + .fold(vec![], |mut acc, (k, v)| { + acc.extend(*k); + acc.extend(b": "); + acc.extend(*v); + acc.extend(b"\r\n"); + acc + }); + + Ok(ExtractedFull(buffer.into())) + } + + /// The HEADER [...] part specifiers refer to the [RFC-2822] header of the message or of + /// an encapsulated [MIME-IMT] MESSAGE/RFC822 message. + /// ```raw + /// HEADER ([RFC-2822] header of the message) + /// ``` + fn header(&self) -> Result<ExtractedFull<'a>> { + let msg = self + .0 + .as_message() + .ok_or(anyhow!("Selected part must be a message/rfc822"))?; + Ok(ExtractedFull(msg.raw_headers.into())) + } + + /// The TEXT part specifier refers to the text body of the message, omitting the [RFC-2822] header. + fn text(&self) -> Result<ExtractedFull<'a>> { + let msg = self + .0 + .as_message() + .ok_or(anyhow!("Selected part must be a message/rfc822"))?; + Ok(ExtractedFull(msg.raw_body.into())) + } + + // ------------ + + /// Basic field of a MIME part that is + /// common to all parts + fn basic_fields(&self) -> Result<BasicFields<'static>> { + let sz = match self.0 { + AnyPart::Txt(x) => x.body.len(), + AnyPart::Bin(x) => x.body.len(), + AnyPart::Msg(x) => x.raw_part.len(), + AnyPart::Mult(_) => 0, + }; + let m = self.0.mime(); + let parameter_list = m + .ctype + .as_ref() + .map(|x| { + x.params + .iter() + .map(|p| { + ( + IString::try_from(String::from_utf8_lossy(p.name).to_string()), + IString::try_from(p.value.to_string()), + ) + }) + .filter(|(k, v)| k.is_ok() && v.is_ok()) + .map(|(k, v)| (k.unwrap(), v.unwrap())) + .collect() + }) + .unwrap_or(vec![]); + + Ok(BasicFields { + parameter_list, + id: NString( + m.id.as_ref() + .and_then(|ci| IString::try_from(ci.to_string()).ok()), + ), + description: NString( + m.description + .as_ref() + .and_then(|cd| IString::try_from(cd.to_string()).ok()), + ), + content_transfer_encoding: match m.transfer_encoding { + mime::mechanism::Mechanism::_8Bit => unchecked_istring("8bit"), + mime::mechanism::Mechanism::Binary => unchecked_istring("binary"), + mime::mechanism::Mechanism::QuotedPrintable => { + unchecked_istring("quoted-printable") + } + mime::mechanism::Mechanism::Base64 => unchecked_istring("base64"), + _ => unchecked_istring("7bit"), + }, + // @FIXME we can't compute the size of the message currently... + size: u32::try_from(sz)?, + }) + } +} + +// --------------------------- +struct NodeMsg<'a>(&'a NodeMime<'a>, &'a composite::Message<'a>); +impl<'a> NodeMsg<'a> { + fn structure(&self, is_ext: bool) -> Result<BodyStructure<'static>> { + let basic = SelectedMime(self.0 .0).basic_fields()?; + + Ok(BodyStructure::Single { + body: FetchBody { + basic, + specific: SpecificFields::Message { + envelope: Box::new(ImfView(&self.1.imf).message_envelope()), + body_structure: Box::new(NodeMime(&self.1.child).structure(is_ext)?), + number_of_lines: nol(self.1.raw_part), + }, + }, + extension_data: match is_ext { + true => Some(SinglePartExtensionData { + md5: NString(None), + tail: None, + }), + _ => None, + }, + }) + } +} + +#[allow(dead_code)] +struct NodeMult<'a>(&'a NodeMime<'a>, &'a composite::Multipart<'a>); +impl<'a> NodeMult<'a> { + fn structure(&self, is_ext: bool) -> Result<BodyStructure<'static>> { + let itype = &self.1.mime.interpreted_type; + let subtype = IString::try_from(itype.subtype.to_string()) + .unwrap_or(unchecked_istring("alternative")); + + let inner_bodies = self + .1 + .children + .iter() + .filter_map(|inner| NodeMime(&inner).structure(is_ext).ok()) + .collect::<Vec<_>>(); + + Vec1::validate(&inner_bodies)?; + let bodies = Vec1::unvalidated(inner_bodies); + + Ok(BodyStructure::Multi { + bodies, + subtype, + extension_data: match is_ext { + true => Some(MultiPartExtensionData { + parameter_list: vec![( + IString::try_from("boundary").unwrap(), + IString::try_from(self.1.mime.interpreted_type.boundary.to_string())?, + )], + tail: None, + }), + _ => None, + }, + }) + } +} +struct NodeTxt<'a>(&'a NodeMime<'a>, &'a discrete::Text<'a>); +impl<'a> NodeTxt<'a> { + fn structure(&self, is_ext: bool) -> Result<BodyStructure<'static>> { + let mut basic = SelectedMime(self.0 .0).basic_fields()?; + + // Get the interpreted content type, set it + let itype = match &self.1.mime.interpreted_type { + Deductible::Inferred(v) | Deductible::Explicit(v) => v, + }; + let subtype = + IString::try_from(itype.subtype.to_string()).unwrap_or(unchecked_istring("plain")); + + // Add charset to the list of parameters if we know it has been inferred as it will be + // missing from the parsed content. + if let Deductible::Inferred(charset) = &itype.charset { + basic.parameter_list.push(( + unchecked_istring("charset"), + IString::try_from(charset.to_string()).unwrap_or(unchecked_istring("us-ascii")), + )); + } + + Ok(BodyStructure::Single { + body: FetchBody { + basic, + specific: SpecificFields::Text { + subtype, + number_of_lines: nol(self.1.body), + }, + }, + extension_data: match is_ext { + true => Some(SinglePartExtensionData { + md5: NString(None), + tail: None, + }), + _ => None, + }, + }) + } +} + +struct NodeBin<'a>(&'a NodeMime<'a>, &'a discrete::Binary<'a>); +impl<'a> NodeBin<'a> { + fn structure(&self, is_ext: bool) -> Result<BodyStructure<'static>> { + let basic = SelectedMime(self.0 .0).basic_fields()?; + + let default = mime::r#type::NaiveType { + main: &b"application"[..], + sub: &b"octet-stream"[..], + params: vec![], + }; + let ct = self.1.mime.fields.ctype.as_ref().unwrap_or(&default); + + let r#type = IString::try_from(String::from_utf8_lossy(ct.main).to_string()).or(Err( + anyhow!("Unable to build IString from given Content-Type type given"), + ))?; + + let subtype = IString::try_from(String::from_utf8_lossy(ct.sub).to_string()).or(Err( + anyhow!("Unable to build IString from given Content-Type subtype given"), + ))?; + + Ok(BodyStructure::Single { + body: FetchBody { + basic, + specific: SpecificFields::Basic { r#type, subtype }, + }, + extension_data: match is_ext { + true => Some(SinglePartExtensionData { + md5: NString(None), + tail: None, + }), + _ => None, + }, + }) + } +} + +// --------------------------- + +struct ExtractedFull<'a>(Cow<'a, [u8]>); +impl<'a> ExtractedFull<'a> { + /// It is possible to fetch a substring of the designated text. + /// This is done by appending an open angle bracket ("<"), the + /// octet position of the first desired octet, a period, the + /// maximum number of octets desired, and a close angle bracket + /// (">") to the part specifier. If the starting octet is beyond + /// the end of the text, an empty string is returned. + /// + /// Any partial fetch that attempts to read beyond the end of the + /// text is truncated as appropriate. A partial fetch that starts + /// at octet 0 is returned as a partial fetch, even if this + /// truncation happened. + /// + /// Note: This means that BODY[]<0.2048> of a 1500-octet message + /// will return BODY[]<0> with a literal of size 1500, not + /// BODY[]. + /// + /// Note: A substring fetch of a HEADER.FIELDS or + /// HEADER.FIELDS.NOT part specifier is calculated after + /// subsetting the header. + fn to_body_section(self, partial: &'_ Option<(u32, NonZeroU32)>) -> BodySection<'a> { + match partial { + Some((begin, len)) => self.partialize(*begin, *len), + None => BodySection::Full(self.0), + } + } + + fn partialize(self, begin: u32, len: NonZeroU32) -> BodySection<'a> { + // Asked range is starting after the end of the content, + // returning an empty buffer + if begin as usize > self.0.len() { + return BodySection::Slice { + body: Cow::Borrowed(&[][..]), + origin_octet: begin, + }; + } + + // Asked range is ending after the end of the content, + // slice only the beginning of the buffer + if (begin + len.get()) as usize >= self.0.len() { + return BodySection::Slice { + body: match self.0 { + Cow::Borrowed(body) => Cow::Borrowed(&body[begin as usize..]), + Cow::Owned(body) => Cow::Owned(body[begin as usize..].to_vec()), + }, + origin_octet: begin, + }; + } + + // Range is included inside the considered content, + // this is the "happy case" + BodySection::Slice { + body: match self.0 { + Cow::Borrowed(body) => { + Cow::Borrowed(&body[begin as usize..(begin + len.get()) as usize]) + } + Cow::Owned(body) => { + Cow::Owned(body[begin as usize..(begin + len.get()) as usize].to_vec()) + } + }, + origin_octet: begin, + } + } +} + +/// ---- LEGACY + +/// s is set to static to ensure that only compile time values +/// checked by developpers are passed. +fn unchecked_istring(s: &'static str) -> IString { + IString::try_from(s).expect("this value is expected to be a valid imap-codec::IString") +} + +// Number Of Lines +fn nol(input: &[u8]) -> u32 { + input + .iter() + .filter(|x| **x == b'\n') + .count() + .try_into() + .unwrap_or(0) +} diff --git a/aero-proto/src/imap/mod.rs b/aero-proto/src/imap/mod.rs new file mode 100644 index 0000000..6a768b0 --- /dev/null +++ b/aero-proto/src/imap/mod.rs @@ -0,0 +1,336 @@ +mod attributes; +mod capability; +mod command; +mod flags; +mod flow; +mod imf_view; +mod index; +mod mail_view; +mod mailbox_view; +mod mime_view; +mod request; +mod response; +mod search; +mod session; + +use std::net::SocketAddr; + +use anyhow::{anyhow, bail, Result}; +use futures::stream::{FuturesUnordered, StreamExt}; +use imap_codec::imap_types::response::{Code, CommandContinuationRequest, Response, Status}; +use imap_codec::imap_types::{core::Text, response::Greeting}; +use imap_flow::server::{ServerFlow, ServerFlowEvent, ServerFlowOptions}; +use imap_flow::stream::AnyStream; +use rustls_pemfile::{certs, private_key}; +use tokio::net::TcpListener; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio_rustls::TlsAcceptor; + +use aero_user::config::{ImapConfig, ImapUnsecureConfig}; +use aero_user::login::ArcLoginProvider; + +use crate::imap::capability::ServerCapability; +use crate::imap::request::Request; +use crate::imap::response::{Body, ResponseOrIdle}; +use crate::imap::session::Instance; + +/// Server is a thin wrapper to register our Services in BÃ L +pub struct Server { + bind_addr: SocketAddr, + login_provider: ArcLoginProvider, + capabilities: ServerCapability, + tls: Option<TlsAcceptor>, +} + +#[derive(Clone)] +struct ClientContext { + addr: SocketAddr, + login_provider: ArcLoginProvider, + must_exit: watch::Receiver<bool>, + server_capabilities: ServerCapability, +} + +pub fn new(config: ImapConfig, login: ArcLoginProvider) -> Result<Server> { + let loaded_certs = certs(&mut std::io::BufReader::new(std::fs::File::open( + config.certs, + )?)) + .collect::<Result<Vec<_>, _>>()?; + let loaded_key = private_key(&mut std::io::BufReader::new(std::fs::File::open( + config.key, + )?))? + .unwrap(); + + let tls_config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(loaded_certs, loaded_key)?; + let acceptor = TlsAcceptor::from(Arc::new(tls_config)); + + Ok(Server { + bind_addr: config.bind_addr, + login_provider: login, + capabilities: ServerCapability::default(), + tls: Some(acceptor), + }) +} + +pub fn new_unsecure(config: ImapUnsecureConfig, login: ArcLoginProvider) -> Server { + Server { + bind_addr: config.bind_addr, + login_provider: login, + capabilities: ServerCapability::default(), + tls: None, + } +} + +impl Server { + pub async fn run(self: Self, mut must_exit: watch::Receiver<bool>) -> Result<()> { + let tcp = TcpListener::bind(self.bind_addr).await?; + tracing::info!("IMAP server listening on {:#}", self.bind_addr); + + let mut connections = FuturesUnordered::new(); + + while !*must_exit.borrow() { + let wait_conn_finished = async { + if connections.is_empty() { + futures::future::pending().await + } else { + connections.next().await + } + }; + let (socket, remote_addr) = tokio::select! { + a = tcp.accept() => a?, + _ = wait_conn_finished => continue, + _ = must_exit.changed() => continue, + }; + tracing::info!("IMAP: accepted connection from {}", remote_addr); + let stream = match self.tls.clone() { + Some(acceptor) => { + let stream = match acceptor.accept(socket).await { + Ok(v) => v, + Err(e) => { + tracing::error!(err=?e, "TLS negociation failed"); + continue; + } + }; + AnyStream::new(stream) + } + None => AnyStream::new(socket), + }; + + let client = ClientContext { + addr: remote_addr.clone(), + login_provider: self.login_provider.clone(), + must_exit: must_exit.clone(), + server_capabilities: self.capabilities.clone(), + }; + let conn = tokio::spawn(NetLoop::handler(client, stream)); + connections.push(conn); + } + drop(tcp); + + tracing::info!("IMAP server shutting down, draining remaining connections..."); + while connections.next().await.is_some() {} + + Ok(()) + } +} + +use std::sync::Arc; +use tokio::sync::mpsc::*; +use tokio::sync::Notify; + +const PIPELINABLE_COMMANDS: usize = 64; + +// @FIXME a full refactor of this part of the code will be needed sooner or later +struct NetLoop { + ctx: ClientContext, + server: ServerFlow, + cmd_tx: Sender<Request>, + resp_rx: UnboundedReceiver<ResponseOrIdle>, +} + +impl NetLoop { + async fn handler(ctx: ClientContext, sock: AnyStream) { + let addr = ctx.addr.clone(); + + let mut nl = match Self::new(ctx, sock).await { + Ok(nl) => { + tracing::debug!(addr=?addr, "netloop successfully initialized"); + nl + } + Err(e) => { + tracing::error!(addr=?addr, err=?e, "netloop can not be initialized, closing session"); + return; + } + }; + + match nl.core().await { + Ok(()) => { + tracing::debug!("closing successful netloop core for {:?}", addr); + } + Err(e) => { + tracing::error!("closing errored netloop core for {:?}: {}", addr, e); + } + } + } + + async fn new(ctx: ClientContext, sock: AnyStream) -> Result<Self> { + let mut opts = ServerFlowOptions::default(); + opts.crlf_relaxed = false; + opts.literal_accept_text = Text::unvalidated("OK"); + opts.literal_reject_text = Text::unvalidated("Literal rejected"); + + // Send greeting + let (server, _) = ServerFlow::send_greeting( + sock, + opts, + Greeting::ok( + Some(Code::Capability(ctx.server_capabilities.to_vec())), + "Aerogramme", + ) + .unwrap(), + ) + .await?; + + // Start a mailbox session in background + let (cmd_tx, cmd_rx) = mpsc::channel::<Request>(PIPELINABLE_COMMANDS); + let (resp_tx, resp_rx) = mpsc::unbounded_channel::<ResponseOrIdle>(); + tokio::spawn(Self::session(ctx.clone(), cmd_rx, resp_tx)); + + // Return the object + Ok(NetLoop { + ctx, + server, + cmd_tx, + resp_rx, + }) + } + + /// Coms with the background session + async fn session( + ctx: ClientContext, + mut cmd_rx: Receiver<Request>, + resp_tx: UnboundedSender<ResponseOrIdle>, + ) -> () { + let mut session = Instance::new(ctx.login_provider, ctx.server_capabilities); + loop { + let cmd = match cmd_rx.recv().await { + None => break, + Some(cmd_recv) => cmd_recv, + }; + + tracing::debug!(cmd=?cmd, sock=%ctx.addr, "command"); + let maybe_response = session.request(cmd).await; + tracing::debug!(cmd=?maybe_response, sock=%ctx.addr, "response"); + + match resp_tx.send(maybe_response) { + Err(_) => break, + Ok(_) => (), + }; + } + tracing::info!("runner is quitting"); + } + + async fn core(&mut self) -> Result<()> { + let mut maybe_idle: Option<Arc<Notify>> = None; + loop { + tokio::select! { + // Managing imap_flow stuff + srv_evt = self.server.progress() => match srv_evt? { + ServerFlowEvent::ResponseSent { handle: _handle, response } => { + match response { + Response::Status(Status::Bye(_)) => return Ok(()), + _ => tracing::trace!("sent to {} content {:?}", self.ctx.addr, response), + } + }, + ServerFlowEvent::CommandReceived { command } => { + match self.cmd_tx.try_send(Request::ImapCommand(command)) { + Ok(_) => (), + Err(mpsc::error::TrySendError::Full(_)) => { + self.server.enqueue_status(Status::bye(None, "Too fast").unwrap()); + tracing::error!("client {:?} is sending commands too fast, closing.", self.ctx.addr); + } + _ => { + self.server.enqueue_status(Status::bye(None, "Internal session exited").unwrap()); + tracing::error!("session task exited for {:?}, quitting", self.ctx.addr); + } + } + }, + ServerFlowEvent::IdleCommandReceived { tag } => { + match self.cmd_tx.try_send(Request::IdleStart(tag)) { + Ok(_) => (), + Err(mpsc::error::TrySendError::Full(_)) => { + self.server.enqueue_status(Status::bye(None, "Too fast").unwrap()); + tracing::error!("client {:?} is sending commands too fast, closing.", self.ctx.addr); + } + _ => { + self.server.enqueue_status(Status::bye(None, "Internal session exited").unwrap()); + tracing::error!("session task exited for {:?}, quitting", self.ctx.addr); + } + } + } + ServerFlowEvent::IdleDoneReceived => { + tracing::trace!("client sent DONE and want to stop IDLE"); + maybe_idle.ok_or(anyhow!("Received IDLE done but not idling currently"))?.notify_one(); + maybe_idle = None; + } + flow => { + self.server.enqueue_status(Status::bye(None, "Unsupported server flow event").unwrap()); + tracing::error!("session task exited for {:?} due to unsupported flow {:?}", self.ctx.addr, flow); + } + }, + + // Managing response generated by Aerogramme + maybe_msg = self.resp_rx.recv() => match maybe_msg { + Some(ResponseOrIdle::Response(response)) => { + tracing::trace!("Interactive, server has a response for the client"); + for body_elem in response.body.into_iter() { + let _handle = match body_elem { + Body::Data(d) => self.server.enqueue_data(d), + Body::Status(s) => self.server.enqueue_status(s), + }; + } + self.server.enqueue_status(response.completion); + }, + Some(ResponseOrIdle::IdleAccept(stop)) => { + tracing::trace!("Interactive, server agreed to switch in idle mode"); + let cr = CommandContinuationRequest::basic(None, "Idling")?; + self.server.idle_accept(cr).or(Err(anyhow!("refused continuation for idle accept")))?; + self.cmd_tx.try_send(Request::IdlePoll)?; + if maybe_idle.is_some() { + bail!("Can't start IDLE if already idling"); + } + maybe_idle = Some(stop); + }, + Some(ResponseOrIdle::IdleEvent(elems)) => { + tracing::trace!("server imap session has some change to communicate to the client"); + for body_elem in elems.into_iter() { + let _handle = match body_elem { + Body::Data(d) => self.server.enqueue_data(d), + Body::Status(s) => self.server.enqueue_status(s), + }; + } + self.cmd_tx.try_send(Request::IdlePoll)?; + }, + Some(ResponseOrIdle::IdleReject(response)) => { + tracing::trace!("inform client that session rejected idle"); + self.server + .idle_reject(response.completion) + .or(Err(anyhow!("wrong reject command")))?; + }, + None => { + self.server.enqueue_status(Status::bye(None, "Internal session exited").unwrap()); + tracing::error!("session task exited for {:?}, quitting", self.ctx.addr); + }, + }, + + // When receiving a CTRL+C + _ = self.ctx.must_exit.changed() => { + tracing::trace!("Interactive, CTRL+C, exiting"); + self.server.enqueue_status(Status::bye(None, "Server is being shutdown").unwrap()); + }, + }; + } + } +} diff --git a/aero-proto/src/imap/request.rs b/aero-proto/src/imap/request.rs new file mode 100644 index 0000000..cff18a3 --- /dev/null +++ b/aero-proto/src/imap/request.rs @@ -0,0 +1,9 @@ +use imap_codec::imap_types::command::Command; +use imap_codec::imap_types::core::Tag; + +#[derive(Debug)] +pub enum Request { + ImapCommand(Command<'static>), + IdleStart(Tag<'static>), + IdlePoll, +} diff --git a/aero-proto/src/imap/response.rs b/aero-proto/src/imap/response.rs new file mode 100644 index 0000000..b6a0e98 --- /dev/null +++ b/aero-proto/src/imap/response.rs @@ -0,0 +1,124 @@ +use anyhow::Result; +use imap_codec::imap_types::command::Command; +use imap_codec::imap_types::core::Tag; +use imap_codec::imap_types::response::{Code, Data, Status}; +use std::sync::Arc; +use tokio::sync::Notify; + +#[derive(Debug)] +pub enum Body<'a> { + Data(Data<'a>), + Status(Status<'a>), +} + +pub struct ResponseBuilder<'a> { + tag: Option<Tag<'a>>, + code: Option<Code<'a>>, + text: String, + body: Vec<Body<'a>>, +} + +impl<'a> ResponseBuilder<'a> { + pub fn to_req(mut self, cmd: &Command<'a>) -> Self { + self.tag = Some(cmd.tag.clone()); + self + } + pub fn tag(mut self, tag: Tag<'a>) -> Self { + self.tag = Some(tag); + self + } + + pub fn message(mut self, txt: impl Into<String>) -> Self { + self.text = txt.into(); + self + } + + pub fn code(mut self, code: Code<'a>) -> Self { + self.code = Some(code); + self + } + + pub fn data(mut self, data: Data<'a>) -> Self { + self.body.push(Body::Data(data)); + self + } + + pub fn many_data(mut self, data: Vec<Data<'a>>) -> Self { + for d in data.into_iter() { + self = self.data(d); + } + self + } + + #[allow(dead_code)] + pub fn info(mut self, status: Status<'a>) -> Self { + self.body.push(Body::Status(status)); + self + } + + #[allow(dead_code)] + pub fn many_info(mut self, status: Vec<Status<'a>>) -> Self { + for d in status.into_iter() { + self = self.info(d); + } + self + } + + pub fn set_body(mut self, body: Vec<Body<'a>>) -> Self { + self.body = body; + self + } + + pub fn ok(self) -> Result<Response<'a>> { + Ok(Response { + completion: Status::ok(self.tag, self.code, self.text)?, + body: self.body, + }) + } + + pub fn no(self) -> Result<Response<'a>> { + Ok(Response { + completion: Status::no(self.tag, self.code, self.text)?, + body: self.body, + }) + } + + pub fn bad(self) -> Result<Response<'a>> { + Ok(Response { + completion: Status::bad(self.tag, self.code, self.text)?, + body: self.body, + }) + } +} + +#[derive(Debug)] +pub struct Response<'a> { + pub body: Vec<Body<'a>>, + pub completion: Status<'a>, +} + +impl<'a> Response<'a> { + pub fn build() -> ResponseBuilder<'a> { + ResponseBuilder { + tag: None, + code: None, + text: "".to_string(), + body: vec![], + } + } + + pub fn bye() -> Result<Response<'a>> { + Ok(Response { + completion: Status::bye(None, "bye")?, + body: vec![], + }) + } +} + +#[derive(Debug)] +pub enum ResponseOrIdle { + Response(Response<'static>), + IdleAccept(Arc<Notify>), + IdleReject(Response<'static>), + IdleEvent(Vec<Body<'static>>), +} diff --git a/aero-proto/src/imap/search.rs b/aero-proto/src/imap/search.rs new file mode 100644 index 0000000..3634a3a --- /dev/null +++ b/aero-proto/src/imap/search.rs @@ -0,0 +1,478 @@ +use std::num::{NonZeroU32, NonZeroU64}; + +use imap_codec::imap_types::core::Vec1; +use imap_codec::imap_types::search::{MetadataItemSearch, SearchKey}; +use imap_codec::imap_types::sequence::{SeqOrUid, Sequence, SequenceSet}; + +use aero_collections::mail::query::QueryScope; + +use crate::imap::index::MailIndex; +use crate::imap::mail_view::MailView; + +pub enum SeqType { + Undefined, + NonUid, + Uid, +} +impl SeqType { + pub fn is_uid(&self) -> bool { + matches!(self, Self::Uid) + } +} + +pub struct Criteria<'a>(pub &'a SearchKey<'a>); +impl<'a> Criteria<'a> { + /// Returns a set of email identifiers that is greater or equal + /// to the set of emails to return + pub fn to_sequence_set(&self) -> (SequenceSet, SeqType) { + match self.0 { + SearchKey::All => (sequence_set_all(), SeqType::Undefined), + SearchKey::SequenceSet(seq_set) => (seq_set.clone(), SeqType::NonUid), + SearchKey::Uid(seq_set) => (seq_set.clone(), SeqType::Uid), + SearchKey::Not(_inner) => { + tracing::debug!( + "using NOT in a search request is slow: it selects all identifiers" + ); + (sequence_set_all(), SeqType::Undefined) + } + SearchKey::Or(left, right) => { + tracing::debug!("using OR in a search request is slow: no deduplication is done"); + let (base, base_seqtype) = Self(&left).to_sequence_set(); + let (ext, ext_seqtype) = Self(&right).to_sequence_set(); + + // Check if we have a UID/ID conflict in fetching: now we don't know how to handle them + match (base_seqtype, ext_seqtype) { + (SeqType::Uid, SeqType::NonUid) | (SeqType::NonUid, SeqType::Uid) => { + (sequence_set_all(), SeqType::Undefined) + } + (SeqType::Undefined, x) | (x, _) => { + let mut new_vec = base.0.into_inner(); + new_vec.extend_from_slice(ext.0.as_ref()); + let seq = SequenceSet( + Vec1::try_from(new_vec) + .expect("merging non empty vec lead to non empty vec"), + ); + (seq, x) + } + } + } + SearchKey::And(search_list) => { + tracing::debug!( + "using AND in a search request is slow: no intersection is performed" + ); + // As we perform no intersection, we don't care if we mix uid or id. + // We only keep the smallest range, being it ID or UID, depending of + // which one has the less items. This is an approximation as UID ranges + // can have holes while ID ones can't. + search_list + .as_ref() + .iter() + .map(|crit| Self(&crit).to_sequence_set()) + .min_by(|(x, _), (y, _)| { + let x_size = approx_sequence_set_size(x); + let y_size = approx_sequence_set_size(y); + x_size.cmp(&y_size) + }) + .unwrap_or((sequence_set_all(), SeqType::Undefined)) + } + _ => (sequence_set_all(), SeqType::Undefined), + } + } + + /// Not really clever as we can have cases where we filter out + /// the email before needing to inspect its meta. + /// But for now we are seeking the most basic/stupid algorithm. + pub fn query_scope(&self) -> QueryScope { + use SearchKey::*; + match self.0 { + // Combinators + And(and_list) => and_list + .as_ref() + .iter() + .fold(QueryScope::Index, |prev, sk| { + prev.union(&Criteria(sk).query_scope()) + }), + Not(inner) => Criteria(inner).query_scope(), + Or(left, right) => Criteria(left) + .query_scope() + .union(&Criteria(right).query_scope()), + All => QueryScope::Index, + + // IMF Headers + Bcc(_) | Cc(_) | From(_) | Header(..) | SentBefore(_) | SentOn(_) | SentSince(_) + | Subject(_) | To(_) => QueryScope::Partial, + // Internal Date is also stored in MailMeta + Before(_) | On(_) | Since(_) => QueryScope::Partial, + // Message size is also stored in MailMeta + Larger(_) | Smaller(_) => QueryScope::Partial, + // Text and Body require that we fetch the full content! + Text(_) | Body(_) => QueryScope::Full, + + _ => QueryScope::Index, + } + } + + pub fn is_modseq(&self) -> bool { + use SearchKey::*; + match self.0 { + And(and_list) => and_list + .as_ref() + .iter() + .any(|child| Criteria(child).is_modseq()), + Or(left, right) => Criteria(left).is_modseq() || Criteria(right).is_modseq(), + Not(child) => Criteria(child).is_modseq(), + ModSeq { .. } => true, + _ => false, + } + } + + /// Returns emails that we now for sure we want to keep + /// but also a second list of emails we need to investigate further by + /// fetching some remote data + pub fn filter_on_idx<'b>( + &self, + midx_list: &[&'b MailIndex<'b>], + ) -> (Vec<&'b MailIndex<'b>>, Vec<&'b MailIndex<'b>>) { + let (p1, p2): (Vec<_>, Vec<_>) = midx_list + .iter() + .map(|x| (x, self.is_keep_on_idx(x))) + .filter(|(_midx, decision)| decision.is_keep()) + .map(|(midx, decision)| (*midx, decision)) + .partition(|(_midx, decision)| matches!(decision, PartialDecision::Keep)); + + let to_keep = p1.into_iter().map(|(v, _)| v).collect(); + let to_fetch = p2.into_iter().map(|(v, _)| v).collect(); + (to_keep, to_fetch) + } + + // ---- + + /// Here we are doing a partial filtering: we do not have access + /// to the headers or to the body, so every time we encounter a rule + /// based on them, we need to keep it. + /// + /// @TODO Could be optimized on a per-email basis by also returning the QueryScope + /// when more information is needed! + fn is_keep_on_idx(&self, midx: &MailIndex) -> PartialDecision { + use SearchKey::*; + match self.0 { + // Combinator logic + And(expr_list) => expr_list + .as_ref() + .iter() + .fold(PartialDecision::Keep, |acc, cur| { + acc.and(&Criteria(cur).is_keep_on_idx(midx)) + }), + Or(left, right) => { + let left_decision = Criteria(left).is_keep_on_idx(midx); + let right_decision = Criteria(right).is_keep_on_idx(midx); + left_decision.or(&right_decision) + } + Not(expr) => Criteria(expr).is_keep_on_idx(midx).not(), + All => PartialDecision::Keep, + + // Sequence logic + maybe_seq if is_sk_seq(maybe_seq) => is_keep_seq(maybe_seq, midx).into(), + maybe_flag if is_sk_flag(maybe_flag) => is_keep_flag(maybe_flag, midx).into(), + ModSeq { + metadata_item, + modseq, + } => is_keep_modseq(metadata_item, modseq, midx).into(), + + // All the stuff we can't evaluate yet + Bcc(_) | Cc(_) | From(_) | Header(..) | SentBefore(_) | SentOn(_) | SentSince(_) + | Subject(_) | To(_) | Before(_) | On(_) | Since(_) | Larger(_) | Smaller(_) + | Text(_) | Body(_) => PartialDecision::Postpone, + + unknown => { + tracing::error!("Unknown filter {:?}", unknown); + PartialDecision::Discard + } + } + } + + /// @TODO we re-eveluate twice the same logic. The correct way would be, on each pass, + /// to simplify the searck query, by removing the elements that were already checked. + /// For example if we have AND(OR(seqid(X), body(Y)), body(X)), we can't keep for sure + /// the email, as body(x) might be false. So we need to check it. But as seqid(x) is true, + /// we could simplify the request to just body(x) and truncate the first OR. Today, we are + /// not doing that, and thus we reevaluate everything. + pub fn is_keep_on_query(&self, mail_view: &MailView) -> bool { + use SearchKey::*; + match self.0 { + // Combinator logic + And(expr_list) => expr_list + .as_ref() + .iter() + .all(|cur| Criteria(cur).is_keep_on_query(mail_view)), + Or(left, right) => { + Criteria(left).is_keep_on_query(mail_view) + || Criteria(right).is_keep_on_query(mail_view) + } + Not(expr) => !Criteria(expr).is_keep_on_query(mail_view), + All => true, + + //@FIXME Reevaluating our previous logic... + maybe_seq if is_sk_seq(maybe_seq) => is_keep_seq(maybe_seq, &mail_view.in_idx), + maybe_flag if is_sk_flag(maybe_flag) => is_keep_flag(maybe_flag, &mail_view.in_idx), + ModSeq { + metadata_item, + modseq, + } => is_keep_modseq(metadata_item, modseq, &mail_view.in_idx).into(), + + // Filter on mail meta + Before(search_naive) => match mail_view.stored_naive_date() { + Ok(msg_naive) => &msg_naive < search_naive.as_ref(), + _ => false, + }, + On(search_naive) => match mail_view.stored_naive_date() { + Ok(msg_naive) => &msg_naive == search_naive.as_ref(), + _ => false, + }, + Since(search_naive) => match mail_view.stored_naive_date() { + Ok(msg_naive) => &msg_naive > search_naive.as_ref(), + _ => false, + }, + + // Message size is also stored in MailMeta + Larger(size_ref) => { + mail_view + .query_result + .metadata() + .expect("metadata were fetched") + .rfc822_size + > *size_ref as usize + } + Smaller(size_ref) => { + mail_view + .query_result + .metadata() + .expect("metadata were fetched") + .rfc822_size + < *size_ref as usize + } + + // Filter on well-known headers + Bcc(txt) => mail_view.is_header_contains_pattern(&b"bcc"[..], txt.as_ref()), + Cc(txt) => mail_view.is_header_contains_pattern(&b"cc"[..], txt.as_ref()), + From(txt) => mail_view.is_header_contains_pattern(&b"from"[..], txt.as_ref()), + Subject(txt) => mail_view.is_header_contains_pattern(&b"subject"[..], txt.as_ref()), + To(txt) => mail_view.is_header_contains_pattern(&b"to"[..], txt.as_ref()), + Header(hdr, txt) => mail_view.is_header_contains_pattern(hdr.as_ref(), txt.as_ref()), + + // Filter on Date header + SentBefore(search_naive) => mail_view + .imf() + .map(|imf| imf.naive_date().ok()) + .flatten() + .map(|msg_naive| &msg_naive < search_naive.as_ref()) + .unwrap_or(false), + SentOn(search_naive) => mail_view + .imf() + .map(|imf| imf.naive_date().ok()) + .flatten() + .map(|msg_naive| &msg_naive == search_naive.as_ref()) + .unwrap_or(false), + SentSince(search_naive) => mail_view + .imf() + .map(|imf| imf.naive_date().ok()) + .flatten() + .map(|msg_naive| &msg_naive > search_naive.as_ref()) + .unwrap_or(false), + + // Filter on the full content of the email + Text(txt) => mail_view + .content + .as_msg() + .map(|msg| { + msg.raw_part + .windows(txt.as_ref().len()) + .any(|win| win == txt.as_ref()) + }) + .unwrap_or(false), + Body(txt) => mail_view + .content + .as_msg() + .map(|msg| { + msg.raw_body + .windows(txt.as_ref().len()) + .any(|win| win == txt.as_ref()) + }) + .unwrap_or(false), + + unknown => { + tracing::error!("Unknown filter {:?}", unknown); + false + } + } + } +} + +// ---- Sequence things ---- +fn sequence_set_all() -> SequenceSet { + SequenceSet::from(Sequence::Range( + SeqOrUid::Value(NonZeroU32::MIN), + SeqOrUid::Asterisk, + )) +} + +// This is wrong as sequences can overlap +fn approx_sequence_set_size(seq_set: &SequenceSet) -> u64 { + seq_set.0.as_ref().iter().fold(0u64, |acc, seq| { + acc.saturating_add(approx_sequence_size(seq)) + }) +} + +// This is wrong as sequence UID can have holes, +// as we don't know the number of messages in the mailbox also +// we gave to guess +fn approx_sequence_size(seq: &Sequence) -> u64 { + match seq { + Sequence::Single(_) => 1, + Sequence::Range(SeqOrUid::Asterisk, _) | Sequence::Range(_, SeqOrUid::Asterisk) => u64::MAX, + Sequence::Range(SeqOrUid::Value(x1), SeqOrUid::Value(x2)) => { + let x2 = x2.get() as i64; + let x1 = x1.get() as i64; + (x2 - x1).abs().try_into().unwrap_or(1) + } + } +} + +// --- Partial decision things ---- + +enum PartialDecision { + Keep, + Discard, + Postpone, +} +impl From<bool> for PartialDecision { + fn from(x: bool) -> Self { + match x { + true => PartialDecision::Keep, + _ => PartialDecision::Discard, + } + } +} +impl PartialDecision { + fn not(&self) -> Self { + match self { + Self::Keep => Self::Discard, + Self::Discard => Self::Keep, + Self::Postpone => Self::Postpone, + } + } + + fn or(&self, other: &Self) -> Self { + match (self, other) { + (Self::Keep, _) | (_, Self::Keep) => Self::Keep, + (Self::Postpone, _) | (_, Self::Postpone) => Self::Postpone, + (Self::Discard, Self::Discard) => Self::Discard, + } + } + + fn and(&self, other: &Self) -> Self { + match (self, other) { + (Self::Discard, _) | (_, Self::Discard) => Self::Discard, + (Self::Postpone, _) | (_, Self::Postpone) => Self::Postpone, + (Self::Keep, Self::Keep) => Self::Keep, + } + } + + fn is_keep(&self) -> bool { + !matches!(self, Self::Discard) + } +} + +// ----- Search Key things --- +fn is_sk_flag(sk: &SearchKey) -> bool { + use SearchKey::*; + match sk { + Answered | Deleted | Draft | Flagged | Keyword(..) | New | Old | Recent | Seen + | Unanswered | Undeleted | Undraft | Unflagged | Unkeyword(..) | Unseen => true, + _ => false, + } +} + +fn is_keep_flag(sk: &SearchKey, midx: &MailIndex) -> bool { + use SearchKey::*; + match sk { + Answered => midx.is_flag_set("\\Answered"), + Deleted => midx.is_flag_set("\\Deleted"), + Draft => midx.is_flag_set("\\Draft"), + Flagged => midx.is_flag_set("\\Flagged"), + Keyword(kw) => midx.is_flag_set(kw.inner()), + New => { + let is_recent = midx.is_flag_set("\\Recent"); + let is_seen = midx.is_flag_set("\\Seen"); + is_recent && !is_seen + } + Old => { + let is_recent = midx.is_flag_set("\\Recent"); + !is_recent + } + Recent => midx.is_flag_set("\\Recent"), + Seen => midx.is_flag_set("\\Seen"), + Unanswered => { + let is_answered = midx.is_flag_set("\\Recent"); + !is_answered + } + Undeleted => { + let is_deleted = midx.is_flag_set("\\Deleted"); + !is_deleted + } + Undraft => { + let is_draft = midx.is_flag_set("\\Draft"); + !is_draft + } + Unflagged => { + let is_flagged = midx.is_flag_set("\\Flagged"); + !is_flagged + } + Unkeyword(kw) => { + let is_keyword_set = midx.is_flag_set(kw.inner()); + !is_keyword_set + } + Unseen => { + let is_seen = midx.is_flag_set("\\Seen"); + !is_seen + } + + // Not flag logic + _ => unreachable!(), + } +} + +fn is_sk_seq(sk: &SearchKey) -> bool { + use SearchKey::*; + match sk { + SequenceSet(..) | Uid(..) => true, + _ => false, + } +} +fn is_keep_seq(sk: &SearchKey, midx: &MailIndex) -> bool { + use SearchKey::*; + match sk { + SequenceSet(seq_set) => seq_set + .0 + .as_ref() + .iter() + .any(|seq| midx.is_in_sequence_i(seq)), + Uid(seq_set) => seq_set + .0 + .as_ref() + .iter() + .any(|seq| midx.is_in_sequence_uid(seq)), + _ => unreachable!(), + } +} + +fn is_keep_modseq( + filter: &Option<MetadataItemSearch>, + modseq: &NonZeroU64, + midx: &MailIndex, +) -> bool { + if filter.is_some() { + tracing::warn!(filter=?filter, "Ignoring search metadata filter as it's not supported yet"); + } + modseq <= &midx.modseq +} diff --git a/aero-proto/src/imap/session.rs b/aero-proto/src/imap/session.rs new file mode 100644 index 0000000..92b5eb6 --- /dev/null +++ b/aero-proto/src/imap/session.rs @@ -0,0 +1,175 @@ +use anyhow::{anyhow, bail, Context, Result}; +use imap_codec::imap_types::{command::Command, core::Tag}; + +use aero_user::login::ArcLoginProvider; + +use crate::imap::capability::{ClientCapability, ServerCapability}; +use crate::imap::command::{anonymous, authenticated, selected}; +use crate::imap::flow; +use crate::imap::request::Request; +use crate::imap::response::{Response, ResponseOrIdle}; + +//----- +pub struct Instance { + pub login_provider: ArcLoginProvider, + pub server_capabilities: ServerCapability, + pub client_capabilities: ClientCapability, + pub state: flow::State, +} +impl Instance { + pub fn new(login_provider: ArcLoginProvider, cap: ServerCapability) -> Self { + let client_cap = ClientCapability::new(&cap); + Self { + login_provider, + state: flow::State::NotAuthenticated, + server_capabilities: cap, + client_capabilities: client_cap, + } + } + + pub async fn request(&mut self, req: Request) -> ResponseOrIdle { + match req { + Request::IdleStart(tag) => self.idle_init(tag), + Request::IdlePoll => self.idle_poll().await, + Request::ImapCommand(cmd) => self.command(cmd).await, + } + } + + pub fn idle_init(&mut self, tag: Tag<'static>) -> ResponseOrIdle { + // Build transition + //@FIXME the notifier should be hidden inside the state and thus not part of the transition! + let transition = flow::Transition::Idle(tag.clone(), tokio::sync::Notify::new()); + + // Try to apply the transition and get the stop notifier + let maybe_stop = self + .state + .apply(transition) + .context("IDLE transition failed") + .and_then(|_| { + self.state + .notify() + .ok_or(anyhow!("IDLE state has no Notify object")) + }); + + // Build an appropriate response + match maybe_stop { + Ok(stop) => ResponseOrIdle::IdleAccept(stop), + Err(e) => { + tracing::error!(err=?e, "unable to init idle due to a transition error"); + //ResponseOrIdle::IdleReject(tag) + let no = Response::build() + .tag(tag) + .message( + "Internal error, processing command triggered an illegal IMAP state transition", + ) + .no() + .unwrap(); + ResponseOrIdle::IdleReject(no) + } + } + } + + pub async fn idle_poll(&mut self) -> ResponseOrIdle { + match self.idle_poll_happy().await { + Ok(r) => r, + Err(e) => { + tracing::error!(err=?e, "something bad happened in idle"); + ResponseOrIdle::Response(Response::bye().unwrap()) + } + } + } + + pub async fn idle_poll_happy(&mut self) -> Result<ResponseOrIdle> { + let (mbx, tag, stop) = match &mut self.state { + flow::State::Idle(_, ref mut mbx, _, tag, stop) => (mbx, tag.clone(), stop.clone()), + _ => bail!("Invalid session state, can't idle"), + }; + + tokio::select! { + _ = stop.notified() => { + self.state.apply(flow::Transition::UnIdle)?; + return Ok(ResponseOrIdle::Response(Response::build() + .tag(tag.clone()) + .message("IDLE completed") + .ok()?)) + }, + change = mbx.idle_sync() => { + tracing::debug!("idle event"); + return Ok(ResponseOrIdle::IdleEvent(change?)); + } + } + } + + pub async fn command(&mut self, cmd: Command<'static>) -> ResponseOrIdle { + // Command behavior is modulated by the state. + // To prevent state error, we handle the same command in separate code paths. + let (resp, tr) = match &mut self.state { + flow::State::NotAuthenticated => { + let ctx = anonymous::AnonymousContext { + req: &cmd, + login_provider: &self.login_provider, + server_capabilities: &self.server_capabilities, + }; + anonymous::dispatch(ctx).await + } + flow::State::Authenticated(ref user) => { + let ctx = authenticated::AuthenticatedContext { + req: &cmd, + server_capabilities: &self.server_capabilities, + client_capabilities: &mut self.client_capabilities, + user, + }; + authenticated::dispatch(ctx).await + } + flow::State::Selected(ref user, ref mut mailbox, ref perm) => { + let ctx = selected::SelectedContext { + req: &cmd, + server_capabilities: &self.server_capabilities, + client_capabilities: &mut self.client_capabilities, + user, + mailbox, + perm, + }; + selected::dispatch(ctx).await + } + flow::State::Idle(..) => Err(anyhow!("can not receive command while idling")), + flow::State::Logout => Response::build() + .tag(cmd.tag.clone()) + .message("No commands are allowed in the LOGOUT state.") + .bad() + .map(|r| (r, flow::Transition::None)), + } + .unwrap_or_else(|err| { + tracing::error!("Command error {:?} occured while processing {:?}", err, cmd); + ( + Response::build() + .to_req(&cmd) + .message("Internal error while processing command") + .bad() + .unwrap(), + flow::Transition::None, + ) + }); + + if let Err(e) = self.state.apply(tr) { + tracing::error!( + "Transition error {:?} occured while processing on command {:?}", + e, + cmd + ); + return ResponseOrIdle::Response(Response::build() + .to_req(&cmd) + .message( + "Internal error, processing command triggered an illegal IMAP state transition", + ) + .bad() + .unwrap()); + } + ResponseOrIdle::Response(resp) + + /*match &self.state { + flow::State::Idle(_, _, _, _, n) => ResponseOrIdle::StartIdle(n.clone()), + _ => ResponseOrIdle::Response(resp), + }*/ + } +} diff --git a/aero-proto/src/lib.rs b/aero-proto/src/lib.rs new file mode 100644 index 0000000..d5154cd --- /dev/null +++ b/aero-proto/src/lib.rs @@ -0,0 +1,6 @@ +#![feature(async_closure)] + +pub mod dav; +pub mod imap; +pub mod lmtp; +pub mod sasl; diff --git a/aero-proto/src/lmtp.rs b/aero-proto/src/lmtp.rs new file mode 100644 index 0000000..a82a783 --- /dev/null +++ b/aero-proto/src/lmtp.rs @@ -0,0 +1,219 @@ +use std::net::SocketAddr; +use std::{pin::Pin, sync::Arc}; + +use anyhow::Result; +use async_trait::async_trait; +use duplexify::Duplex; +use futures::{io, AsyncRead, AsyncReadExt, AsyncWrite}; +use futures::{ + stream, + stream::{FuturesOrdered, FuturesUnordered}, + StreamExt, +}; +use smtp_message::{DataUnescaper, Email, EscapedDataReader, Reply, ReplyCode}; +use smtp_server::{reply, Config, ConnectionMetadata, Decision, MailMetadata}; +use tokio::net::TcpListener; +use tokio::select; +use tokio::sync::watch; +use tokio_util::compat::*; + +use aero_collections::mail::incoming::EncryptedMessage; +use aero_user::config::*; +use aero_user::login::*; + +pub struct LmtpServer { + bind_addr: SocketAddr, + hostname: String, + login_provider: Arc<dyn LoginProvider + Send + Sync>, +} + +impl LmtpServer { + pub fn new( + config: LmtpConfig, + login_provider: Arc<dyn LoginProvider + Send + Sync>, + ) -> Arc<Self> { + Arc::new(Self { + bind_addr: config.bind_addr, + hostname: config.hostname, + login_provider, + }) + } + + pub async fn run(self: &Arc<Self>, mut must_exit: watch::Receiver<bool>) -> Result<()> { + let tcp = TcpListener::bind(self.bind_addr).await?; + tracing::info!("LMTP server listening on {:#}", self.bind_addr); + + let mut connections = FuturesUnordered::new(); + + while !*must_exit.borrow() { + let wait_conn_finished = async { + if connections.is_empty() { + futures::future::pending().await + } else { + connections.next().await + } + }; + let (socket, remote_addr) = select! { + a = tcp.accept() => a?, + _ = wait_conn_finished => continue, + _ = must_exit.changed() => continue, + }; + tracing::info!("LMTP: accepted connection from {}", remote_addr); + + let conn = tokio::spawn(smtp_server::interact( + socket.compat(), + smtp_server::IsAlreadyTls::No, + (), + self.clone(), + )); + + connections.push(conn); + } + drop(tcp); + + tracing::info!("LMTP server shutting down, draining remaining connections..."); + while connections.next().await.is_some() {} + + Ok(()) + } +} + +// ---- + +pub struct Message { + to: Vec<PublicCredentials>, +} + +#[async_trait] +impl Config for LmtpServer { + type Protocol = smtp_server::protocol::Lmtp; + + type ConnectionUserMeta = (); + type MailUserMeta = Message; + + fn hostname(&self, _conn_meta: &ConnectionMetadata<()>) -> &str { + &self.hostname + } + + async fn new_mail(&self, _conn_meta: &mut ConnectionMetadata<()>) -> Message { + Message { to: vec![] } + } + + async fn tls_accept<IO>( + &self, + _io: IO, + _conn_meta: &mut ConnectionMetadata<()>, + ) -> io::Result<Duplex<Pin<Box<dyn Send + AsyncRead>>, Pin<Box<dyn Send + AsyncWrite>>>> + where + IO: Send + AsyncRead + AsyncWrite, + { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "TLS not implemented for LMTP server", + )) + } + + async fn filter_from( + &self, + from: Option<Email>, + _meta: &mut MailMetadata<Message>, + _conn_meta: &mut ConnectionMetadata<()>, + ) -> Decision<Option<Email>> { + Decision::Accept { + reply: reply::okay_from().convert(), + res: from, + } + } + + async fn filter_to( + &self, + to: Email, + meta: &mut MailMetadata<Message>, + _conn_meta: &mut ConnectionMetadata<()>, + ) -> Decision<Email> { + let to_str = match to.hostname.as_ref() { + Some(h) => format!("{}@{}", to.localpart, h), + None => to.localpart.to_string(), + }; + match self.login_provider.public_login(&to_str).await { + Ok(creds) => { + meta.user.to.push(creds); + Decision::Accept { + reply: reply::okay_to().convert(), + res: to, + } + } + Err(e) => Decision::Reject { + reply: Reply { + code: ReplyCode::POLICY_REASON, + ecode: None, + text: vec![smtp_message::MaybeUtf8::Utf8(e.to_string())], + }, + }, + } + } + + async fn handle_mail<'resp, R>( + &'resp self, + reader: &mut EscapedDataReader<'_, R>, + meta: MailMetadata<Message>, + _conn_meta: &'resp mut ConnectionMetadata<()>, + ) -> Pin<Box<dyn futures::Stream<Item = Decision<()>> + Send + 'resp>> + where + R: Send + Unpin + AsyncRead, + { + let err_response_stream = |meta: MailMetadata<Message>, msg: String| { + Box::pin( + stream::iter(meta.user.to.into_iter()).map(move |_| Decision::Reject { + reply: Reply { + code: ReplyCode::POLICY_REASON, + ecode: None, + text: vec![smtp_message::MaybeUtf8::Utf8(msg.clone())], + }, + }), + ) + }; + + let mut text = Vec::new(); + if let Err(e) = reader.read_to_end(&mut text).await { + return err_response_stream(meta, format!("io error: {}", e)); + } + reader.complete(); + let raw_size = text.len(); + + // Unescape email, shrink it also to remove last dot + let unesc_res = DataUnescaper::new(true).unescape(&mut text); + text.truncate(unesc_res.written); + tracing::debug!(prev_sz = raw_size, new_sz = text.len(), "unescaped"); + + let encrypted_message = match EncryptedMessage::new(text) { + Ok(x) => Arc::new(x), + Err(e) => return err_response_stream(meta, e.to_string()), + }; + + Box::pin( + meta.user + .to + .into_iter() + .map(move |creds| { + let encrypted_message = encrypted_message.clone(); + async move { + match encrypted_message.deliver_to(creds).await { + Ok(()) => Decision::Accept { + reply: reply::okay_mail().convert(), + res: (), + }, + Err(e) => Decision::Reject { + reply: Reply { + code: ReplyCode::POLICY_REASON, + ecode: None, + text: vec![smtp_message::MaybeUtf8::Utf8(e.to_string())], + }, + }, + } + } + }) + .collect::<FuturesOrdered<_>>(), + ) + } +} diff --git a/aero-proto/src/sasl.rs b/aero-proto/src/sasl.rs new file mode 100644 index 0000000..48c0815 --- /dev/null +++ b/aero-proto/src/sasl.rs @@ -0,0 +1,142 @@ +use std::net::SocketAddr; + +use anyhow::{anyhow, bail, Result}; +use futures::stream::{FuturesUnordered, StreamExt}; +use tokio::io::BufStream; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::watch; +use tokio_util::bytes::BytesMut; + +use aero_sasl::{decode::client_command, encode::Encode, flow::State}; +use aero_user::config::AuthConfig; +use aero_user::login::ArcLoginProvider; + +pub struct AuthServer { + login_provider: ArcLoginProvider, + bind_addr: SocketAddr, +} + +impl AuthServer { + pub fn new(config: AuthConfig, login_provider: ArcLoginProvider) -> Self { + Self { + bind_addr: config.bind_addr, + login_provider, + } + } + + pub async fn run(self: Self, mut must_exit: watch::Receiver<bool>) -> Result<()> { + let tcp = TcpListener::bind(self.bind_addr).await?; + tracing::info!( + "SASL Authentication Protocol listening on {:#}", + self.bind_addr + ); + + let mut connections = FuturesUnordered::new(); + + while !*must_exit.borrow() { + let wait_conn_finished = async { + if connections.is_empty() { + futures::future::pending().await + } else { + connections.next().await + } + }; + + let (socket, remote_addr) = tokio::select! { + a = tcp.accept() => a?, + _ = wait_conn_finished => continue, + _ = must_exit.changed() => continue, + }; + + tracing::info!("AUTH: accepted connection from {}", remote_addr); + let conn = tokio::spawn( + NetLoop::new(socket, self.login_provider.clone(), must_exit.clone()).run_error(), + ); + + connections.push(conn); + } + drop(tcp); + + tracing::info!("AUTH server shutting down, draining remaining connections..."); + while connections.next().await.is_some() {} + + Ok(()) + } +} + +struct NetLoop { + login: ArcLoginProvider, + stream: BufStream<TcpStream>, + stop: watch::Receiver<bool>, + state: State, + read_buf: Vec<u8>, + write_buf: BytesMut, +} + +impl NetLoop { + fn new(stream: TcpStream, login: ArcLoginProvider, stop: watch::Receiver<bool>) -> Self { + Self { + login, + stream: BufStream::new(stream), + state: State::Init, + stop, + read_buf: Vec::new(), + write_buf: BytesMut::new(), + } + } + + async fn run_error(self) { + match self.run().await { + Ok(()) => tracing::info!("Auth session succeeded"), + Err(e) => tracing::error!(err=?e, "Auth session failed"), + } + } + + async fn run(mut self) -> Result<()> { + loop { + tokio::select! { + read_res = self.stream.read_until(b'\n', &mut self.read_buf) => { + // Detect EOF / socket close + let bread = read_res?; + if bread == 0 { + tracing::info!("Reading buffer empty, connection has been closed. Exiting AUTH session."); + return Ok(()) + } + + // Parse command + let (_, cmd) = client_command(&self.read_buf).map_err(|_| anyhow!("Unable to parse command"))?; + tracing::trace!(cmd=?cmd, "Received command"); + + // Make some progress in our local state + let login = async |user: String, pass: String| self.login.login(user.as_str(), pass.as_str()).await.is_ok(); + self.state.progress(cmd, login).await; + if matches!(self.state, State::Error) { + bail!("Internal state is in error, previous logs explain what went wrong"); + } + + // Build response + let srv_cmds = self.state.response(); + srv_cmds.iter().try_for_each(|r| { + tracing::trace!(cmd=?r, "Sent command"); + r.encode(&mut self.write_buf) + })?; + + // Send responses if at least one command response has been generated + if !srv_cmds.is_empty() { + self.stream.write_all(&self.write_buf).await?; + self.stream.flush().await?; + } + + // Reset buffers + self.read_buf.clear(); + self.write_buf.clear(); + }, + _ = self.stop.changed() => { + tracing::debug!("Server is stopping, quitting this runner"); + return Ok(()) + } + } + } + } +} |