aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/login/static_provider.rs73
-rw-r--r--src/main.rs2
-rw-r--r--src/server.rs8
3 files changed, 49 insertions, 34 deletions
diff --git a/src/login/static_provider.rs b/src/login/static_provider.rs
index 85d55ef..4a8d484 100644
--- a/src/login/static_provider.rs
+++ b/src/login/static_provider.rs
@@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::path::PathBuf;
+use tokio::sync::watch;
+use tokio::signal::unix::{signal, SignalKind};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@@ -9,48 +11,59 @@ use crate::config::*;
use crate::login::*;
use crate::storage;
-pub struct StaticLoginProvider {
- user_list: PathBuf,
+#[derive(Default)]
+pub struct UserDatabase {
users: HashMap<String, Arc<UserEntry>>,
users_by_email: HashMap<String, Arc<UserEntry>>,
}
-impl StaticLoginProvider {
- pub fn new(config: LoginStaticConfig) -> Result<Self> {
- let mut lp = Self {
- user_list: config.user_list.clone(),
- users: HashMap::new(),
- users_by_email: HashMap::new(),
- };
-
- lp
- .update_user_list()
- .context(
- format!(
- "failed to read {:?}, make sure it exists and it's correctly formatted",
- config.user_list))?;
+pub struct StaticLoginProvider {
+ user_db: watch::Receiver<UserDatabase>,
+}
- Ok(lp)
- }
+pub async fn update_user_list(config: PathBuf, up: watch::Sender<UserDatabase>) -> Result<()> {
+ let mut stream = signal(SignalKind::user_defined1()).expect("failed to install SIGUSR1 signal hander for reload");
- pub fn update_user_list(&mut self) -> Result<()> {
- let ulist: UserList = read_config(self.user_list.clone())?;
+ loop {
+ let ulist: UserList = match read_config(config.clone()) {
+ Ok(x) => x,
+ Err(e) => {
+ tracing::warn!(path=%config.as_path().to_string_lossy(), error=%e, "Unable to load config");
+ continue;
+ }
+ };
- self.users = ulist
+ let users = ulist
.into_iter()
.map(|(k, v)| (k, Arc::new(v)))
.collect::<HashMap<_, _>>();
- self.users_by_email.clear();
- for (_, u) in self.users.iter() {
+ let mut users_by_email = HashMap::new();
+ for (_, u) in users.iter() {
for m in u.email_addresses.iter() {
- if self.users_by_email.contains_key(m) {
- bail!("Several users have same email address: {}", m);
+ if users_by_email.contains_key(m) {
+ tracing::warn!("Several users have the same email address: {}", m);
+ continue
}
- self.users_by_email.insert(m.clone(), u.clone());
+ users_by_email.insert(m.clone(), u.clone());
}
}
- Ok(())
+
+ tracing::info!("{} users loaded", users.len());
+ up.send(UserDatabase { users, users_by_email }).context("update user db config")?;
+ stream.recv().await;
+ tracing::info!("Received SIGUSR1, reloading");
+ }
+}
+
+impl StaticLoginProvider {
+ pub async fn new(config: LoginStaticConfig) -> Result<Self> {
+ let (tx, mut rx) = watch::channel(UserDatabase::default());
+
+ tokio::spawn(update_user_list(config.user_list, tx));
+ rx.changed().await?;
+
+ Ok(Self { user_db: rx })
}
}
@@ -58,7 +71,8 @@ impl StaticLoginProvider {
impl LoginProvider for StaticLoginProvider {
async fn login(&self, username: &str, password: &str) -> Result<Credentials> {
tracing::debug!(user=%username, "login");
- let user = match self.users.get(username) {
+ let user_db = self.user_db.borrow();
+ let user = match user_db.users.get(username) {
None => bail!("User {} does not exist", username),
Some(u) => u,
};
@@ -89,7 +103,8 @@ impl LoginProvider for StaticLoginProvider {
}
async fn public_login(&self, email: &str) -> Result<PublicCredentials> {
- let user = match self.users_by_email.get(email) {
+ let user_db = self.user_db.borrow();
+ let user = match user_db.users_by_email.get(email) {
None => bail!("No user for email address {}", email),
Some(u) => u,
};
diff --git a/src/main.rs b/src/main.rs
index 3d87d11..02ba5e4 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -42,7 +42,7 @@ enum Command {
Provider(ProviderCommand),
#[clap(subcommand)]
- /// Specific tooling, should not be part of a normal workflow, for debug & experimenting only
+ /// Specific tooling, should not be part of a normal workflow, for debug & experimentation only
Tools(ToolsCommand),
//Test,
}
diff --git a/src/server.rs b/src/server.rs
index 2321da8..8abdb86 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -18,21 +18,21 @@ pub struct Server {
impl Server {
pub async fn from_companion_config(config: CompanionConfig) -> Result<Self> {
- let login = Arc::new(StaticLoginProvider::new(config.users)?);
+ let login = Arc::new(StaticLoginProvider::new(config.users).await?);
let lmtp_server = None;
- let imap_server = Some(imap::new(config.imap, login).await?);
+ let imap_server = Some(imap::new(config.imap, login.clone()).await?);
Ok(Self { lmtp_server, imap_server })
}
pub async fn from_provider_config(config: ProviderConfig) -> Result<Self> {
let login: ArcLoginProvider = match config.users {
- UserManagement::Static(x) => Arc::new(StaticLoginProvider::new(x)?),
+ UserManagement::Static(x) => Arc::new(StaticLoginProvider::new(x).await?),
UserManagement::Ldap(x) => Arc::new(LdapLoginProvider::new(x)?),
};
let lmtp_server = Some(LmtpServer::new(config.lmtp, login.clone()));
- let imap_server = Some(imap::new(config.imap, login).await?);
+ let imap_server = Some(imap::new(config.imap, login.clone()).await?);
Ok(Self { lmtp_server, imap_server })
}