From 1a43ce5ac7033c148f64a033f2b1d335e95e11d5 Mon Sep 17 00:00:00 2001 From: Quentin Dufour Date: Fri, 8 Mar 2024 08:17:03 +0100 Subject: WIP refactor --- aero-bayou/src/lib.rs | 517 ++++++++++++++++++++++++++++++++++++++++++++ aero-bayou/src/timestamp.rs | 66 ++++++ 2 files changed, 583 insertions(+) create mode 100644 aero-bayou/src/lib.rs create mode 100644 aero-bayou/src/timestamp.rs (limited to 'aero-bayou/src') diff --git a/aero-bayou/src/lib.rs b/aero-bayou/src/lib.rs new file mode 100644 index 0000000..7756964 --- /dev/null +++ b/aero-bayou/src/lib.rs @@ -0,0 +1,517 @@ +mod timestamp + +use std::sync::{Arc, Weak}; +use std::time::{Duration, Instant}; + +use anyhow::{anyhow, bail, Result}; +use log::error; +use rand::prelude::*; +use serde::{Deserialize, Serialize}; +use tokio::sync::{watch, Notify}; + +use aero_foundations::cryptoblob::*; +use aero_foundations::login::Credentials; +use aero_foundations::storage; + +use crate::timestamp::*; + +const KEEP_STATE_EVERY: usize = 64; + +// Checkpointing interval constants: a checkpoint is not made earlier +// than CHECKPOINT_INTERVAL time after the last one, and is not made +// if there are less than CHECKPOINT_MIN_OPS new operations since last one. +const CHECKPOINT_INTERVAL: Duration = Duration::from_secs(6 * 3600); +const CHECKPOINT_MIN_OPS: usize = 16; +// HYPOTHESIS: processes are able to communicate in a synchronous +// fashion in times that are small compared to CHECKPOINT_INTERVAL. +// More precisely, if a process tried to save an operation within the last +// CHECKPOINT_INTERVAL, we are sure to read it from storage if it was +// successfully saved (and if we don't read it, it means it has been +// definitely discarded due to an error). + +// Keep at least two checkpoints, here three, to avoid race conditions +// between processes doing .checkpoint() and those doing .sync() +const CHECKPOINTS_TO_KEEP: usize = 3; + +const WATCH_SK: &str = "watch"; + +pub trait BayouState: + Default + Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static +{ + type Op: Clone + Serialize + for<'de> Deserialize<'de> + std::fmt::Debug + Send + Sync + 'static; + + fn apply(&self, op: &Self::Op) -> Self; +} + +pub struct Bayou { + path: String, + key: Key, + + storage: storage::Store, + + checkpoint: (Timestamp, S), + history: Vec<(Timestamp, S::Op, Option)>, + + last_sync: Option, + last_try_checkpoint: Option, + + watch: Arc, + last_sync_watch_ct: storage::RowRef, +} + +impl Bayou { + pub async fn new(creds: &Credentials, path: String) -> Result { + let storage = creds.storage.build().await?; + + //let target = k2v_client.row(&path, WATCH_SK); + let target = storage::RowRef::new(&path, WATCH_SK); + let watch = K2vWatch::new(creds, target.clone()).await?; + + Ok(Self { + path, + storage, + key: creds.keys.master.clone(), + checkpoint: (Timestamp::zero(), S::default()), + history: vec![], + last_sync: None, + last_try_checkpoint: None, + watch, + last_sync_watch_ct: target, + }) + } + + /// Re-reads the state from persistent storage backend + pub async fn sync(&mut self) -> Result<()> { + let new_last_sync = Some(Instant::now()); + let new_last_sync_watch_ct = self.watch.rx.borrow().clone(); + + // 1. List checkpoints + let checkpoints = self.list_checkpoints().await?; + tracing::debug!("(sync) listed checkpoints: {:?}", checkpoints); + + // 2. Load last checkpoint if different from currently used one + let checkpoint = if let Some((ts, key)) = checkpoints.last() { + if *ts == self.checkpoint.0 { + (*ts, None) + } else { + tracing::debug!("(sync) loading checkpoint: {}", key); + + let buf = self + .storage + .blob_fetch(&storage::BlobRef(key.to_string())) + .await? + .value; + tracing::debug!("(sync) checkpoint body length: {}", buf.len()); + + let ck = open_deserialize::(&buf, &self.key)?; + (*ts, Some(ck)) + } + } else { + (Timestamp::zero(), None) + }; + + if self.checkpoint.0 > checkpoint.0 { + bail!("Loaded checkpoint is more recent than stored one"); + } + + if let Some(ck) = checkpoint.1 { + tracing::debug!( + "(sync) updating checkpoint to loaded state at {:?}", + checkpoint.0 + ); + self.checkpoint = (checkpoint.0, ck); + }; + + // remove from history events before checkpoint + self.history = std::mem::take(&mut self.history) + .into_iter() + .skip_while(|(ts, _, _)| *ts < self.checkpoint.0) + .collect(); + + // 3. List all operations starting from checkpoint + let ts_ser = self.checkpoint.0.to_string(); + tracing::debug!("(sync) looking up operations starting at {}", ts_ser); + let ops_map = self + .storage + .row_fetch(&storage::Selector::Range { + shard: &self.path, + sort_begin: &ts_ser, + sort_end: WATCH_SK, + }) + .await?; + + let mut ops = vec![]; + for row_value in ops_map { + let row = row_value.row_ref; + let sort_key = row.uid.sort; + let ts = sort_key + .parse::() + .map_err(|_| anyhow!("Invalid operation timestamp: {}", sort_key))?; + + let val = row_value.value; + if val.len() != 1 { + bail!("Invalid operation, has {} values", val.len()); + } + match &val[0] { + storage::Alternative::Value(v) => { + let op = open_deserialize::(v, &self.key)?; + tracing::trace!("(sync) operation {}: {:?}", sort_key, op); + ops.push((ts, op)); + } + storage::Alternative::Tombstone => { + continue; + } + } + } + ops.sort_by_key(|(ts, _)| *ts); + tracing::debug!("(sync) {} operations", ops.len()); + + if ops.len() < self.history.len() { + bail!("Some operations have disappeared from storage!"); + } + + // 4. Check that first operation has same timestamp as checkpoint (if not zero) + if self.checkpoint.0 != Timestamp::zero() && ops[0].0 != self.checkpoint.0 { + bail!( + "First operation in listing doesn't have timestamp that corresponds to checkpoint" + ); + } + + // 5. Apply all operations in order + // Hypothesis: before the loaded checkpoint, operations haven't changed + // between what's on storage and what we used to calculate the state in RAM here. + let i0 = self + .history + .iter() + .zip(ops.iter()) + .take_while(|((ts1, _, _), (ts2, _))| ts1 == ts2) + .count(); + + if ops.len() > i0 { + // Remove operations from first position where histories differ + self.history.truncate(i0); + + // Look up last calculated state which we have saved and start from there. + let mut last_state = (0, &self.checkpoint.1); + for (i, (_, _, state_opt)) in self.history.iter().enumerate().rev() { + if let Some(state) = state_opt { + last_state = (i + 1, state); + break; + } + } + + // Calculate state at the end of this common part of the history + let mut state = last_state.1.clone(); + for (_, op, _) in self.history[last_state.0..].iter() { + state = state.apply(op); + } + + // Now, apply all operations retrieved from storage after the common part + for (ts, op) in ops.drain(i0..) { + state = state.apply(&op); + if (self.history.len() + 1) % KEEP_STATE_EVERY == 0 { + self.history.push((ts, op, Some(state.clone()))); + } else { + self.history.push((ts, op, None)); + } + } + + // Always save final state as result of last operation + self.history.last_mut().unwrap().2 = Some(state); + } + + // Save info that sync has been done + self.last_sync = new_last_sync; + self.last_sync_watch_ct = new_last_sync_watch_ct; + Ok(()) + } + + /// Does a sync() if either of the two conditions is met: + /// - last sync was more than CHECKPOINT_INTERVAL/5 ago + /// - a change was detected + pub async fn opportunistic_sync(&mut self) -> Result<()> { + let too_old = match self.last_sync { + Some(t) => Instant::now() > t + (CHECKPOINT_INTERVAL / 5), + _ => true, + }; + let changed = self.last_sync_watch_ct != *self.watch.rx.borrow(); + if too_old || changed { + self.sync().await?; + } + Ok(()) + } + + pub fn notifier(&self) -> std::sync::Weak { + Arc::downgrade(&self.watch.learnt_remote_update) + } + + /// Applies a new operation on the state. Once this function returns, + /// the operation has been safely persisted to storage backend. + /// Make sure to call `.opportunistic_sync()` before doing this, + /// and even before calculating the `op` argument given here. + pub async fn push(&mut self, op: S::Op) -> Result<()> { + tracing::debug!("(push) add operation: {:?}", op); + + let ts = Timestamp::after( + self.history + .last() + .map(|(ts, _, _)| ts) + .unwrap_or(&self.checkpoint.0), + ); + + let row_val = storage::RowVal::new( + storage::RowRef::new(&self.path, &ts.to_string()), + seal_serialize(&op, &self.key)?, + ); + self.storage.row_insert(vec![row_val]).await?; + self.watch.propagate_local_update.notify_one(); + + let new_state = self.state().apply(&op); + self.history.push((ts, op, Some(new_state))); + + // Clear previously saved state in history if not required + let hlen = self.history.len(); + if hlen >= 2 && (hlen - 1) % KEEP_STATE_EVERY != 0 { + self.history[hlen - 2].2 = None; + } + + self.checkpoint().await?; + + Ok(()) + } + + /// Save a new checkpoint if previous checkpoint is too old + pub async fn checkpoint(&mut self) -> Result<()> { + match self.last_try_checkpoint { + Some(ts) if Instant::now() - ts < CHECKPOINT_INTERVAL / 5 => Ok(()), + _ => { + let res = self.checkpoint_internal().await; + if res.is_ok() { + self.last_try_checkpoint = Some(Instant::now()); + } + res + } + } + } + + async fn checkpoint_internal(&mut self) -> Result<()> { + self.sync().await?; + + // Check what would be the possible time for a checkpoint in the history we have + let now = now_msec() as i128; + let i_cp = match self + .history + .iter() + .enumerate() + .rev() + .skip_while(|(_, (ts, _, _))| { + (now - ts.msec as i128) < CHECKPOINT_INTERVAL.as_millis() as i128 + }) + .map(|(i, _)| i) + .next() + { + Some(i) => i, + None => { + tracing::debug!("(cp) Oldest operation is too recent to trigger checkpoint"); + return Ok(()); + } + }; + + if i_cp < CHECKPOINT_MIN_OPS { + tracing::debug!("(cp) Not enough old operations to trigger checkpoint"); + return Ok(()); + } + + let ts_cp = self.history[i_cp].0; + tracing::debug!( + "(cp) we could checkpoint at time {} (index {} in history)", + ts_cp.to_string(), + i_cp + ); + + // Check existing checkpoints: if last one is too recent, don't checkpoint again. + let existing_checkpoints = self.list_checkpoints().await?; + tracing::debug!("(cp) listed checkpoints: {:?}", existing_checkpoints); + + if let Some(last_cp) = existing_checkpoints.last() { + if (ts_cp.msec as i128 - last_cp.0.msec as i128) + < CHECKPOINT_INTERVAL.as_millis() as i128 + { + tracing::debug!( + "(cp) last checkpoint is too recent: {}, not checkpointing", + last_cp.0.to_string() + ); + return Ok(()); + } + } + + tracing::debug!("(cp) saving checkpoint at {}", ts_cp.to_string()); + + // Calculate state at time of checkpoint + let mut last_known_state = (0, &self.checkpoint.1); + for (i, (_, _, st)) in self.history[..i_cp].iter().enumerate() { + if let Some(s) = st { + last_known_state = (i + 1, s); + } + } + let mut state_cp = last_known_state.1.clone(); + for (_, op, _) in self.history[last_known_state.0..i_cp].iter() { + state_cp = state_cp.apply(op); + } + + // Serialize and save checkpoint + let cryptoblob = seal_serialize(&state_cp, &self.key)?; + tracing::debug!("(cp) checkpoint body length: {}", cryptoblob.len()); + + let blob_val = storage::BlobVal::new( + storage::BlobRef(format!("{}/checkpoint/{}", self.path, ts_cp.to_string())), + cryptoblob.into(), + ); + self.storage.blob_insert(blob_val).await?; + + // Drop old checkpoints (but keep at least CHECKPOINTS_TO_KEEP of them) + let ecp_len = existing_checkpoints.len(); + if ecp_len + 1 > CHECKPOINTS_TO_KEEP { + let last_to_keep = ecp_len + 1 - CHECKPOINTS_TO_KEEP; + + // Delete blobs + for (_ts, key) in existing_checkpoints[..last_to_keep].iter() { + tracing::debug!("(cp) drop old checkpoint {}", key); + self.storage + .blob_rm(&storage::BlobRef(key.to_string())) + .await?; + } + + // Delete corresponding range of operations + let ts_ser = existing_checkpoints[last_to_keep].0.to_string(); + self.storage + .row_rm(&storage::Selector::Range { + shard: &self.path, + sort_begin: "", + sort_end: &ts_ser, + }) + .await? + } + + Ok(()) + } + + pub fn state(&self) -> &S { + if let Some(last) = self.history.last() { + last.2.as_ref().unwrap() + } else { + &self.checkpoint.1 + } + } + + // ---- INTERNAL ---- + + async fn list_checkpoints(&self) -> Result> { + let prefix = format!("{}/checkpoint/", self.path); + + let checkpoints_res = self.storage.blob_list(&prefix).await?; + + let mut checkpoints = vec![]; + for object in checkpoints_res { + let key = object.0; + if let Some(ckid) = key.strip_prefix(&prefix) { + if let Ok(ts) = ckid.parse::() { + checkpoints.push((ts, key.into())); + } + } + } + checkpoints.sort_by_key(|(ts, _)| *ts); + Ok(checkpoints) + } +} + +// ---- Bayou watch in K2V ---- + +struct K2vWatch { + target: storage::RowRef, + rx: watch::Receiver, + propagate_local_update: Notify, + learnt_remote_update: Arc, +} + +impl K2vWatch { + /// Creates a new watch and launches subordinate threads. + /// These threads hold Weak pointers to the struct; + /// they exit when the Arc is dropped. + async fn new(creds: &Credentials, target: storage::RowRef) -> Result> { + let storage = creds.storage.build().await?; + + let (tx, rx) = watch::channel::(target.clone()); + let propagate_local_update = Notify::new(); + let learnt_remote_update = Arc::new(Notify::new()); + + let watch = Arc::new(K2vWatch { + target, + rx, + propagate_local_update, + learnt_remote_update, + }); + + tokio::spawn(Self::background_task(Arc::downgrade(&watch), storage, tx)); + + Ok(watch) + } + + async fn background_task( + self_weak: Weak, + storage: storage::Store, + tx: watch::Sender, + ) { + let (mut row, remote_update) = match Weak::upgrade(&self_weak) { + Some(this) => (this.target.clone(), this.learnt_remote_update.clone()), + None => return, + }; + + while let Some(this) = Weak::upgrade(&self_weak) { + tracing::debug!( + "bayou k2v watch bg loop iter ({}, {})", + this.target.uid.shard, + this.target.uid.sort + ); + tokio::select!( + // Needed to exit: will force a loop iteration every minutes, + // that will stop the loop if other Arc references have been dropped + // and free resources. Otherwise we would be blocked waiting forever... + _ = tokio::time::sleep(Duration::from_secs(60)) => continue, + + // Watch if another instance has modified the log + update = storage.row_poll(&row) => { + match update { + Err(e) => { + error!("Error in bayou k2v wait value changed: {}", e); + tokio::time::sleep(Duration::from_secs(30)).await; + } + Ok(new_value) => { + row = new_value.row_ref; + if let Err(e) = tx.send(row.clone()) { + tracing::warn!(err=?e, "(watch) can't record the new log ref"); + break; + } + tracing::debug!(row=?row, "(watch) learnt remote update"); + this.learnt_remote_update.notify_waiters(); + } + } + } + + // It appears we have modified the log, informing other people + _ = this.propagate_local_update.notified() => { + let rand = u128::to_be_bytes(thread_rng().gen()).to_vec(); + let row_val = storage::RowVal::new(row.clone(), rand); + if let Err(e) = storage.row_insert(vec![row_val]).await + { + tracing::error!("Error in bayou k2v watch updater loop: {}", e); + tokio::time::sleep(Duration::from_secs(30)).await; + } + } + ); + } + // unblock listeners + remote_update.notify_waiters(); + tracing::info!("bayou k2v watch bg loop exiting"); + } +} diff --git a/aero-bayou/src/timestamp.rs b/aero-bayou/src/timestamp.rs new file mode 100644 index 0000000..4aa5399 --- /dev/null +++ b/aero-bayou/src/timestamp.rs @@ -0,0 +1,66 @@ +use std::str::FromStr; +use std::time::{SystemTime, UNIX_EPOCH}; + +use rand::prelude::*; + +/// Returns milliseconds since UNIX Epoch +pub fn now_msec() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Fix your clock :o") + .as_millis() as u64 +} + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub struct Timestamp { + pub msec: u64, + pub rand: u64, +} + +impl Timestamp { + #[allow(dead_code)] + // 2023-05-15 try to make clippy happy and not sure if this fn will be used in the future. + pub fn now() -> Self { + let mut rng = thread_rng(); + Self { + msec: now_msec(), + rand: rng.gen::(), + } + } + + pub fn after(other: &Self) -> Self { + let mut rng = thread_rng(); + Self { + msec: std::cmp::max(now_msec(), other.msec + 1), + rand: rng.gen::(), + } + } + + pub fn zero() -> Self { + Self { msec: 0, rand: 0 } + } +} + +impl ToString for Timestamp { + fn to_string(&self) -> String { + let mut bytes = [0u8; 16]; + bytes[0..8].copy_from_slice(&u64::to_be_bytes(self.msec)); + bytes[8..16].copy_from_slice(&u64::to_be_bytes(self.rand)); + hex::encode(bytes) + } +} + +impl FromStr for Timestamp { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + let bytes = hex::decode(s).map_err(|_| "invalid hex")?; + if bytes.len() != 16 { + return Err("bad length"); + } + Ok(Self { + msec: u64::from_be_bytes(bytes[0..8].try_into().unwrap()), + rand: u64::from_be_bytes(bytes[8..16].try_into().unwrap()), + }) + } +} -- cgit v1.2.3