aboutsummaryrefslogtreecommitdiff
path: root/src/login/static_provider.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/login/static_provider.rs')
-rw-r--r--src/login/static_provider.rs73
1 files changed, 44 insertions, 29 deletions
diff --git a/src/login/static_provider.rs b/src/login/static_provider.rs
index 85d55ef..4a8d484 100644
--- a/src/login/static_provider.rs
+++ b/src/login/static_provider.rs
@@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::path::PathBuf;
+use tokio::sync::watch;
+use tokio::signal::unix::{signal, SignalKind};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@@ -9,48 +11,59 @@ use crate::config::*;
use crate::login::*;
use crate::storage;
-pub struct StaticLoginProvider {
- user_list: PathBuf,
+#[derive(Default)]
+pub struct UserDatabase {
users: HashMap<String, Arc<UserEntry>>,
users_by_email: HashMap<String, Arc<UserEntry>>,
}
-impl StaticLoginProvider {
- pub fn new(config: LoginStaticConfig) -> Result<Self> {
- let mut lp = Self {
- user_list: config.user_list.clone(),
- users: HashMap::new(),
- users_by_email: HashMap::new(),
- };
-
- lp
- .update_user_list()
- .context(
- format!(
- "failed to read {:?}, make sure it exists and it's correctly formatted",
- config.user_list))?;
+pub struct StaticLoginProvider {
+ user_db: watch::Receiver<UserDatabase>,
+}
- Ok(lp)
- }
+pub async fn update_user_list(config: PathBuf, up: watch::Sender<UserDatabase>) -> Result<()> {
+ let mut stream = signal(SignalKind::user_defined1()).expect("failed to install SIGUSR1 signal hander for reload");
- pub fn update_user_list(&mut self) -> Result<()> {
- let ulist: UserList = read_config(self.user_list.clone())?;
+ loop {
+ let ulist: UserList = match read_config(config.clone()) {
+ Ok(x) => x,
+ Err(e) => {
+ tracing::warn!(path=%config.as_path().to_string_lossy(), error=%e, "Unable to load config");
+ continue;
+ }
+ };
- self.users = ulist
+ let users = ulist
.into_iter()
.map(|(k, v)| (k, Arc::new(v)))
.collect::<HashMap<_, _>>();
- self.users_by_email.clear();
- for (_, u) in self.users.iter() {
+ let mut users_by_email = HashMap::new();
+ for (_, u) in users.iter() {
for m in u.email_addresses.iter() {
- if self.users_by_email.contains_key(m) {
- bail!("Several users have same email address: {}", m);
+ if users_by_email.contains_key(m) {
+ tracing::warn!("Several users have the same email address: {}", m);
+ continue
}
- self.users_by_email.insert(m.clone(), u.clone());
+ users_by_email.insert(m.clone(), u.clone());
}
}
- Ok(())
+
+ tracing::info!("{} users loaded", users.len());
+ up.send(UserDatabase { users, users_by_email }).context("update user db config")?;
+ stream.recv().await;
+ tracing::info!("Received SIGUSR1, reloading");
+ }
+}
+
+impl StaticLoginProvider {
+ pub async fn new(config: LoginStaticConfig) -> Result<Self> {
+ let (tx, mut rx) = watch::channel(UserDatabase::default());
+
+ tokio::spawn(update_user_list(config.user_list, tx));
+ rx.changed().await?;
+
+ Ok(Self { user_db: rx })
}
}
@@ -58,7 +71,8 @@ impl StaticLoginProvider {
impl LoginProvider for StaticLoginProvider {
async fn login(&self, username: &str, password: &str) -> Result<Credentials> {
tracing::debug!(user=%username, "login");
- let user = match self.users.get(username) {
+ let user_db = self.user_db.borrow();
+ let user = match user_db.users.get(username) {
None => bail!("User {} does not exist", username),
Some(u) => u,
};
@@ -89,7 +103,8 @@ impl LoginProvider for StaticLoginProvider {
}
async fn public_login(&self, email: &str) -> Result<PublicCredentials> {
- let user = match self.users_by_email.get(email) {
+ let user_db = self.user_db.borrow();
+ let user = match user_db.users_by_email.get(email) {
None => bail!("No user for email address {}", email),
Some(u) => u,
};