aboutsummaryrefslogblamecommitdiff
path: root/src/lmtp.rs
blob: a0dafa561fb95a53b672835d4666e0c2e4c229f1 (plain) (tree)





















                                                                                       
                               


































































                                                                                            
                                                

































































                                                                                            
                                                           
















































































                                                                                       
                                                                  










                                                            
use std::collections::HashMap;
use std::net::SocketAddr;
use std::{pin::Pin, sync::Arc};

use anyhow::{bail, Result};
use async_trait::async_trait;
use duplexify::Duplex;
use futures::{io, AsyncRead, AsyncReadExt, AsyncWrite};
use futures::{stream, stream::FuturesUnordered, StreamExt};
use log::*;
use rusoto_s3::{PutObjectRequest, S3Client, S3};
use tokio::net::{TcpListener, TcpStream};
use tokio::select;
use tokio::sync::watch;
use tokio_util::compat::*;

use smtp_message::{Email, EscapedDataReader, Reply, ReplyCode};
use smtp_server::{reply, Config, ConnectionMetadata, Decision, MailMetadata, Protocol};

use crate::config::*;
use crate::cryptoblob::*;
use crate::login::*;
use crate::mail::mail_ident::*;

pub struct LmtpServer {
    bind_addr: SocketAddr,
    hostname: String,
    login_provider: Arc<dyn LoginProvider + Send + Sync>,
}

impl LmtpServer {
    pub fn new(
        config: LmtpConfig,
        login_provider: Arc<dyn LoginProvider + Send + Sync>,
    ) -> Arc<Self> {
        Arc::new(Self {
            bind_addr: config.bind_addr,
            hostname: config.hostname,
            login_provider,
        })
    }

    pub async fn run(self: &Arc<Self>, mut must_exit: watch::Receiver<bool>) -> Result<()> {
        let tcp = TcpListener::bind(self.bind_addr).await?;
        let mut connections = FuturesUnordered::new();

        while !*must_exit.borrow() {
            let wait_conn_finished = async {
                if connections.is_empty() {
                    futures::future::pending().await
                } else {
                    connections.next().await
                }
            };
            let (socket, remote_addr) = select! {
                a = tcp.accept() => a?,
                _ = wait_conn_finished => continue,
                _ = must_exit.changed() => continue,
            };

            let conn = tokio::spawn(smtp_server::interact(
                socket.compat(),
                smtp_server::IsAlreadyTls::No,
                Conn { remote_addr },
                self.clone(),
            ));

            connections.push(conn);
        }
        drop(tcp);

        info!("LMTP server shutting down, draining remaining connections...");
        while connections.next().await.is_some() {}

        Ok(())
    }
}

// ----

pub struct Conn {
    remote_addr: SocketAddr,
}

pub struct Message {
    to: Vec<PublicCredentials>,
}

#[async_trait]
impl Config for LmtpServer {
    type Protocol = smtp_server::protocol::Lmtp;

    type ConnectionUserMeta = Conn;
    type MailUserMeta = Message;

    fn hostname(&self, _conn_meta: &ConnectionMetadata<Conn>) -> &str {
        &self.hostname
    }

    async fn new_mail(&self, _conn_meta: &mut ConnectionMetadata<Conn>) -> Message {
        Message { to: vec![] }
    }

    async fn tls_accept<IO>(
        &self,
        _io: IO,
        _conn_meta: &mut ConnectionMetadata<Conn>,
    ) -> io::Result<Duplex<Pin<Box<dyn Send + AsyncRead>>, Pin<Box<dyn Send + AsyncWrite>>>>
    where
        IO: Send + AsyncRead + AsyncWrite,
    {
        Err(io::Error::new(
            io::ErrorKind::InvalidInput,
            "TLS not implemented for LMTP server",
        ))
    }

    async fn filter_from(
        &self,
        from: Option<Email>,
        meta: &mut MailMetadata<Message>,
        _conn_meta: &mut ConnectionMetadata<Conn>,
    ) -> Decision<Option<Email>> {
        Decision::Accept {
            reply: reply::okay_from().convert(),
            res: from,
        }
    }

