aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2020-04-07 18:10:20 +0200
committerAlex Auvolat <alex@adnab.me>2020-04-07 18:10:20 +0200
commit90cdffb425c6222f4234db54a16c079d8c058724 (patch)
tree2a836af28b5bce3969560fa4f45973b8392629c2 /src
parent82b7fcd280d416aabc4f94a99a17c8d9e99888be (diff)
downloadgarage-90cdffb425c6222f4234db54a16c079d8c058724.tar.gz
garage-90cdffb425c6222f4234db54a16c079d8c058724.zip
custom data type for hashes and identifiers
Diffstat (limited to 'src')
-rw-r--r--src/data.rs71
-rw-r--r--src/error.rs18
-rw-r--r--src/main.rs8
-rw-r--r--src/membership.rs33
-rw-r--r--src/proto.rs2
-rw-r--r--src/server.rs8
6 files changed, 103 insertions, 37 deletions
diff --git a/src/data.rs b/src/data.rs
index c649b289..f54c4cc1 100644
--- a/src/data.rs
+++ b/src/data.rs
@@ -1,8 +1,73 @@
+use std::fmt;
use std::collections::HashMap;
-use serde::{Serialize, Deserialize};
+use serde::{Serializer, Deserializer, Serialize, Deserialize};
+use serde::de::{self, Visitor};
-pub type UUID = [u8; 32];
-pub type Hash = [u8; 32];
+#[derive(Default, PartialOrd, Ord, Clone, Hash, PartialEq)]
+pub struct FixedBytes32([u8; 32]);
+
+impl From<[u8; 32]> for FixedBytes32 {
+ fn from(x: [u8; 32]) -> FixedBytes32 {
+ FixedBytes32(x)
+ }
+}
+
+impl std::convert::AsRef<[u8]> for FixedBytes32 {
+ fn as_ref(&self) -> &[u8] {
+ &self.0[..]
+ }
+}
+
+impl Eq for FixedBytes32 {}
+
+impl fmt::Debug for FixedBytes32 {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", hex::encode(self.0))
+ }
+}
+
+struct FixedBytes32Visitor;
+impl<'de> Visitor<'de> for FixedBytes32Visitor {
+ type Value = FixedBytes32;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a byte slice of size 32")
+ }
+
+ fn visit_bytes<E: de::Error>(self, value: &[u8]) -> Result<Self::Value, E> {
+ if value.len() == 32 {
+ let mut res = [0u8; 32];
+ res.copy_from_slice(value);
+ Ok(res.into())
+ } else {
+ Err(E::custom(format!("Invalid byte string length {}, expected 32", value.len())))
+ }
+ }
+}
+
+impl<'de> Deserialize<'de> for FixedBytes32 {
+ fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<FixedBytes32, D::Error> {
+ deserializer.deserialize_bytes(FixedBytes32Visitor)
+ }
+}
+
+impl Serialize for FixedBytes32 {
+ fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+ serializer.serialize_bytes(&self.0[..])
+ }
+}
+
+impl FixedBytes32 {
+ pub fn as_slice(&self) -> &[u8] {
+ &self.0[..]
+ }
+ pub fn as_slice_mut(&mut self) -> &mut [u8] {
+ &mut self.0[..]
+ }
+}
+
+pub type UUID = FixedBytes32;
+pub type Hash = FixedBytes32;
// Network management
diff --git a/src/error.rs b/src/error.rs
index 1e611adb..fd717638 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -3,29 +3,29 @@ use std::io;
#[derive(Debug, Error)]
pub enum Error {
- #[error(display = "IO error")]
+ #[error(display = "IO error: {}", _0)]
Io(#[error(source)] io::Error),
- #[error(display = "Hyper error")]
+ #[error(display = "Hyper error: {}", _0)]
Hyper(#[error(source)] hyper::Error),
- #[error(display = "HTTP error")]
+ #[error(display = "HTTP error: {}", _0)]
HTTP(#[error(source)] http::Error),
- #[error(display = "Messagepack encode error")]
+ #[error(display = "Messagepack encode error: {}", _0)]
RMPEncode(#[error(source)] rmp_serde::encode::Error),
- #[error(display = "Messagepack decode error")]
+ #[error(display = "Messagepack decode error: {}", _0)]
RMPDecode(#[error(source)] rmp_serde::decode::Error),
- #[error(display = "TOML decode error")]
+ #[error(display = "TOML decode error: {}", _0)]
TomlDecode(#[error(source)] toml::de::Error),
- #[error(display = "Timeout")]
+ #[error(display = "Timeout: {}", _0)]
RPCTimeout(#[error(source)] tokio::time::Elapsed),
- #[error(display = "RPC error")]
+ #[error(display = "RPC error: {}", _0)]
RPCError(String),
- #[error(display = "")]
+ #[error(display = "{}", _0)]
Message(String),
}
diff --git a/src/main.rs b/src/main.rs
index 2cb4b720..1e4107c2 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -103,7 +103,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro
println!("Healthy nodes:");
for adv in status.iter() {
if let Some(cfg) = config.members.get(&adv.id) {
- println!("{}\t{}\t{}\t{}", hex::encode(adv.id), cfg.datacenter, cfg.n_tokens, adv.addr);
+ println!("{}\t{}\t{}\t{}", hex::encode(&adv.id), cfg.datacenter, cfg.n_tokens, adv.addr);
}
}
@@ -112,7 +112,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro
println!("\nFailed nodes:");
for (id, cfg) in config.members.iter() {
if !status.iter().any(|x| x.id == *id) {
- println!("{}\t{}\t{}", hex::encode(id), cfg.datacenter, cfg.n_tokens);
+ println!("{}\t{}\t{}", hex::encode(&id), cfg.datacenter, cfg.n_tokens);
}
}
}
@@ -121,7 +121,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro
println!("\nUnconfigured nodes:");
for adv in status.iter() {
if !config.members.contains_key(&adv.id) {
- println!("{}\t{}", hex::encode(adv.id), adv.addr);
+ println!("{}\t{}", hex::encode(&adv.id), adv.addr);
}
}
}
@@ -139,7 +139,7 @@ async fn cmd_configure(rpc_cli: RpcClient, rpc_host: SocketAddr, args: Configure
let mut candidates = vec![];
for adv in status.iter() {
- if hex::encode(adv.id).starts_with(&args.node_id) {
+ if hex::encode(&adv.id).starts_with(&args.node_id) {
candidates.push(adv.id.clone());
}
}
diff --git a/src/membership.rs b/src/membership.rs
index 1ce567a7..b7b99bb1 100644
--- a/src/membership.rs
+++ b/src/membership.rs
@@ -61,7 +61,7 @@ impl Members {
});
match old_status {
None => {
- eprintln!("Newly pingable node: {}", hex::encode(info.id));
+ eprintln!("Newly pingable node: {}", hex::encode(&info.id));
true
}
Some(x) => x.addr != addr,
@@ -70,16 +70,16 @@ impl Members {
fn recalculate_status_hash(&mut self) {
let mut nodes = self.status.iter().collect::<Vec<_>>();
- nodes.sort_by_key(|(id, _status)| *id);
+ nodes.sort_unstable_by_key(|(id, _status)| *id);
let mut hasher = Sha256::new();
eprintln!("Current set of pingable nodes: --");
for (id, status) in nodes {
- eprintln!("{} {}", hex::encode(id), status.addr);
- hasher.input(format!("{} {}\n", hex::encode(id), status.addr));
+ eprintln!("{} {}", hex::encode(&id), status.addr);
+ hasher.input(format!("{} {}\n", hex::encode(&id), status.addr));
}
eprintln!("END --");
- self.status_hash.copy_from_slice(&hasher.result()[..]);
+ self.status_hash.as_slice_mut().copy_from_slice(&hasher.result()[..]);
}
fn rebuild_ring(&mut self) {
@@ -97,19 +97,19 @@ impl Members {
for i in 0..config.n_tokens {
let mut location_hasher = Sha256::new();
- location_hasher.input(format!("{} {}", hex::encode(id), i));
+ location_hasher.input(format!("{} {}", hex::encode(&id), i));
let mut location = [0u8; 32];
location.copy_from_slice(&location_hasher.result()[..]);
new_ring.push(RingEntry{
- location,
+ location: location.into(),
node: id.clone(),
datacenter,
})
}
}
- new_ring.sort_by_key(|x| x.location);
+ new_ring.sort_unstable_by(|x, y| x.location.cmp(&y.location));
self.ring = new_ring;
self.n_datacenters = datacenters.len();
}
@@ -119,7 +119,7 @@ impl Members {
return self.config.members.keys().cloned().collect::<Vec<_>>();
}
- let start = match self.ring.binary_search_by_key(from, |x| x.location) {
+ let start = match self.ring.binary_search_by(|x| x.location.cmp(from)) {
Ok(i) => i,
Err(i) => if i == 0 {
self.ring.len() - 1
@@ -178,7 +178,7 @@ impl System {
};
let mut members = Members{
status: HashMap::new(),
- status_hash: [0u8; 32],
+ status_hash: Hash::default(),
config: net_config,
ring: Vec::new(),
n_datacenters: 0,
@@ -193,7 +193,7 @@ impl System {
}
}
- pub async fn save_network_config(&self) {
+ async fn save_network_config(self: Arc<Self>) {
let mut path = self.config.metadata_dir.clone();
path.push("network_config");
@@ -211,7 +211,7 @@ impl System {
pub async fn make_ping(&self) -> Message {
let members = self.members.read().await;
Message::Ping(PingMessage{
- id: self.id,
+ id: self.id.clone(),
rpc_port: self.config.rpc_port,
status_hash: members.status_hash.clone(),
config_version: members.config.version,
@@ -271,8 +271,8 @@ impl System {
} else if let Some(id) = id_option {
let remaining_attempts = members.status.get(id).map(|x| x.remaining_ping_attempts).unwrap_or(0);
if remaining_attempts == 0 {
- eprintln!("Removing node {} after too many failed pings", hex::encode(id));
- members.status.remove(id);
+ eprintln!("Removing node {} after too many failed pings", hex::encode(&id));
+ members.status.remove(&id);
has_changes = true;
} else {
if let Some(st) = members.status.get_mut(id) {
@@ -376,11 +376,12 @@ impl System {
{
let mut members = self.members.write().await;
if adv.version > members.config.version {
- tokio::spawn(self.clone().broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT));
members.config = adv.clone();
- self.save_network_config().await;
members.rebuild_ring();
+
+ tokio::spawn(self.clone().broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT));
+ tokio::spawn(self.clone().save_network_config());
}
Ok(Message::Ok)
diff --git a/src/proto.rs b/src/proto.rs
index 18bc339e..d1d4fb59 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -4,7 +4,7 @@ use serde::{Serialize, Deserialize};
use crate::data::*;
-pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2);
+pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Serialize, Deserialize)]
pub enum Message {
diff --git a/src/server.rs b/src/server.rs
index 1450911b..5cac1c70 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -47,13 +47,13 @@ fn gen_node_id(metadata_dir: &PathBuf) -> Result<UUID, Error> {
let mut id = [0u8; 32];
id.copy_from_slice(&d[..]);
- Ok(id)
+ Ok(id.into())
} else {
- let id = rand::thread_rng().gen::<UUID>();
+ let id = rand::thread_rng().gen::<[u8; 32]>();
let mut f = std::fs::File::create(id_file.as_path())?;
f.write_all(&id[..])?;
- Ok(id)
+ Ok(id.into())
}
}
@@ -78,7 +78,7 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> {
let id = gen_node_id(&config.metadata_dir)
.expect("Unable to read or generate node ID");
- println!("Node ID: {}", hex::encode(id));
+ println!("Node ID: {}", hex::encode(&id));
let sys = Arc::new(System::new(config, id));