diff options
Diffstat (limited to 'aero-proto/src/sasl.rs')
-rw-r--r-- | aero-proto/src/sasl.rs | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/aero-proto/src/sasl.rs b/aero-proto/src/sasl.rs new file mode 100644 index 0000000..48c0815 --- /dev/null +++ b/aero-proto/src/sasl.rs @@ -0,0 +1,142 @@ +use std::net::SocketAddr; + +use anyhow::{anyhow, bail, Result}; +use futures::stream::{FuturesUnordered, StreamExt}; +use tokio::io::BufStream; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::watch; +use tokio_util::bytes::BytesMut; + +use aero_sasl::{decode::client_command, encode::Encode, flow::State}; +use aero_user::config::AuthConfig; +use aero_user::login::ArcLoginProvider; + +pub struct AuthServer { + login_provider: ArcLoginProvider, + bind_addr: SocketAddr, +} + +impl AuthServer { + pub fn new(config: AuthConfig, login_provider: ArcLoginProvider) -> Self { + Self { + bind_addr: config.bind_addr, + login_provider, + } + } + + pub async fn run(self: Self, mut must_exit: watch::Receiver<bool>) -> Result<()> { + let tcp = TcpListener::bind(self.bind_addr).await?; + tracing::info!( + "SASL Authentication Protocol listening on {:#}", + self.bind_addr + ); + + let mut connections = FuturesUnordered::new(); + + while !*must_exit.borrow() { + let wait_conn_finished = async { + if connections.is_empty() { + futures::future::pending().await + } else { + connections.next().await + } + }; + + let (socket, remote_addr) = tokio::select! { + a = tcp.accept() => a?, + _ = wait_conn_finished => continue, + _ = must_exit.changed() => continue, + }; + + tracing::info!("AUTH: accepted connection from {}", remote_addr); + let conn = tokio::spawn( + NetLoop::new(socket, self.login_provider.clone(), must_exit.clone()).run_error(), + ); + + connections.push(conn); + } + drop(tcp); + + tracing::info!("AUTH server shutting down, draining remaining connections..."); + while connections.next().await.is_some() {} + + Ok(()) + } +} + +struct NetLoop { + login: ArcLoginProvider, + stream: BufStream<TcpStream>, + stop: watch::Receiver<bool>, + state: State, + read_buf: Vec<u8>, + write_buf: BytesMut, +} + +impl NetLoop { + fn new(stream: TcpStream, login: ArcLoginProvider, stop: watch::Receiver<bool>) -> Self { + Self { + login, + stream: BufStream::new(stream), + state: State::Init, + stop, + read_buf: Vec::new(), + write_buf: BytesMut::new(), + } + } + + async fn run_error(self) { + match self.run().await { + Ok(()) => tracing::info!("Auth session succeeded"), + Err(e) => tracing::error!(err=?e, "Auth session failed"), + } + } + + async fn run(mut self) -> Result<()> { + loop { + tokio::select! { + read_res = self.stream.read_until(b'\n', &mut self.read_buf) => { + // Detect EOF / socket close + let bread = read_res?; + if bread == 0 { + tracing::info!("Reading buffer empty, connection has been closed. Exiting AUTH session."); + return Ok(()) + } + + // Parse command + let (_, cmd) = client_command(&self.read_buf).map_err(|_| anyhow!("Unable to parse command"))?; + tracing::trace!(cmd=?cmd, "Received command"); + + // Make some progress in our local state + let login = async |user: String, pass: String| self.login.login(user.as_str(), pass.as_str()).await.is_ok(); + self.state.progress(cmd, login).await; + if matches!(self.state, State::Error) { + bail!("Internal state is in error, previous logs explain what went wrong"); + } + + // Build response + let srv_cmds = self.state.response(); + srv_cmds.iter().try_for_each(|r| { + tracing::trace!(cmd=?r, "Sent command"); + r.encode(&mut self.write_buf) + })?; + + // Send responses if at least one command response has been generated + if !srv_cmds.is_empty() { + self.stream.write_all(&self.write_buf).await?; + self.stream.flush().await?; + } + + // Reset buffers + self.read_buf.clear(); + self.write_buf.clear(); + }, + _ = self.stop.changed() => { + tracing::debug!("Server is stopping, quitting this runner"); + return Ok(()) + } + } + } + } +} |