    async fn filter_to(
        &self,
        to: Email,
        meta: &mut MailMetadata<Message>,
        _conn_meta: &mut ConnectionMetadata<Conn>,
    ) -> Decision<Email> {
        let to_str = match to.hostname.as_ref() {
            Some(h) => format!("{}@{}", to.localpart, h),
            None => to.localpart.to_string(),
        };
        match self.login_provider.public_login(&to_str).await {
            Ok(creds) => {
                meta.user.to.push(creds);
                Decision::Accept {
                    reply: reply::okay_to().convert(),
                    res: to,
                }
            }
            Err(e) => Decision::Reject {
                reply: Reply {
                    code: ReplyCode::POLICY_REASON,
                    ecode: None,
                    text: vec![smtp_message::MaybeUtf8::Utf8(e.to_string())],
                },
            },
        }
    }

    async fn handle_mail<'a, 'slife0, 'slife1, 'stream, R>(
        &'slife0 self,
        reader: &mut EscapedDataReader<'a, R>,
        meta: MailMetadata<Message>,
        conn_meta: &'slife1 mut ConnectionMetadata<Conn>,
    ) -> Pin<Box<dyn futures::Stream<Item = Decision<()>> + Send + 'stream>>
    where
        R: Send + Unpin + AsyncRead,
        'slife0: 'stream,
        'slife1: 'stream,
        Self: 'stream,
    {
        let err_response_stream = |meta: MailMetadata<Message>, msg: String| {
            Box::pin(
                stream::iter(meta.user.to.into_iter()).map(move |_| Decision::Reject {
                    reply: Reply {
                        code: ReplyCode::POLICY_REASON,
                        ecode: None,
                        text: vec![smtp_message::MaybeUtf8::Utf8(msg.clone())],
                    },
                }),
            )
        };

        let mut text = Vec::new();
        if reader.read_to_end(&mut text).await.is_err() {
            return err_response_stream(meta, "io error".into());
        }
        reader.complete();

        let encrypted_message = match EncryptedMessage::new(text) {
            Ok(x) => Arc::new(x),
            Err(e) => return err_response_stream(meta, e.to_string()),
        };

        Box::pin(stream::iter(meta.user.to.into_iter()).then(move |creds| {
            let encrypted_message = encrypted_message.clone();
            async move {
                match encrypted_message.deliver_to(creds).await {
                    Ok(()) => Decision::Accept {
                        reply: reply::okay_mail().convert(),
                        res: (),
                    },
                    Err(e) => Decision::Reject {
                        reply: Reply {
                            code: ReplyCode::POLICY_REASON,
                            ecode: None,
                            text: vec![smtp_message::MaybeUtf8::Utf8(e.to_string())],
                        },
                    },
                }
            }
        }))
    }
}

// ----

struct EncryptedMessage {
    key: Key,
    encrypted_body: Vec<u8>,
}

impl EncryptedMessage {
    fn new(body: Vec<u8>) -> Result<Self> {
        let key = gen_key();
        let encrypted_body = seal(&body, &key)?;
        Ok(Self {
            key,
            encrypted_body,
        })
    }

    async fn deliver_to(self: Arc<Self>, creds: PublicCredentials) -> Result<()> {
        let s3_client = creds.storage.s3_client()?;

        let encrypted_key =
            sodiumoxide::crypto::sealedbox::seal(self.key.as_ref(), &creds.public_key);
        let key_header = base64::encode(&encrypted_key);

        let mut por = PutObjectRequest::default();
        por.bucket = creds.storage.bucket.clone();
        por.key = format!("incoming/{}", gen_ident().to_string());
        por.metadata = Some(
            [("Message-Key".to_string(), key_header)]
                .into_iter()
                .collect::<HashMap<_, _>>(),
        );
        por.body = Some(self.encrypted_body.clone().into());
        s3_client.put_object(por).await?;

        Ok(())
    }
}