diff options
Diffstat (limited to 'src/util')
-rw-r--r-- | src/util/Cargo.toml | 35 | ||||
-rw-r--r-- | src/util/background.rs | 124 | ||||
-rw-r--r-- | src/util/config.rs | 66 | ||||
-rw-r--r-- | src/util/data.rs | 124 | ||||
-rw-r--r-- | src/util/error.rs | 112 | ||||
-rw-r--r-- | src/util/lib.rs | 7 |
6 files changed, 468 insertions, 0 deletions
diff --git a/src/util/Cargo.toml b/src/util/Cargo.toml new file mode 100644 index 00000000..6f61a586 --- /dev/null +++ b/src/util/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "garage_util" +version = "0.1.0" +authors = ["Alex Auvolat <alex@adnab.me>"] +edition = "2018" + +[lib] +path = "lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rand = "0.7" +hex = "0.3" +sha2 = "0.8" +err-derive = "0.2.3" +log = "0.4" + +sled = "0.31" + +toml = "0.5" +rmp-serde = "0.14.3" +serde = { version = "1.0", default-features = false, features = ["derive", "rc"] } +serde_json = "1.0" + +futures = "0.3" +futures-util = "0.3" +tokio = { version = "0.2", default-features = false, features = ["rt-core", "rt-threaded", "io-driver", "net", "tcp", "time", "macros", "sync", "signal", "fs"] } + +http = "0.2" +hyper = "0.13" +rustls = "0.17" +webpki = "0.21" + + diff --git a/src/util/background.rs b/src/util/background.rs new file mode 100644 index 00000000..937062dd --- /dev/null +++ b/src/util/background.rs @@ -0,0 +1,124 @@ +use core::future::Future; +use std::pin::Pin; + +use futures::future::join_all; +use futures::select; +use futures_util::future::*; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio::sync::{mpsc, watch, Notify}; + +use crate::error::Error; + +type JobOutput = Result<(), Error>; +type Job = Pin<Box<dyn Future<Output = JobOutput> + Send>>; + +pub struct BackgroundRunner { + n_runners: usize, + pub stop_signal: watch::Receiver<bool>, + + queue_in: mpsc::UnboundedSender<(Job, bool)>, + queue_out: Mutex<mpsc::UnboundedReceiver<(Job, bool)>>, + job_notify: Notify, + + workers: Mutex<Vec<tokio::task::JoinHandle<()>>>, +} + +impl BackgroundRunner { + pub fn new(n_runners: usize, stop_signal: watch::Receiver<bool>) -> Arc<Self> { + let (queue_in, queue_out) = mpsc::unbounded_channel(); + Arc::new(Self { + n_runners, + stop_signal, + queue_in, + queue_out: Mutex::new(queue_out), + job_notify: Notify::new(), + workers: Mutex::new(Vec::new()), + }) + } + + pub async fn run(self: Arc<Self>) { + let mut workers = self.workers.lock().await; + for i in 0..self.n_runners { + workers.push(tokio::spawn(self.clone().runner(i))); + } + drop(workers); + + let mut stop_signal = self.stop_signal.clone(); + while let Some(exit_now) = stop_signal.recv().await { + if exit_now { + let mut workers = self.workers.lock().await; + let workers_vec = workers.drain(..).collect::<Vec<_>>(); + join_all(workers_vec).await; + return; + } + } + } + + pub fn spawn<T>(&self, job: T) + where + T: Future<Output = JobOutput> + Send + 'static, + { + let boxed: Job = Box::pin(job); + let _: Result<_, _> = self.queue_in.clone().send((boxed, false)); + self.job_notify.notify(); + } + + pub fn spawn_cancellable<T>(&self, job: T) + where + T: Future<Output = JobOutput> + Send + 'static, + { + let boxed: Job = Box::pin(job); + let _: Result<_, _> = self.queue_in.clone().send((boxed, true)); + self.job_notify.notify(); + } + + pub async fn spawn_worker<F, T>(&self, name: String, worker: F) + where + F: FnOnce(watch::Receiver<bool>) -> T + Send + 'static, + T: Future<Output = JobOutput> + Send + 'static, + { + let mut workers = self.workers.lock().await; + let stop_signal = self.stop_signal.clone(); + workers.push(tokio::spawn(async move { + if let Err(e) = worker(stop_signal).await { + error!("Worker stopped with error: {}, error: {}", name, e); + } else { + info!("Worker exited successfully: {}", name); + } + })); + } + + async fn runner(self: Arc<Self>, i: usize) { + let mut stop_signal = self.stop_signal.clone(); + loop { + let must_exit: bool = *stop_signal.borrow(); + if let Some(job) = self.dequeue_job(must_exit).await { + if let Err(e) = job.await { + error!("Job failed: {}", e) + } + } else { + if must_exit { + info!("Background runner {} exiting", i); + return; + } + select! { + _ = self.job_notify.notified().fuse() => (), + _ = stop_signal.recv().fuse() => (), + } + } + } + } + + async fn dequeue_job(&self, must_exit: bool) -> Option<Job> { + let mut queue = self.queue_out.lock().await; + while let Ok((job, cancellable)) = queue.try_recv() { + if cancellable && must_exit { + continue; + } else { + return Some(job); + } + } + None + } +} diff --git a/src/util/config.rs b/src/util/config.rs new file mode 100644 index 00000000..cb871562 --- /dev/null +++ b/src/util/config.rs @@ -0,0 +1,66 @@ +use std::io::Read; +use std::net::SocketAddr; +use std::path::PathBuf; + +use serde::Deserialize; + +use crate::error::Error; + +#[derive(Deserialize, Debug, Clone)] +pub struct Config { + pub metadata_dir: PathBuf, + pub data_dir: PathBuf, + + pub api_bind_addr: SocketAddr, + pub rpc_bind_addr: SocketAddr, + + pub bootstrap_peers: Vec<SocketAddr>, + + #[serde(default = "default_max_concurrent_rpc_requests")] + pub max_concurrent_rpc_requests: usize, + + #[serde(default = "default_block_size")] + pub block_size: usize, + + #[serde(default = "default_replication_factor")] + pub meta_replication_factor: usize, + + #[serde(default = "default_epidemic_factor")] + pub meta_epidemic_factor: usize, + + #[serde(default = "default_replication_factor")] + pub data_replication_factor: usize, + + pub rpc_tls: Option<TlsConfig>, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct TlsConfig { + pub ca_cert: String, + pub node_cert: String, + pub node_key: String, +} + +fn default_max_concurrent_rpc_requests() -> usize { + 12 +} +fn default_block_size() -> usize { + 1048576 +} +fn default_replication_factor() -> usize { + 3 +} +fn default_epidemic_factor() -> usize { + 3 +} + +pub fn read_config(config_file: PathBuf) -> Result<Config, Error> { + let mut file = std::fs::OpenOptions::new() + .read(true) + .open(config_file.as_path())?; + + let mut config = String::new(); + file.read_to_string(&mut config)?; + + Ok(toml::from_str(&config)?) +} diff --git a/src/util/data.rs b/src/util/data.rs new file mode 100644 index 00000000..8f976f71 --- /dev/null +++ b/src/util/data.rs @@ -0,0 +1,124 @@ +use rand::Rng; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use sha2::{Digest, Sha256}; +use std::fmt; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[derive(Default, PartialOrd, Ord, Clone, Hash, PartialEq, Copy)] +pub struct FixedBytes32([u8; 32]); + +impl From<[u8; 32]> for FixedBytes32 { + fn from(x: [u8; 32]) -> FixedBytes32 { + FixedBytes32(x) + } +} + +impl std::convert::AsRef<[u8]> for FixedBytes32 { + fn as_ref(&self) -> &[u8] { + &self.0[..] + } +} + +impl Eq for FixedBytes32 {} + +impl fmt::Debug for FixedBytes32 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}…", hex::encode(&self.0[..8])) + } +} + +struct FixedBytes32Visitor; +impl<'de> Visitor<'de> for FixedBytes32Visitor { + type Value = FixedBytes32; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a byte slice of size 32") + } + + fn visit_bytes<E: de::Error>(self, value: &[u8]) -> Result<Self::Value, E> { + if value.len() == 32 { + let mut res = [0u8; 32]; + res.copy_from_slice(value); + Ok(res.into()) + } else { + Err(E::custom(format!( + "Invalid byte string length {}, expected 32", + value.len() + ))) + } + } +} + +impl<'de> Deserialize<'de> for FixedBytes32 { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<FixedBytes32, D::Error> { + deserializer.deserialize_bytes(FixedBytes32Visitor) + } +} + +impl Serialize for FixedBytes32 { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_bytes(&self.0[..]) + } +} + +impl FixedBytes32 { + pub fn as_slice(&self) -> &[u8] { + &self.0[..] + } + pub fn as_slice_mut(&mut self) -> &mut [u8] { + &mut self.0[..] + } + pub fn to_vec(&self) -> Vec<u8> { + self.0.to_vec() + } +} + +pub type UUID = FixedBytes32; +pub type Hash = FixedBytes32; + +pub fn hash(data: &[u8]) -> Hash { + let mut hasher = Sha256::new(); + hasher.input(data); + let mut hash = [0u8; 32]; + hash.copy_from_slice(&hasher.result()[..]); + hash.into() +} + +pub fn gen_uuid() -> UUID { + rand::thread_rng().gen::<[u8; 32]>().into() +} + +pub fn now_msec() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Fix your clock :o") + .as_millis() as u64 +} + +// RMP serialization with names of fields and variants + +pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error> +where + T: Serialize + ?Sized, +{ + let mut wr = Vec::with_capacity(128); + let mut se = rmp_serde::Serializer::new(&mut wr) + .with_struct_map() + .with_string_variants(); + val.serialize(&mut se)?; + Ok(wr) +} + +pub fn debug_serialize<T: Serialize>(x: T) -> String { + match serde_json::to_string(&x) { + Ok(ss) => { + if ss.len() > 100 { + ss[..100].to_string() + } else { + ss + } + } + Err(e) => format!("<JSON serialization error: {}>", e), + } +} diff --git a/src/util/error.rs b/src/util/error.rs new file mode 100644 index 00000000..f73d6915 --- /dev/null +++ b/src/util/error.rs @@ -0,0 +1,112 @@ +use err_derive::Error; +use hyper::StatusCode; +use std::io; + +use crate::data::*; + +#[derive(Debug, Error)] +pub enum RPCError { + #[error(display = "Node is down: {:?}.", _0)] + NodeDown(UUID), + #[error(display = "Timeout: {}", _0)] + Timeout(#[error(source)] tokio::time::Elapsed), + #[error(display = "HTTP error: {}", _0)] + HTTP(#[error(source)] http::Error), + #[error(display = "Hyper error: {}", _0)] + Hyper(#[error(source)] hyper::Error), + #[error(display = "Messagepack encode error: {}", _0)] + RMPEncode(#[error(source)] rmp_serde::encode::Error), + #[error(display = "Messagepack decode error: {}", _0)] + RMPDecode(#[error(source)] rmp_serde::decode::Error), + #[error(display = "Too many errors: {:?}", _0)] + TooManyErrors(Vec<String>), +} + +#[derive(Debug, Error)] +pub enum Error { + #[error(display = "IO error: {}", _0)] + Io(#[error(source)] io::Error), + + #[error(display = "Hyper error: {}", _0)] + Hyper(#[error(source)] hyper::Error), + + #[error(display = "HTTP error: {}", _0)] + HTTP(#[error(source)] http::Error), + + #[error(display = "Invalid HTTP header value: {}", _0)] + HTTPHeader(#[error(source)] http::header::ToStrError), + + #[error(display = "TLS error: {}", _0)] + TLS(#[error(source)] rustls::TLSError), + + #[error(display = "PKI error: {}", _0)] + PKI(#[error(source)] webpki::Error), + + #[error(display = "Sled error: {}", _0)] + Sled(#[error(source)] sled::Error), + + #[error(display = "Messagepack encode error: {}", _0)] + RMPEncode(#[error(source)] rmp_serde::encode::Error), + #[error(display = "Messagepack decode error: {}", _0)] + RMPDecode(#[error(source)] rmp_serde::decode::Error), + #[error(display = "JSON error: {}", _0)] + JSON(#[error(source)] serde_json::error::Error), + #[error(display = "TOML decode error: {}", _0)] + TomlDecode(#[error(source)] toml::de::Error), + + #[error(display = "Timeout: {}", _0)] + RPCTimeout(#[error(source)] tokio::time::Elapsed), + + #[error(display = "Tokio join error: {}", _0)] + TokioJoin(#[error(source)] tokio::task::JoinError), + + #[error(display = "RPC call error: {}", _0)] + RPC(#[error(source)] RPCError), + + #[error(display = "Remote error: {} (status code {})", _0, _1)] + RemoteError(String, StatusCode), + + #[error(display = "Bad request: {}", _0)] + BadRequest(String), + + #[error(display = "Not found")] + NotFound, + + #[error(display = "Corrupt data: does not match hash {:?}", _0)] + CorruptData(Hash), + + #[error(display = "{}", _0)] + Message(String), +} + +impl Error { + pub fn http_status_code(&self) -> StatusCode { + match self { + Error::BadRequest(_) => StatusCode::BAD_REQUEST, + Error::NotFound => StatusCode::NOT_FOUND, + Error::RPC(_) => StatusCode::SERVICE_UNAVAILABLE, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +impl From<sled::TransactionError<Error>> for Error { + fn from(e: sled::TransactionError<Error>) -> Error { + match e { + sled::TransactionError::Abort(x) => x, + sled::TransactionError::Storage(x) => Error::Sled(x), + } + } +} + +impl<T> From<tokio::sync::watch::error::SendError<T>> for Error { + fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error { + Error::Message(format!("Watch send error")) + } +} + +impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error { + fn from(_e: tokio::sync::mpsc::error::SendError<T>) -> Error { + Error::Message(format!("MPSC send error")) + } +} diff --git a/src/util/lib.rs b/src/util/lib.rs new file mode 100644 index 00000000..0bf09bf6 --- /dev/null +++ b/src/util/lib.rs @@ -0,0 +1,7 @@ +#[macro_use] +extern crate log; + +pub mod background; +pub mod config; +pub mod data; +pub mod error; |