diff options
Diffstat (limited to 'src/login')
-rw-r--r-- | src/login/static_provider.rs | 73 |
1 files changed, 44 insertions, 29 deletions
diff --git a/src/login/static_provider.rs b/src/login/static_provider.rs index 85d55ef..4a8d484 100644 --- a/src/login/static_provider.rs +++ b/src/login/static_provider.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; use std::sync::Arc; use std::path::PathBuf; +use tokio::sync::watch; +use tokio::signal::unix::{signal, SignalKind}; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -9,48 +11,59 @@ use crate::config::*; use crate::login::*; use crate::storage; -pub struct StaticLoginProvider { - user_list: PathBuf, +#[derive(Default)] +pub struct UserDatabase { users: HashMap<String, Arc<UserEntry>>, users_by_email: HashMap<String, Arc<UserEntry>>, } -impl StaticLoginProvider { - pub fn new(config: LoginStaticConfig) -> Result<Self> { - let mut lp = Self { - user_list: config.user_list.clone(), - users: HashMap::new(), - users_by_email: HashMap::new(), - }; - - lp - .update_user_list() - .context( - format!( - "failed to read {:?}, make sure it exists and it's correctly formatted", - config.user_list))?; +pub struct StaticLoginProvider { + user_db: watch::Receiver<UserDatabase>, +} - Ok(lp) - } +pub async fn update_user_list(config: PathBuf, up: watch::Sender<UserDatabase>) -> Result<()> { + let mut stream = signal(SignalKind::user_defined1()).expect("failed to install SIGUSR1 signal hander for reload"); - pub fn update_user_list(&mut self) -> Result<()> { - let ulist: UserList = read_config(self.user_list.clone())?; + loop { + let ulist: UserList = match read_config(config.clone()) { + Ok(x) => x, + Err(e) => { + tracing::warn!(path=%config.as_path().to_string_lossy(), error=%e, "Unable to load config"); + continue; + } + }; - self.users = ulist + let users = ulist .into_iter() .map(|(k, v)| (k, Arc::new(v))) .collect::<HashMap<_, _>>(); - self.users_by_email.clear(); - for (_, u) in self.users.iter() { + let mut users_by_email = HashMap::new(); + for (_, u) in users.iter() { for m in u.email_addresses.iter() { - if self.users_by_email.contains_key(m) { - bail!("Several users have same email address: {}", m); + if users_by_email.contains_key(m) { + tracing::warn!("Several users have the same email address: {}", m); + continue } - self.users_by_email.insert(m.clone(), u.clone()); + users_by_email.insert(m.clone(), u.clone()); } } - Ok(()) + + tracing::info!("{} users loaded", users.len()); + up.send(UserDatabase { users, users_by_email }).context("update user db config")?; + stream.recv().await; + tracing::info!("Received SIGUSR1, reloading"); + } +} + +impl StaticLoginProvider { + pub async fn new(config: LoginStaticConfig) -> Result<Self> { + let (tx, mut rx) = watch::channel(UserDatabase::default()); + + tokio::spawn(update_user_list(config.user_list, tx)); + rx.changed().await?; + + Ok(Self { user_db: rx }) } } @@ -58,7 +71,8 @@ impl StaticLoginProvider { impl LoginProvider for StaticLoginProvider { async fn login(&self, username: &str, password: &str) -> Result<Credentials> { tracing::debug!(user=%username, "login"); - let user = match self.users.get(username) { + let user_db = self.user_db.borrow(); + let user = match user_db.users.get(username) { None => bail!("User {} does not exist", username), Some(u) => u, }; @@ -89,7 +103,8 @@ impl LoginProvider for StaticLoginProvider { } async fn public_login(&self, email: &str) -> Result<PublicCredentials> { - let user = match self.users_by_email.get(email) { + let user_db = self.user_db.borrow(); + let user = match user_db.users_by_email.get(email) { None => bail!("No user for email address {}", email), Some(u) => u, }; |