aboutsummaryrefslogtreecommitdiff
path: root/aero-proto/src/sasl.rs
blob: dae89ebf3147bf7169f8b004116f5dc2e8080c90 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use std::net::SocketAddr;

use anyhow::{anyhow, bail, Result};
use futures::stream::{FuturesUnordered, StreamExt};
use tokio::io::BufStream;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch;
use tokio_util::bytes::BytesMut;

use aero_user::config::AuthConfig;
use aero_user::login::ArcLoginProvider;
use aero_sasl::{flow::State, decode::client_command, encode::Encode};

pub struct AuthServer {
    login_provider: ArcLoginProvider,
    bind_addr: SocketAddr,
}

impl AuthServer {
    pub fn new(config: AuthConfig, login_provider: ArcLoginProvider) -> Self {
        Self {
            bind_addr: config.bind_addr,
            login_provider,
        }
    }

    pub async fn run(self: Self, mut must_exit: watch::Receiver<bool>) -> Result<()> {
        let tcp = TcpListener::bind(self.bind_addr).await?;
        tracing::info!(
            "SASL Authentication Protocol listening on {:#}",
            self.bind_addr
        );

        let mut connections = FuturesUnordered::new();

        while !*must_exit.borrow() {
            let wait_conn_finished = async {
                if connections.is_empty() {
                    futures::future::pending().await
                } else {
                    connections.next().await
                }
            };

            let (socket, remote_addr) = tokio::select! {
                a = tcp.accept() => a?,
                _ = wait_conn_finished => continue,
                _ = must_exit.changed() => continue,
            };

            tracing::info!("AUTH: accepted connection from {}", remote_addr);
            let conn = tokio::spawn(
                NetLoop::new(socket, self.login_provider.clone(), must_exit.clone()).run_error(),
            );

            connections.push(conn);
        }
        drop(tcp);

        tracing::info!("AUTH server shutting down, draining remaining connections...");
        while connections.next().await.is_some() {}

        Ok(())
    }
}

struct NetLoop {
    login: ArcLoginProvider,
    stream: BufStream<TcpStream>,
    stop: watch::Receiver<bool>,
    state: State,
    read_buf: Vec<u8>,
    write_buf: BytesMut,
}

impl NetLoop {
    fn new(stream: TcpStream, login: ArcLoginProvider, stop: watch::Receiver<bool>) -> Self {
        Self {
            login,
            stream: BufStream::new(stream),
            state: State::Init,
            stop,
            read_buf: Vec::new(),
            write_buf: BytesMut::new(),
        }
    }

    async fn run_error(self) {
        match self.run().await {
            Ok(()) => tracing::info!("Auth session succeeded"),
            Err(e) => tracing::error!(err=?e, "Auth session failed"),
        }
    }

    async fn run(mut self) -> Result<()> {
        loop {
            tokio::select! {
                read_res = self.stream.read_until(b'\n', &mut self.read_buf) => {
                    // Detect EOF / socket close
                    let bread = read_res?;
                    if bread == 0 {
                        tracing::info!("Reading buffer empty, connection has been closed. Exiting AUTH session.");
                        return Ok(())
                    }

                    // Parse command
                    let (_, cmd) = client_command(&self.read_buf).map_err(|_| anyhow!("Unable to parse command"))?;
                    tracing::trace!(cmd=?cmd, "Received command");

                    // Make some progress in our local state
                    let login = async |user: String, pass: String| self.login.login(user.as_str(), pass.as_str()).await.is_ok();
                    self.state.progress(cmd, login).await;
                    if matches!(self.state, State::Error) {
                        bail!("Internal state is in error, previous logs explain what went wrong");
                    }

                    // Build response
                    let srv_cmds = self.state.response();
                    srv_cmds.iter().try_for_each(|r| {
                        tracing::trace!(cmd=?r, "Sent command");
                        r.encode(&mut self.write_buf)
                    })?;

                    // Send responses if at least one command response has been generated
                    if !srv_cmds.is_empty() {
                        self.stream.write_all(&self.write_buf).await?;
                        self.stream.flush().await?;
                    }

                    // Reset buffers
                    self.read_buf.clear();
                    self.write_buf.clear();
                },
                _ = self.stop.changed() => {
                    tracing::debug!("Server is stopping, quitting this runner");
                    return Ok(())
                }
            }
        }
    }
}