aboutsummaryrefslogtreecommitdiff
path: root/aero-sasl/src/flow.rs
diff options
context:
space:
mode:
Diffstat (limited to 'aero-sasl/src/flow.rs')
-rw-r--r--aero-sasl/src/flow.rs201
1 files changed, 201 insertions, 0 deletions
diff --git a/aero-sasl/src/flow.rs b/aero-sasl/src/flow.rs
new file mode 100644
index 0000000..6cc698a
--- /dev/null
+++ b/aero-sasl/src/flow.rs
@@ -0,0 +1,201 @@
+use futures::Future;
+use rand::prelude::*;
+
+use super::types::*;
+use super::decode::auth_plain;
+
+#[derive(Debug)]
+pub enum AuthRes {
+ Success(String),
+ Failed(Option<String>, Option<FailCode>),
+}
+
+#[derive(Debug)]
+pub enum State {
+ Error,
+ Init,
+ HandshakePart(Version),
+ HandshakeDone,
+ AuthPlainProgress { id: u64 },
+ AuthDone { id: u64, res: AuthRes },
+}
+
+const SERVER_MAJOR: u64 = 1;
+const SERVER_MINOR: u64 = 2;
+const EMPTY_AUTHZ: &[u8] = &[];
+impl State {
+ pub fn new() -> Self {
+ Self::Init
+ }
+
+ async fn try_auth_plain<'a, X, F>(&self, data: &'a [u8], login: X) -> AuthRes
+ where
+ X: FnOnce(&'a str, &'a str) -> F,
+ F: Future<Output=bool>,
+ {
+ // Check that we can extract user's login+pass
+ let (ubin, pbin) = match auth_plain(&data) {
+ Ok(([], (authz, user, pass))) if authz == user || authz == EMPTY_AUTHZ => (user, pass),
+ Ok(_) => {
+ tracing::error!("Impersonating user is not supported");
+ return AuthRes::Failed(None, None);
+ }
+ Err(e) => {
+ tracing::error!(err=?e, "Could not parse the SASL PLAIN data chunk");
+ return AuthRes::Failed(None, None);
+ }
+ };
+
+ // Try to convert it to UTF-8
+ let (user, password) = match (std::str::from_utf8(ubin), std::str::from_utf8(pbin)) {
+ (Ok(u), Ok(p)) => (u, p),
+ _ => {
+ tracing::error!("Username or password contain invalid UTF-8 characters");
+ return AuthRes::Failed(None, None);
+ }
+ };
+
+ // Try to connect user
+ match login(user, password).await {
+ true => AuthRes::Success(user.to_string()),
+ false => {
+ tracing::warn!("login failed");
+ AuthRes::Failed(Some(user.to_string()), None)
+ }
+ }
+ }
+
+ pub async fn progress<F,X>(&mut self, cmd: ClientCommand, login: X)
+ where
+ X: FnOnce(&str, &str) -> F,
+ F: Future<Output=bool>,
+ {
+ let new_state = 'state: {
+ match (std::mem::replace(self, State::Error), cmd) {
+ (Self::Init, ClientCommand::Version(v)) => Self::HandshakePart(v),
+ (Self::HandshakePart(version), ClientCommand::Cpid(_cpid)) => {
+ if version.major != SERVER_MAJOR {
+ tracing::error!(
+ client_major = version.major,
+ server_major = SERVER_MAJOR,
+ "Unsupported client major version"
+ );
+ break 'state Self::Error;
+ }
+
+ Self::HandshakeDone
+ }
+ (
+ Self::HandshakeDone { .. },
+ ClientCommand::Auth {
+ id, mech, options, ..
+ },
+ )
+ | (
+ Self::AuthDone { .. },
+ ClientCommand::Auth {
+ id, mech, options, ..
+ },
+ ) => {
+ if mech != Mechanism::Plain {
+ tracing::error!(mechanism=?mech, "Unsupported Authentication Mechanism");
+ break 'state Self::AuthDone {
+ id,
+ res: AuthRes::Failed(None, None),
+ };
+ }
+
+ match options.last() {
+ Some(AuthOption::Resp(data)) => Self::AuthDone {
+ id,
+ res: self.try_auth_plain(&data, login).await,
+ },
+ _ => Self::AuthPlainProgress { id },
+ }
+ }
+ (Self::AuthPlainProgress { id }, ClientCommand::Cont { id: cid, data }) => {
+ // Check that ID matches
+ if cid != id {
+ tracing::error!(
+ auth_id = id,
+ cont_id = cid,
+ "CONT id does not match AUTH id"
+ );
+ break 'state Self::AuthDone {
+ id,
+ res: AuthRes::Failed(None, None),
+ };
+ }
+
+ Self::AuthDone {
+ id,
+ res: self.try_auth_plain(&data, login).await,
+ }
+ }
+ _ => {
+ tracing::error!("This command is not valid in this context");
+ Self::Error
+ }
+ }
+ };
+ tracing::debug!(state=?new_state, "Made progress");
+ *self = new_state;
+ }
+
+ pub fn response(&self) -> Vec<ServerCommand> {
+ let mut srv_cmd: Vec<ServerCommand> = Vec::new();
+
+ match self {
+ Self::HandshakeDone { .. } => {
+ srv_cmd.push(ServerCommand::Version(Version {
+ major: SERVER_MAJOR,
+ minor: SERVER_MINOR,
+ }));
+
+ srv_cmd.push(ServerCommand::Mech {
+ kind: Mechanism::Plain,
+ parameters: vec![MechanismParameters::PlainText],
+ });
+
+ srv_cmd.push(ServerCommand::Spid(15u64));
+ srv_cmd.push(ServerCommand::Cuid(19350u64));
+
+ let mut cookie = [0u8; 16];
+ thread_rng().fill(&mut cookie);
+ srv_cmd.push(ServerCommand::Cookie(cookie));
+
+ srv_cmd.push(ServerCommand::Done);
+ }
+ Self::AuthPlainProgress { id } => {
+ srv_cmd.push(ServerCommand::Cont {
+ id: *id,
+ data: None,
+ });
+ }
+ Self::AuthDone {
+ id,
+ res: AuthRes::Success(user),
+ } => {
+ srv_cmd.push(ServerCommand::Ok {
+ id: *id,
+ user_id: Some(user.to_string()),
+ extra_parameters: vec![],
+ });
+ }
+ Self::AuthDone {
+ id,
+ res: AuthRes::Failed(maybe_user, maybe_failcode),
+ } => {
+ srv_cmd.push(ServerCommand::Fail {
+ id: *id,
+ user_id: maybe_user.clone(),
+ code: maybe_failcode.clone(),
+ extra_parameters: vec![],
+ });
+ }
+ _ => (),
+ };
+
+ srv_cmd
+ }
+}