From c9c6b0dbd41e20d19b91c6615c46da6f45925bca Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 23 Apr 2020 17:05:46 +0000 Subject: Reorganize code --- src/admin_rpc.rs | 15 +- src/api/api_server.rs | 360 ++++++++++++++++++++ src/api/http_util.rs | 82 +++++ src/api/mod.rs | 2 + src/api_server.rs | 358 -------------------- src/block.rs | 500 --------------------------- src/block_ref_table.rs | 67 ---- src/bucket_table.rs | 82 ----- src/config.rs | 66 ++++ src/error.rs | 2 +- src/http_util.rs | 82 ----- src/main.rs | 28 +- src/membership.rs | 674 ------------------------------------ src/object_table.rs | 133 -------- src/rpc/membership.rs | 692 +++++++++++++++++++++++++++++++++++++ src/rpc/mod.rs | 4 + src/rpc/rpc_client.rs | 360 ++++++++++++++++++++ src/rpc/rpc_server.rs | 219 ++++++++++++ src/rpc/tls_util.rs | 139 ++++++++ src/rpc_client.rs | 358 -------------------- src/rpc_server.rs | 218 ------------ src/server.rs | 168 +++------ src/store/block.rs | 506 +++++++++++++++++++++++++++ src/store/block_ref_table.rs | 68 ++++ src/store/bucket_table.rs | 82 +++++ src/store/mod.rs | 5 + src/store/object_table.rs | 134 ++++++++ src/store/version_table.rs | 95 ++++++ src/table.rs | 522 ---------------------------- src/table/mod.rs | 6 + src/table/table.rs | 524 ++++++++++++++++++++++++++++ src/table/table_fullcopy.rs | 100 ++++++ src/table/table_sharded.rs | 55 +++ src/table/table_sync.rs | 791 +++++++++++++++++++++++++++++++++++++++++++ src/table_fullcopy.rs | 100 ------ src/table_sharded.rs | 55 --- src/table_sync.rs | 791 ------------------------------------------- src/tls_util.rs | 139 -------- src/version_table.rs | 94 ----- 39 files changed, 4361 insertions(+), 4315 deletions(-) create mode 100644 src/api/api_server.rs create mode 100644 src/api/http_util.rs create mode 100644 src/api/mod.rs delete mode 100644 src/api_server.rs delete mode 100644 src/block.rs delete mode 100644 src/block_ref_table.rs delete mode 100644 src/bucket_table.rs create mode 100644 src/config.rs delete mode 100644 src/http_util.rs delete mode 100644 src/membership.rs delete mode 100644 src/object_table.rs create mode 100644 src/rpc/membership.rs create mode 100644 src/rpc/mod.rs create mode 100644 src/rpc/rpc_client.rs create mode 100644 src/rpc/rpc_server.rs create mode 100644 src/rpc/tls_util.rs delete mode 100644 src/rpc_client.rs delete mode 100644 src/rpc_server.rs create mode 100644 src/store/block.rs create mode 100644 src/store/block_ref_table.rs create mode 100644 src/store/bucket_table.rs create mode 100644 src/store/mod.rs create mode 100644 src/store/object_table.rs create mode 100644 src/store/version_table.rs delete mode 100644 src/table.rs create mode 100644 src/table/mod.rs create mode 100644 src/table/table.rs create mode 100644 src/table/table_fullcopy.rs create mode 100644 src/table/table_sharded.rs create mode 100644 src/table/table_sync.rs delete mode 100644 src/table_fullcopy.rs delete mode 100644 src/table_sharded.rs delete mode 100644 src/table_sync.rs delete mode 100644 src/tls_util.rs delete mode 100644 src/version_table.rs (limited to 'src') diff --git a/src/admin_rpc.rs b/src/admin_rpc.rs index 458df360..fe59f92e 100644 --- a/src/admin_rpc.rs +++ b/src/admin_rpc.rs @@ -5,15 +5,18 @@ use tokio::sync::watch; use crate::data::*; use crate::error::Error; -use crate::rpc_client::*; -use crate::rpc_server::*; use crate::server::Garage; + use crate::table::*; -use crate::*; -use crate::block_ref_table::*; -use crate::bucket_table::*; -use crate::version_table::*; +use crate::rpc::rpc_client::*; +use crate::rpc::rpc_server::*; + +use crate::store::block_ref_table::*; +use crate::store::bucket_table::*; +use crate::store::version_table::*; + +use crate::*; pub const ADMIN_RPC_TIMEOUT: Duration = Duration::from_secs(30); pub const ADMIN_RPC_PATH: &str = "_admin"; diff --git a/src/api/api_server.rs b/src/api/api_server.rs new file mode 100644 index 00000000..a80b2ea2 --- /dev/null +++ b/src/api/api_server.rs @@ -0,0 +1,360 @@ +use std::collections::VecDeque; +use std::net::SocketAddr; +use std::sync::Arc; + +use futures::future::Future; +use futures::stream::*; +use hyper::body::{Bytes, HttpBody}; +use hyper::server::conn::AddrStream; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Method, Request, Response, Server, StatusCode}; + +use crate::data::*; +use crate::error::Error; +use crate::server::Garage; + +use crate::table::EmptyKey; + +use crate::store::block::INLINE_THRESHOLD; +use crate::store::block_ref_table::*; +use crate::store::object_table::*; +use crate::store::version_table::*; + +use crate::api::http_util::*; + +type BodyType = Box + Send + Unpin>; + +pub async fn run_api_server( + garage: Arc, + shutdown_signal: impl Future, +) -> Result<(), Error> { + let addr = &garage.config.api_bind_addr; + + let service = make_service_fn(|conn: &AddrStream| { + let garage = garage.clone(); + let client_addr = conn.remote_addr(); + async move { + Ok::<_, Error>(service_fn(move |req: Request| { + let garage = garage.clone(); + handler(garage, req, client_addr) + })) + } + }); + + let server = Server::bind(&addr).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + info!("API server listening on http://{}", addr); + + graceful.await?; + Ok(()) +} + +async fn handler( + garage: Arc, + req: Request, + addr: SocketAddr, +) -> Result, Error> { + match handler_inner(garage, req, addr).await { + Ok(x) => Ok(x), + Err(e) => { + let body: BodyType = Box::new(BytesBody::from(format!("{}\n", e))); + let mut http_error = Response::new(body); + *http_error.status_mut() = e.http_status_code(); + Ok(http_error) + } + } +} + +async fn handler_inner( + garage: Arc, + req: Request, + addr: SocketAddr, +) -> Result, Error> { + info!("{} {} {}", addr, req.method(), req.uri()); + + let bucket = req + .headers() + .get(hyper::header::HOST) + .map(|x| x.to_str().map_err(Error::from)) + .unwrap_or(Err(Error::BadRequest(format!("Host: header missing"))))? + .to_lowercase(); + let key = req.uri().path().to_string(); + + match req.method() { + &Method::GET => Ok(handle_get(garage, &bucket, &key).await?), + &Method::PUT => { + let mime_type = req + .headers() + .get(hyper::header::CONTENT_TYPE) + .map(|x| x.to_str()) + .unwrap_or(Ok("blob"))? + .to_string(); + let version_uuid = + handle_put(garage, &mime_type, &bucket, &key, req.into_body()).await?; + let response = format!("{}\n", hex::encode(version_uuid,)); + Ok(Response::new(Box::new(BytesBody::from(response)))) + } + &Method::DELETE => { + let version_uuid = handle_delete(garage, &bucket, &key).await?; + let response = format!("{}\n", hex::encode(version_uuid,)); + Ok(Response::new(Box::new(BytesBody::from(response)))) + } + _ => Err(Error::BadRequest(format!("Invalid method"))), + } +} + +async fn handle_put( + garage: Arc, + mime_type: &str, + bucket: &str, + key: &str, + body: Body, +) -> Result { + let version_uuid = gen_uuid(); + + let mut chunker = BodyChunker::new(body, garage.config.block_size); + let first_block = match chunker.next().await? { + Some(x) => x, + None => return Err(Error::BadRequest(format!("Empty body"))), + }; + + let mut object = Object { + bucket: bucket.into(), + key: key.into(), + versions: Vec::new(), + }; + object.versions.push(Box::new(ObjectVersion { + uuid: version_uuid, + timestamp: now_msec(), + mime_type: mime_type.to_string(), + size: first_block.len() as u64, + is_complete: false, + data: ObjectVersionData::DeleteMarker, + })); + + if first_block.len() < INLINE_THRESHOLD { + object.versions[0].data = ObjectVersionData::Inline(first_block); + object.versions[0].is_complete = true; + garage.object_table.insert(&object).await?; + return Ok(version_uuid); + } + + let version = Version { + uuid: version_uuid, + deleted: false, + blocks: Vec::new(), + bucket: bucket.into(), + key: key.into(), + }; + + let first_block_hash = hash(&first_block[..]); + object.versions[0].data = ObjectVersionData::FirstBlock(first_block_hash); + garage.object_table.insert(&object).await?; + + let mut next_offset = first_block.len(); + let mut put_curr_version_block = put_block_meta(garage.clone(), &version, 0, first_block_hash); + let mut put_curr_block = garage + .block_manager + .rpc_put_block(first_block_hash, first_block); + + loop { + let (_, _, next_block) = + futures::try_join!(put_curr_block, put_curr_version_block, chunker.next())?; + if let Some(block) = next_block { + let block_hash = hash(&block[..]); + let block_len = block.len(); + put_curr_version_block = + put_block_meta(garage.clone(), &version, next_offset as u64, block_hash); + put_curr_block = garage.block_manager.rpc_put_block(block_hash, block); + next_offset += block_len; + } else { + break; + } + } + + // TODO: if at any step we have an error, we should undo everything we did + + object.versions[0].is_complete = true; + object.versions[0].size = next_offset as u64; + garage.object_table.insert(&object).await?; + Ok(version_uuid) +} + +async fn put_block_meta( + garage: Arc, + version: &Version, + offset: u64, + hash: Hash, +) -> Result<(), Error> { + let mut version = version.clone(); + version.blocks.push(VersionBlock { offset, hash: hash }); + + let block_ref = BlockRef { + block: hash, + version: version.uuid, + deleted: false, + }; + + futures::try_join!( + garage.version_table.insert(&version), + garage.block_ref_table.insert(&block_ref), + )?; + Ok(()) +} + +struct BodyChunker { + body: Body, + read_all: bool, + block_size: usize, + buf: VecDeque, +} + +impl BodyChunker { + fn new(body: Body, block_size: usize) -> Self { + Self { + body, + read_all: false, + block_size, + buf: VecDeque::new(), + } + } + async fn next(&mut self) -> Result>, Error> { + while !self.read_all && self.buf.len() < self.block_size { + if let Some(block) = self.body.next().await { + let bytes = block?; + trace!("Body next: {} bytes", bytes.len()); + self.buf.extend(&bytes[..]); + } else { + self.read_all = true; + } + } + if self.buf.len() == 0 { + Ok(None) + } else if self.buf.len() <= self.block_size { + let block = self.buf.drain(..).collect::>(); + Ok(Some(block)) + } else { + let block = self.buf.drain(..self.block_size).collect::>(); + Ok(Some(block)) + } + } +} + +async fn handle_delete(garage: Arc, bucket: &str, key: &str) -> Result { + let exists = match garage + .object_table + .get(&bucket.to_string(), &key.to_string()) + .await? + { + None => false, + Some(o) => { + let mut has_active_version = false; + for v in o.versions.iter() { + if v.data != ObjectVersionData::DeleteMarker { + has_active_version = true; + break; + } + } + has_active_version + } + }; + + if !exists { + // No need to delete + return Ok([0u8; 32].into()); + } + + let version_uuid = gen_uuid(); + + let mut object = Object { + bucket: bucket.into(), + key: key.into(), + versions: Vec::new(), + }; + object.versions.push(Box::new(ObjectVersion { + uuid: version_uuid, + timestamp: now_msec(), + mime_type: "application/x-delete-marker".into(), + size: 0, + is_complete: true, + data: ObjectVersionData::DeleteMarker, + })); + + garage.object_table.insert(&object).await?; + return Ok(version_uuid); +} + +async fn handle_get( + garage: Arc, + bucket: &str, + key: &str, +) -> Result, Error> { + let mut object = match garage + .object_table + .get(&bucket.to_string(), &key.to_string()) + .await? + { + None => return Err(Error::NotFound), + Some(o) => o, + }; + + let last_v = match object + .versions + .drain(..) + .rev() + .filter(|v| v.is_complete) + .next() + { + Some(v) => v, + None => return Err(Error::NotFound), + }; + + let resp_builder = Response::builder() + .header("Content-Type", last_v.mime_type) + .status(StatusCode::OK); + + match last_v.data { + ObjectVersionData::DeleteMarker => Err(Error::NotFound), + ObjectVersionData::Inline(bytes) => { + let body: BodyType = Box::new(BytesBody::from(bytes)); + Ok(resp_builder.body(body)?) + } + ObjectVersionData::FirstBlock(first_block_hash) => { + let read_first_block = garage.block_manager.rpc_get_block(&first_block_hash); + let get_next_blocks = garage.version_table.get(&last_v.uuid, &EmptyKey); + + let (first_block, version) = futures::try_join!(read_first_block, get_next_blocks)?; + let version = match version { + Some(v) => v, + None => return Err(Error::NotFound), + }; + + let mut blocks = version + .blocks + .iter() + .map(|vb| (vb.hash, None)) + .collect::>(); + blocks[0].1 = Some(first_block); + + let body_stream = futures::stream::iter(blocks) + .map(move |(hash, data_opt)| { + let garage = garage.clone(); + async move { + if let Some(data) = data_opt { + Ok(Bytes::from(data)) + } else { + garage + .block_manager + .rpc_get_block(&hash) + .await + .map(Bytes::from) + } + } + }) + .buffered(2); + let body: BodyType = Box::new(StreamBody::new(Box::pin(body_stream))); + Ok(resp_builder.body(body)?) + } + } +} diff --git a/src/api/http_util.rs b/src/api/http_util.rs new file mode 100644 index 00000000..228448f0 --- /dev/null +++ b/src/api/http_util.rs @@ -0,0 +1,82 @@ +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures::ready; +use futures::stream::*; +use hyper::body::{Bytes, HttpBody}; + +use crate::error::Error; + +type StreamType = Pin> + Send>>; + +pub struct StreamBody { + stream: StreamType, +} + +impl StreamBody { + pub fn new(stream: StreamType) -> Self { + Self { stream } + } +} + +impl HttpBody for StreamBody { + type Data = Bytes; + type Error = Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll>> { + match ready!(self.stream.as_mut().poll_next(cx)) { + Some(res) => Poll::Ready(Some(res)), + None => Poll::Ready(None), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll>, Self::Error>> { + Poll::Ready(Ok(None)) + } +} + +pub struct BytesBody { + bytes: Option, +} + +impl BytesBody { + pub fn new(bytes: Bytes) -> Self { + Self { bytes: Some(bytes) } + } +} + +impl HttpBody for BytesBody { + type Data = Bytes; + type Error = Error; + + fn poll_data( + mut self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll>> { + Poll::Ready(self.bytes.take().map(Ok)) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll>, Self::Error>> { + Poll::Ready(Ok(None)) + } +} + +impl From for BytesBody { + fn from(x: String) -> BytesBody { + Self::new(Bytes::from(x)) + } +} +impl From> for BytesBody { + fn from(x: Vec) -> BytesBody { + Self::new(Bytes::from(x)) + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 00000000..8e62d1e7 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,2 @@ +pub mod api_server; +pub mod http_util; diff --git a/src/api_server.rs b/src/api_server.rs deleted file mode 100644 index f4bb4177..00000000 --- a/src/api_server.rs +++ /dev/null @@ -1,358 +0,0 @@ -use std::collections::VecDeque; -use std::net::SocketAddr; -use std::sync::Arc; - -use futures::future::Future; -use futures::stream::*; -use hyper::body::{Bytes, HttpBody}; -use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; - -use crate::data::*; -use crate::error::Error; -use crate::http_util::*; -use crate::table::EmptyKey; - -use crate::block::INLINE_THRESHOLD; -use crate::block_ref_table::*; -use crate::object_table::*; -use crate::server::Garage; -use crate::version_table::*; - -type BodyType = Box + Send + Unpin>; - -pub async fn run_api_server( - garage: Arc, - shutdown_signal: impl Future, -) -> Result<(), Error> { - let addr = &garage.system.config.api_bind_addr; - - let service = make_service_fn(|conn: &AddrStream| { - let garage = garage.clone(); - let client_addr = conn.remote_addr(); - async move { - Ok::<_, Error>(service_fn(move |req: Request| { - let garage = garage.clone(); - handler(garage, req, client_addr) - })) - } - }); - - let server = Server::bind(&addr).serve(service); - - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("API server listening on http://{}", addr); - - graceful.await?; - Ok(()) -} - -async fn handler( - garage: Arc, - req: Request, - addr: SocketAddr, -) -> Result, Error> { - match handler_inner(garage, req, addr).await { - Ok(x) => Ok(x), - Err(e) => { - let body: BodyType = Box::new(BytesBody::from(format!("{}\n", e))); - let mut http_error = Response::new(body); - *http_error.status_mut() = e.http_status_code(); - Ok(http_error) - } - } -} - -async fn handler_inner( - garage: Arc, - req: Request, - addr: SocketAddr, -) -> Result, Error> { - info!("{} {} {}", addr, req.method(), req.uri()); - - let bucket = req - .headers() - .get(hyper::header::HOST) - .map(|x| x.to_str().map_err(Error::from)) - .unwrap_or(Err(Error::BadRequest(format!("Host: header missing"))))? - .to_lowercase(); - let key = req.uri().path().to_string(); - - match req.method() { - &Method::GET => Ok(handle_get(garage, &bucket, &key).await?), - &Method::PUT => { - let mime_type = req - .headers() - .get(hyper::header::CONTENT_TYPE) - .map(|x| x.to_str()) - .unwrap_or(Ok("blob"))? - .to_string(); - let version_uuid = - handle_put(garage, &mime_type, &bucket, &key, req.into_body()).await?; - let response = format!("{}\n", hex::encode(version_uuid,)); - Ok(Response::new(Box::new(BytesBody::from(response)))) - } - &Method::DELETE => { - let version_uuid = handle_delete(garage, &bucket, &key).await?; - let response = format!("{}\n", hex::encode(version_uuid,)); - Ok(Response::new(Box::new(BytesBody::from(response)))) - } - _ => Err(Error::BadRequest(format!("Invalid method"))), - } -} - -async fn handle_put( - garage: Arc, - mime_type: &str, - bucket: &str, - key: &str, - body: Body, -) -> Result { - let version_uuid = gen_uuid(); - - let mut chunker = BodyChunker::new(body, garage.system.config.block_size); - let first_block = match chunker.next().await? { - Some(x) => x, - None => return Err(Error::BadRequest(format!("Empty body"))), - }; - - let mut object = Object { - bucket: bucket.into(), - key: key.into(), - versions: Vec::new(), - }; - object.versions.push(Box::new(ObjectVersion { - uuid: version_uuid, - timestamp: now_msec(), - mime_type: mime_type.to_string(), - size: first_block.len() as u64, - is_complete: false, - data: ObjectVersionData::DeleteMarker, - })); - - if first_block.len() < INLINE_THRESHOLD { - object.versions[0].data = ObjectVersionData::Inline(first_block); - object.versions[0].is_complete = true; - garage.object_table.insert(&object).await?; - return Ok(version_uuid); - } - - let version = Version { - uuid: version_uuid, - deleted: false, - blocks: Vec::new(), - bucket: bucket.into(), - key: key.into(), - }; - - let first_block_hash = hash(&first_block[..]); - object.versions[0].data = ObjectVersionData::FirstBlock(first_block_hash); - garage.object_table.insert(&object).await?; - - let mut next_offset = first_block.len(); - let mut put_curr_version_block = put_block_meta(garage.clone(), &version, 0, first_block_hash); - let mut put_curr_block = garage - .block_manager - .rpc_put_block(first_block_hash, first_block); - - loop { - let (_, _, next_block) = - futures::try_join!(put_curr_block, put_curr_version_block, chunker.next())?; - if let Some(block) = next_block { - let block_hash = hash(&block[..]); - let block_len = block.len(); - put_curr_version_block = - put_block_meta(garage.clone(), &version, next_offset as u64, block_hash); - put_curr_block = garage.block_manager.rpc_put_block(block_hash, block); - next_offset += block_len; - } else { - break; - } - } - - // TODO: if at any step we have an error, we should undo everything we did - - object.versions[0].is_complete = true; - object.versions[0].size = next_offset as u64; - garage.object_table.insert(&object).await?; - Ok(version_uuid) -} - -async fn put_block_meta( - garage: Arc, - version: &Version, - offset: u64, - hash: Hash, -) -> Result<(), Error> { - let mut version = version.clone(); - version.blocks.push(VersionBlock { offset, hash: hash }); - - let block_ref = BlockRef { - block: hash, - version: version.uuid, - deleted: false, - }; - - futures::try_join!( - garage.version_table.insert(&version), - garage.block_ref_table.insert(&block_ref), - )?; - Ok(()) -} - -struct BodyChunker { - body: Body, - read_all: bool, - block_size: usize, - buf: VecDeque, -} - -impl BodyChunker { - fn new(body: Body, block_size: usize) -> Self { - Self { - body, - read_all: false, - block_size, - buf: VecDeque::new(), - } - } - async fn next(&mut self) -> Result>, Error> { - while !self.read_all && self.buf.len() < self.block_size { - if let Some(block) = self.body.next().await { - let bytes = block?; - trace!("Body next: {} bytes", bytes.len()); - self.buf.extend(&bytes[..]); - } else { - self.read_all = true; - } - } - if self.buf.len() == 0 { - Ok(None) - } else if self.buf.len() <= self.block_size { - let block = self.buf.drain(..).collect::>(); - Ok(Some(block)) - } else { - let block = self.buf.drain(..self.block_size).collect::>(); - Ok(Some(block)) - } - } -} - -async fn handle_delete(garage: Arc, bucket: &str, key: &str) -> Result { - let exists = match garage - .object_table - .get(&bucket.to_string(), &key.to_string()) - .await? - { - None => false, - Some(o) => { - let mut has_active_version = false; - for v in o.versions.iter() { - if v.data != ObjectVersionData::DeleteMarker { - has_active_version = true; - break; - } - } - has_active_version - } - }; - - if !exists { - // No need to delete - return Ok([0u8; 32].into()); - } - - let version_uuid = gen_uuid(); - - let mut object = Object { - bucket: bucket.into(), - key: key.into(), - versions: Vec::new(), - }; - object.versions.push(Box::new(ObjectVersion { - uuid: version_uuid, - timestamp: now_msec(), - mime_type: "application/x-delete-marker".into(), - size: 0, - is_complete: true, - data: ObjectVersionData::DeleteMarker, - })); - - garage.object_table.insert(&object).await?; - return Ok(version_uuid); -} - -async fn handle_get( - garage: Arc, - bucket: &str, - key: &str, -) -> Result, Error> { - let mut object = match garage - .object_table - .get(&bucket.to_string(), &key.to_string()) - .await? - { - None => return Err(Error::NotFound), - Some(o) => o, - }; - - let last_v = match object - .versions - .drain(..) - .rev() - .filter(|v| v.is_complete) - .next() - { - Some(v) => v, - None => return Err(Error::NotFound), - }; - - let resp_builder = Response::builder() - .header("Content-Type", last_v.mime_type) - .status(StatusCode::OK); - - match last_v.data { - ObjectVersionData::DeleteMarker => Err(Error::NotFound), - ObjectVersionData::Inline(bytes) => { - let body: BodyType = Box::new(BytesBody::from(bytes)); - Ok(resp_builder.body(body)?) - } - ObjectVersionData::FirstBlock(first_block_hash) => { - let read_first_block = garage.block_manager.rpc_get_block(&first_block_hash); - let get_next_blocks = garage.version_table.get(&last_v.uuid, &EmptyKey); - - let (first_block, version) = futures::try_join!(read_first_block, get_next_blocks)?; - let version = match version { - Some(v) => v, - None => return Err(Error::NotFound), - }; - - let mut blocks = version - .blocks - .iter() - .map(|vb| (vb.hash, None)) - .collect::>(); - blocks[0].1 = Some(first_block); - - let body_stream = futures::stream::iter(blocks) - .map(move |(hash, data_opt)| { - let garage = garage.clone(); - async move { - if let Some(data) = data_opt { - Ok(Bytes::from(data)) - } else { - garage - .block_manager - .rpc_get_block(&hash) - .await - .map(Bytes::from) - } - } - }) - .buffered(2); - let body: BodyType = Box::new(StreamBody::new(Box::pin(body_stream))); - Ok(resp_builder.body(body)?) - } - } -} diff --git a/src/block.rs b/src/block.rs deleted file mode 100644 index 23222a7f..00000000 --- a/src/block.rs +++ /dev/null @@ -1,500 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; -use std::time::Duration; - -use arc_swap::ArcSwapOption; -use futures::future::*; -use futures::select; -use futures::stream::*; -use serde::{Deserialize, Serialize}; -use tokio::fs; -use tokio::prelude::*; -use tokio::sync::{watch, Mutex, Notify}; - -use crate::data; -use crate::data::*; -use crate::error::Error; -use crate::membership::System; -use crate::rpc_client::*; -use crate::rpc_server::*; - -use crate::block_ref_table::*; -use crate::server::Garage; - -pub const INLINE_THRESHOLD: usize = 3072; - -const BLOCK_RW_TIMEOUT: Duration = Duration::from_secs(42); -const NEED_BLOCK_QUERY_TIMEOUT: Duration = Duration::from_secs(5); -const RESYNC_RETRY_TIMEOUT: Duration = Duration::from_secs(10); - -#[derive(Debug, Serialize, Deserialize)] -pub enum Message { - Ok, - GetBlock(Hash), - PutBlock(PutBlockMessage), - NeedBlockQuery(Hash), - NeedBlockReply(bool), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct PutBlockMessage { - pub hash: Hash, - - #[serde(with = "serde_bytes")] - pub data: Vec, -} - -impl RpcMessage for Message {} - -pub struct BlockManager { - pub data_dir: PathBuf, - pub data_dir_lock: Mutex<()>, - - pub rc: sled::Tree, - - pub resync_queue: sled::Tree, - pub resync_notify: Notify, - - pub system: Arc, - rpc_client: Arc>, - pub garage: ArcSwapOption, -} - -impl BlockManager { - pub fn new( - db: &sled::Db, - data_dir: PathBuf, - system: Arc, - rpc_server: &mut RpcServer, - ) -> Arc { - let rc = db - .open_tree("block_local_rc") - .expect("Unable to open block_local_rc tree"); - rc.set_merge_operator(rc_merge); - - let resync_queue = db - .open_tree("block_local_resync_queue") - .expect("Unable to open block_local_resync_queue tree"); - - let rpc_path = "block_manager"; - let rpc_client = system.rpc_client::(rpc_path); - - let block_manager = Arc::new(Self { - data_dir, - data_dir_lock: Mutex::new(()), - rc, - resync_queue, - resync_notify: Notify::new(), - system, - rpc_client, - garage: ArcSwapOption::from(None), - }); - block_manager - .clone() - .register_handler(rpc_server, rpc_path.into()); - block_manager - } - - fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { - let self2 = self.clone(); - rpc_server.add_handler::(path, move |msg, _addr| { - let self2 = self2.clone(); - async move { self2.handle(&msg).await } - }); - - let self2 = self.clone(); - self.rpc_client - .set_local_handler(self.system.id, move |msg| { - let self2 = self2.clone(); - async move { self2.handle(&msg).await } - }); - } - - async fn handle(self: Arc, msg: &Message) -> Result { - match msg { - Message::PutBlock(m) => self.write_block(&m.hash, &m.data).await, - Message::GetBlock(h) => self.read_block(h).await, - Message::NeedBlockQuery(h) => self.need_block(h).await.map(Message::NeedBlockReply), - _ => Err(Error::BadRequest(format!("Unexpected RPC message"))), - } - } - - pub async fn spawn_background_worker(self: Arc) { - // Launch 2 simultaneous workers for background resync loop preprocessing - for i in 0..2usize { - let bm2 = self.clone(); - let background = self.system.background.clone(); - tokio::spawn(async move { - tokio::time::delay_for(Duration::from_secs(10)).await; - background - .spawn_worker(format!("block resync worker {}", i), move |must_exit| { - bm2.resync_loop(must_exit) - }) - .await; - }); - } - } - - pub async fn write_block(&self, hash: &Hash, data: &[u8]) -> Result { - let _lock = self.data_dir_lock.lock().await; - - let mut path = self.block_dir(hash); - fs::create_dir_all(&path).await?; - - path.push(hex::encode(hash)); - if fs::metadata(&path).await.is_ok() { - return Ok(Message::Ok); - } - - let mut f = fs::File::create(path).await?; - f.write_all(data).await?; - drop(f); - - Ok(Message::Ok) - } - - pub async fn read_block(&self, hash: &Hash) -> Result { - let path = self.block_path(hash); - - let mut f = match fs::File::open(&path).await { - Ok(f) => f, - Err(e) => { - // Not found but maybe we should have had it ?? - self.put_to_resync(hash, 0)?; - return Err(Into::into(e)); - } - }; - let mut data = vec![]; - f.read_to_end(&mut data).await?; - drop(f); - - if data::hash(&data[..]) != *hash { - let _lock = self.data_dir_lock.lock().await; - warn!("Block {:?} is corrupted. Deleting and resyncing.", hash); - fs::remove_file(path).await?; - self.put_to_resync(&hash, 0)?; - return Err(Error::CorruptData(*hash)); - } - - Ok(Message::PutBlock(PutBlockMessage { hash: *hash, data })) - } - - pub async fn need_block(&self, hash: &Hash) -> Result { - let needed = self - .rc - .get(hash.as_ref())? - .map(|x| u64_from_bytes(x.as_ref()) > 0) - .unwrap_or(false); - if needed { - let path = self.block_path(hash); - let exists = fs::metadata(&path).await.is_ok(); - Ok(!exists) - } else { - Ok(false) - } - } - - fn block_dir(&self, hash: &Hash) -> PathBuf { - let mut path = self.data_dir.clone(); - path.push(hex::encode(&hash.as_slice()[0..1])); - path.push(hex::encode(&hash.as_slice()[1..2])); - path - } - fn block_path(&self, hash: &Hash) -> PathBuf { - let mut path = self.block_dir(hash); - path.push(hex::encode(hash.as_ref())); - path - } - - pub fn block_incref(&self, hash: &Hash) -> Result<(), Error> { - let old_rc = self.rc.get(&hash)?; - self.rc.merge(&hash, vec![1])?; - if old_rc.map(|x| u64_from_bytes(&x[..]) == 0).unwrap_or(true) { - self.put_to_resync(&hash, BLOCK_RW_TIMEOUT.as_millis() as u64)?; - } - Ok(()) - } - - pub fn block_decref(&self, hash: &Hash) -> Result<(), Error> { - let new_rc = self.rc.merge(&hash, vec![0])?; - if new_rc.map(|x| u64_from_bytes(&x[..]) == 0).unwrap_or(true) { - self.put_to_resync(&hash, 0)?; - } - Ok(()) - } - - fn put_to_resync(&self, hash: &Hash, delay_millis: u64) -> Result<(), Error> { - let when = now_msec() + delay_millis; - trace!("Put resync_queue: {} {:?}", when, hash); - let mut key = u64::to_be_bytes(when).to_vec(); - key.extend(hash.as_ref()); - self.resync_queue.insert(key, hash.as_ref())?; - self.resync_notify.notify(); - Ok(()) - } - - async fn resync_loop( - self: Arc, - mut must_exit: watch::Receiver, - ) -> Result<(), Error> { - let mut n_failures = 0usize; - while !*must_exit.borrow() { - if let Some((time_bytes, hash_bytes)) = self.resync_queue.pop_min()? { - let time_msec = u64_from_bytes(&time_bytes[0..8]); - let now = now_msec(); - if now >= time_msec { - let mut hash = [0u8; 32]; - hash.copy_from_slice(hash_bytes.as_ref()); - let hash = Hash::from(hash); - - if let Err(e) = self.resync_iter(&hash).await { - warn!("Failed to resync block {:?}, retrying later: {}", hash, e); - self.put_to_resync(&hash, RESYNC_RETRY_TIMEOUT.as_millis() as u64)?; - n_failures += 1; - if n_failures >= 10 { - warn!("Too many resync failures, throttling."); - tokio::time::delay_for(Duration::from_secs(1)).await; - } - } else { - n_failures = 0; - } - } else { - self.resync_queue.insert(time_bytes, hash_bytes)?; - let delay = tokio::time::delay_for(Duration::from_millis(time_msec - now)); - select! { - _ = delay.fuse() => (), - _ = self.resync_notify.notified().fuse() => (), - _ = must_exit.recv().fuse() => (), - } - } - } else { - select! { - _ = self.resync_notify.notified().fuse() => (), - _ = must_exit.recv().fuse() => (), - } - } - } - Ok(()) - } - - async fn resync_iter(&self, hash: &Hash) -> Result<(), Error> { - let path = self.block_path(hash); - - let exists = fs::metadata(&path).await.is_ok(); - let needed = self - .rc - .get(hash.as_ref())? - .map(|x| u64_from_bytes(x.as_ref()) > 0) - .unwrap_or(false); - - if exists != needed { - info!( - "Resync block {:?}: exists {}, needed {}", - hash, exists, needed - ); - } - - if exists && !needed { - let garage = self.garage.load_full().unwrap(); - let active_refs = garage - .block_ref_table - .get_range(&hash, None, Some(()), 1) - .await?; - let needed_by_others = !active_refs.is_empty(); - if needed_by_others { - let ring = garage.system.ring.borrow().clone(); - let who = ring.walk_ring(&hash, garage.system.config.data_replication_factor); - let msg = Arc::new(Message::NeedBlockQuery(*hash)); - let who_needs_fut = who.iter().map(|to| { - self.rpc_client - .call_arc(*to, msg.clone(), NEED_BLOCK_QUERY_TIMEOUT) - }); - let who_needs = join_all(who_needs_fut).await; - - let mut need_nodes = vec![]; - for (node, needed) in who.into_iter().zip(who_needs.iter()) { - match needed { - Ok(Message::NeedBlockReply(needed)) => { - if *needed { - need_nodes.push(node); - } - } - Err(e) => { - return Err(Error::Message(format!( - "Should delete block, but unable to confirm that all other nodes that need it have it: {}", - e - ))); - } - Ok(_) => { - return Err(Error::Message(format!( - "Unexpected response to NeedBlockQuery RPC" - ))); - } - } - } - - if need_nodes.len() > 0 { - let put_block_message = self.read_block(hash).await?; - self.rpc_client - .try_call_many( - &need_nodes[..], - put_block_message, - RequestStrategy::with_quorum(need_nodes.len()) - .with_timeout(BLOCK_RW_TIMEOUT), - ) - .await?; - } - } - fs::remove_file(path).await?; - self.resync_queue.remove(&hash)?; - } - - if needed && !exists { - // TODO find a way to not do this if they are sending it to us - // Let's suppose this isn't an issue for now with the BLOCK_RW_TIMEOUT delay - // between the RC being incremented and this part being called. - let block_data = self.rpc_get_block(&hash).await?; - self.write_block(hash, &block_data[..]).await?; - } - - Ok(()) - } - - pub async fn rpc_get_block(&self, hash: &Hash) -> Result, Error> { - let ring = self.system.ring.borrow().clone(); - let who = ring.walk_ring(&hash, self.system.config.data_replication_factor); - let resps = self - .rpc_client - .try_call_many( - &who[..], - Message::GetBlock(*hash), - RequestStrategy::with_quorum(1) - .with_timeout(BLOCK_RW_TIMEOUT) - .interrupt_after_quorum(true), - ) - .await?; - - for resp in resps { - if let Message::PutBlock(msg) = resp { - return Ok(msg.data); - } - } - Err(Error::Message(format!( - "Unable to read block {:?}: no valid blocks returned", - hash - ))) - } - - pub async fn rpc_put_block(&self, hash: Hash, data: Vec) -> Result<(), Error> { - let ring = self.system.ring.borrow().clone(); - let who = ring.walk_ring(&hash, self.system.config.data_replication_factor); - self.rpc_client - .try_call_many( - &who[..], - Message::PutBlock(PutBlockMessage { hash, data }), - RequestStrategy::with_quorum((self.system.config.data_replication_factor + 1) / 2) - .with_timeout(BLOCK_RW_TIMEOUT), - ) - .await?; - Ok(()) - } - - pub async fn repair_data_store(&self, must_exit: &watch::Receiver) -> Result<(), Error> { - // 1. Repair blocks from RC table - let garage = self.garage.load_full().unwrap(); - let mut last_hash = None; - let mut i = 0usize; - for entry in garage.block_ref_table.store.iter() { - let (_k, v_bytes) = entry?; - let block_ref = rmp_serde::decode::from_read_ref::<_, BlockRef>(v_bytes.as_ref())?; - if Some(&block_ref.block) == last_hash.as_ref() { - continue; - } - if !block_ref.deleted { - last_hash = Some(block_ref.block); - self.put_to_resync(&block_ref.block, 0)?; - } - i += 1; - if i & 0xFF == 0 && *must_exit.borrow() { - return Ok(()); - } - } - - // 2. Repair blocks actually on disk - let mut ls_data_dir = fs::read_dir(&self.data_dir).await?; - while let Some(data_dir_ent) = ls_data_dir.next().await { - let data_dir_ent = data_dir_ent?; - let dir_name = data_dir_ent.file_name(); - let dir_name = match dir_name.into_string() { - Ok(x) => x, - Err(_) => continue, - }; - if dir_name.len() != 2 || hex::decode(&dir_name).is_err() { - continue; - } - - let mut ls_data_dir_2 = match fs::read_dir(data_dir_ent.path()).await { - Err(e) => { - warn!( - "Warning: could not list dir {:?}: {}", - data_dir_ent.path().to_str(), - e - ); - continue; - } - Ok(x) => x, - }; - while let Some(file) = ls_data_dir_2.next().await { - let file = file?; - let file_name = file.file_name(); - let file_name = match file_name.into_string() { - Ok(x) => x, - Err(_) => continue, - }; - if file_name.len() != 64 { - continue; - } - let hash_bytes = match hex::decode(&file_name) { - Ok(h) => h, - Err(_) => continue, - }; - let mut hash = [0u8; 32]; - hash.copy_from_slice(&hash_bytes[..]); - self.put_to_resync(&hash.into(), 0)?; - - if *must_exit.borrow() { - return Ok(()); - } - } - } - Ok(()) - } -} - -fn u64_from_bytes(bytes: &[u8]) -> u64 { - assert!(bytes.len() == 8); - let mut x8 = [0u8; 8]; - x8.copy_from_slice(bytes); - u64::from_be_bytes(x8) -} - -fn rc_merge(_key: &[u8], old: Option<&[u8]>, new: &[u8]) -> Option> { - let old = old.map(u64_from_bytes).unwrap_or(0); - assert!(new.len() == 1); - let new = match new[0] { - 0 => { - if old > 0 { - old - 1 - } else { - 0 - } - } - 1 => old + 1, - _ => unreachable!(), - }; - if new == 0 { - None - } else { - Some(u64::to_be_bytes(new).to_vec()) - } -} diff --git a/src/block_ref_table.rs b/src/block_ref_table.rs deleted file mode 100644 index 6a256aa3..00000000 --- a/src/block_ref_table.rs +++ /dev/null @@ -1,67 +0,0 @@ -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; - -use crate::background::*; -use crate::data::*; -use crate::error::Error; -use crate::table::*; - -use crate::block::*; - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub struct BlockRef { - // Primary key - pub block: Hash, - - // Sort key - pub version: UUID, - - // Keep track of deleted status - pub deleted: bool, -} - -impl Entry for BlockRef { - fn partition_key(&self) -> &Hash { - &self.block - } - fn sort_key(&self) -> &UUID { - &self.version - } - - fn merge(&mut self, other: &Self) { - if other.deleted { - self.deleted = true; - } - } -} - -pub struct BlockRefTable { - pub background: Arc, - pub block_manager: Arc, -} - -#[async_trait] -impl TableSchema for BlockRefTable { - type P = Hash; - type S = UUID; - type E = BlockRef; - type Filter = (); - - async fn updated(&self, old: Option, new: Option) -> Result<(), Error> { - let block = &old.as_ref().or(new.as_ref()).unwrap().block; - let was_before = old.as_ref().map(|x| !x.deleted).unwrap_or(false); - let is_after = new.as_ref().map(|x| !x.deleted).unwrap_or(false); - if is_after && !was_before { - self.block_manager.block_incref(block)?; - } - if was_before && !is_after { - self.block_manager.block_decref(block)?; - } - Ok(()) - } - - fn matches_filter(entry: &Self::E, _filter: &Self::Filter) -> bool { - !entry.deleted - } -} diff --git a/src/bucket_table.rs b/src/bucket_table.rs deleted file mode 100644 index 5604049c..00000000 --- a/src/bucket_table.rs +++ /dev/null @@ -1,82 +0,0 @@ -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - -use crate::error::Error; -use crate::table::*; - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub struct Bucket { - // Primary key - pub name: String, - - // Timestamp and deletion - // Upon version increment, all info is replaced - pub timestamp: u64, - pub deleted: bool, - - // Authorized keys - pub authorized_keys: Vec, -} - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub struct AllowedKey { - pub access_key_id: String, - pub timestamp: u64, - pub allowed_read: bool, - pub allowed_write: bool, -} - -impl Entry for Bucket { - fn partition_key(&self) -> &EmptyKey { - &EmptyKey - } - fn sort_key(&self) -> &String { - &self.name - } - - fn merge(&mut self, other: &Self) { - if other.timestamp < self.timestamp { - *self = other.clone(); - return; - } - if self.timestamp > other.timestamp { - return; - } - for ak in other.authorized_keys.iter() { - match self - .authorized_keys - .binary_search_by(|our_ak| our_ak.access_key_id.cmp(&ak.access_key_id)) - { - Ok(i) => { - let our_ak = &mut self.authorized_keys[i]; - if ak.timestamp > our_ak.timestamp { - our_ak.timestamp = ak.timestamp; - our_ak.allowed_read = ak.allowed_read; - our_ak.allowed_write = ak.allowed_write; - } - } - Err(i) => { - self.authorized_keys.insert(i, ak.clone()); - } - } - } - } -} - -pub struct BucketTable; - -#[async_trait] -impl TableSchema for BucketTable { - type P = EmptyKey; - type S = String; - type E = Bucket; - type Filter = (); - - async fn updated(&self, _old: Option, _new: Option) -> Result<(), Error> { - Ok(()) - } - - fn matches_filter(entry: &Self::E, _filter: &Self::Filter) -> bool { - !entry.deleted - } -} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 00000000..7a6ae3f2 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,66 @@ +use std::io::Read; +use std::net::SocketAddr; +use std::path::PathBuf; + +use serde::Deserialize; + +use crate::error::Error; + +#[derive(Deserialize, Debug, Clone)] +pub struct Config { + pub metadata_dir: PathBuf, + pub data_dir: PathBuf, + + pub api_bind_addr: SocketAddr, + pub rpc_bind_addr: SocketAddr, + + pub bootstrap_peers: Vec, + + #[serde(default = "default_max_concurrent_rpc_requests")] + pub max_concurrent_rpc_requests: usize, + + #[serde(default = "default_block_size")] + pub block_size: usize, + + #[serde(default = "default_replication_factor")] + pub meta_replication_factor: usize, + + #[serde(default = "default_epidemic_factor")] + pub meta_epidemic_factor: usize, + + #[serde(default = "default_replication_factor")] + pub data_replication_factor: usize, + + pub rpc_tls: Option, +} + +fn default_max_concurrent_rpc_requests() -> usize { + 12 +} +fn default_block_size() -> usize { + 1048576 +} +fn default_replication_factor() -> usize { + 3 +} +fn default_epidemic_factor() -> usize { + 3 +} + +#[derive(Deserialize, Debug, Clone)] +pub struct TlsConfig { + pub ca_cert: String, + pub node_cert: String, + pub node_key: String, +} + +pub fn read_config(config_file: PathBuf) -> Result { + let mut file = std::fs::OpenOptions::new() + .read(true) + .open(config_file.as_path())?; + + let mut config = String::new(); + file.read_to_string(&mut config)?; + + Ok(toml::from_str(&config)?) +} diff --git a/src/error.rs b/src/error.rs index e217f9ae..6290dc24 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,7 +3,7 @@ use hyper::StatusCode; use std::io; use crate::data::Hash; -use crate::rpc_client::RPCError; +use crate::rpc::rpc_client::RPCError; #[derive(Debug, Error)] pub enum Error { diff --git a/src/http_util.rs b/src/http_util.rs deleted file mode 100644 index 228448f0..00000000 --- a/src/http_util.rs +++ /dev/null @@ -1,82 +0,0 @@ -use core::pin::Pin; -use core::task::{Context, Poll}; - -use futures::ready; -use futures::stream::*; -use hyper::body::{Bytes, HttpBody}; - -use crate::error::Error; - -type StreamType = Pin> + Send>>; - -pub struct StreamBody { - stream: StreamType, -} - -impl StreamBody { - pub fn new(stream: StreamType) -> Self { - Self { stream } - } -} - -impl HttpBody for StreamBody { - type Data = Bytes; - type Error = Error; - - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll>> { - match ready!(self.stream.as_mut().poll_next(cx)) { - Some(res) => Poll::Ready(Some(res)), - None => Poll::Ready(None), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll>, Self::Error>> { - Poll::Ready(Ok(None)) - } -} - -pub struct BytesBody { - bytes: Option, -} - -impl BytesBody { - pub fn new(bytes: Bytes) -> Self { - Self { bytes: Some(bytes) } - } -} - -impl HttpBody for BytesBody { - type Data = Bytes; - type Error = Error; - - fn poll_data( - mut self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll>> { - Poll::Ready(self.bytes.take().map(Ok)) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll>, Self::Error>> { - Poll::Ready(Ok(None)) - } -} - -impl From for BytesBody { - fn from(x: String) -> BytesBody { - Self::new(Bytes::from(x)) - } -} -impl From> for BytesBody { - fn from(x: Vec) -> BytesBody { - Self::new(Bytes::from(x)) - } -} diff --git a/src/main.rs b/src/main.rs index 0b41805b..c693b12c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,29 +3,18 @@ #[macro_use] extern crate log; +mod background; +mod config; mod data; mod error; -mod background; -mod membership; +mod api; +mod rpc; +mod store; mod table; -mod table_fullcopy; -mod table_sharded; -mod table_sync; - -mod block; -mod block_ref_table; -mod bucket_table; -mod object_table; -mod version_table; mod admin_rpc; -mod api_server; -mod http_util; -mod rpc_client; -mod rpc_server; mod server; -mod tls_util; use std::collections::HashSet; use std::net::SocketAddr; @@ -36,11 +25,12 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; use structopt::StructOpt; +use config::TlsConfig; use data::*; use error::Error; -use membership::*; -use rpc_client::*; -use server::TlsConfig; + +use rpc::membership::*; +use rpc::rpc_client::*; use admin_rpc::*; diff --git a/src/membership.rs b/src/membership.rs deleted file mode 100644 index 87b065a7..00000000 --- a/src/membership.rs +++ /dev/null @@ -1,674 +0,0 @@ -use std::collections::HashMap; -use std::hash::Hash as StdHash; -use std::hash::Hasher; -use std::io::Read; -use std::net::{IpAddr, SocketAddr}; -use std::path::PathBuf; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::Duration; - -use futures::future::join_all; -use futures::select; -use futures_util::future::*; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; -use tokio::prelude::*; -use tokio::sync::watch; -use tokio::sync::Mutex; - -use crate::background::BackgroundRunner; -use crate::data::*; -use crate::error::Error; -use crate::rpc_client::*; -use crate::rpc_server::*; -use crate::server::Config; - -const PING_INTERVAL: Duration = Duration::from_secs(10); -const PING_TIMEOUT: Duration = Duration::from_secs(2); -const MAX_FAILURES_BEFORE_CONSIDERED_DOWN: usize = 5; - -pub const MEMBERSHIP_RPC_PATH: &str = "_membership"; - -#[derive(Debug, Serialize, Deserialize)] -pub enum Message { - Ok, - Ping(PingMessage), - PullStatus, - PullConfig, - AdvertiseNodesUp(Vec), - AdvertiseConfig(NetworkConfig), -} - -impl RpcMessage for Message {} - -#[derive(Debug, Serialize, Deserialize)] -pub struct PingMessage { - pub id: UUID, - pub rpc_port: u16, - - pub status_hash: Hash, - pub config_version: u64, - - pub state_info: StateInfo, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct AdvertisedNode { - pub id: UUID, - pub addr: SocketAddr, - - pub is_up: bool, - pub last_seen: u64, - - pub state_info: StateInfo, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NetworkConfig { - pub members: HashMap, - pub version: u64, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NetworkConfigEntry { - pub datacenter: String, - pub n_tokens: u32, - pub tag: String, -} - -pub struct System { - pub config: Config, - pub id: UUID, - - pub state_info: StateInfo, - - pub rpc_http_client: Arc, - rpc_client: Arc>, - - pub status: watch::Receiver>, - pub ring: watch::Receiver>, - - update_lock: Mutex<(watch::Sender>, watch::Sender>)>, - - pub background: Arc, -} - -#[derive(Debug, Clone)] -pub struct Status { - pub nodes: HashMap>, - pub hash: Hash, -} - -#[derive(Debug)] -pub struct StatusEntry { - pub addr: SocketAddr, - pub last_seen: u64, - pub num_failures: AtomicUsize, - pub state_info: StateInfo, -} - -impl StatusEntry { - pub fn is_up(&self) -> bool { - self.num_failures.load(Ordering::SeqCst) < MAX_FAILURES_BEFORE_CONSIDERED_DOWN - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StateInfo { - pub hostname: String, -} - -#[derive(Clone)] -pub struct Ring { - pub config: NetworkConfig, - pub ring: Vec, - pub n_datacenters: usize, -} - -#[derive(Clone, Debug)] -pub struct RingEntry { - pub location: Hash, - pub node: UUID, - pub datacenter: u64, -} - -impl Status { - fn handle_ping(&mut self, ip: IpAddr, info: &PingMessage) -> bool { - let addr = SocketAddr::new(ip, info.rpc_port); - let old_status = self.nodes.insert( - info.id, - Arc::new(StatusEntry { - addr, - last_seen: now_msec(), - num_failures: AtomicUsize::from(0), - state_info: info.state_info.clone(), - }), - ); - match old_status { - None => { - info!("Newly pingable node: {}", hex::encode(&info.id)); - true - } - Some(x) => x.addr != addr, - } - } - - fn recalculate_hash(&mut self) { - let mut nodes = self.nodes.iter().collect::>(); - nodes.sort_unstable_by_key(|(id, _status)| *id); - - let mut hasher = Sha256::new(); - debug!("Current set of pingable nodes: --"); - for (id, status) in nodes { - debug!("{} {}", hex::encode(&id), status.addr); - hasher.input(format!("{} {}\n", hex::encode(&id), status.addr)); - } - debug!("END --"); - self.hash - .as_slice_mut() - .copy_from_slice(&hasher.result()[..]); - } -} - -impl Ring { - fn rebuild_ring(&mut self) { - let mut new_ring = vec![]; - let mut datacenters = vec![]; - - for (id, config) in self.config.members.iter() { - let mut dc_hasher = std::collections::hash_map::DefaultHasher::new(); - config.datacenter.hash(&mut dc_hasher); - let datacenter = dc_hasher.finish(); - - if !datacenters.contains(&datacenter) { - datacenters.push(datacenter); - } - - for i in 0..config.n_tokens { - let location = hash(format!("{} {}", hex::encode(&id), i).as_bytes()); - - new_ring.push(RingEntry { - location: location.into(), - node: *id, - datacenter, - }) - } - } - - new_ring.sort_unstable_by(|x, y| x.location.cmp(&y.location)); - self.ring = new_ring; - self.n_datacenters = datacenters.len(); - - // eprintln!("RING: --"); - // for e in self.ring.iter() { - // eprintln!("{:?}", e); - // } - // eprintln!("END --"); - } - - pub fn walk_ring(&self, from: &Hash, n: usize) -> Vec { - if n >= self.config.members.len() { - return self.config.members.keys().cloned().collect::>(); - } - - let start = match self.ring.binary_search_by(|x| x.location.cmp(from)) { - Ok(i) => i, - Err(i) => { - if i == 0 { - self.ring.len() - 1 - } else { - i - 1 - } - } - }; - - self.walk_ring_from_pos(start, n) - } - - fn walk_ring_from_pos(&self, start: usize, n: usize) -> Vec { - if n >= self.config.members.len() { - return self.config.members.keys().cloned().collect::>(); - } - - let mut ret = vec![]; - let mut datacenters = vec![]; - - let mut delta = 0; - while ret.len() < n { - let i = (start + delta) % self.ring.len(); - delta += 1; - - if !datacenters.contains(&self.ring[i].datacenter) { - ret.push(self.ring[i].node); - datacenters.push(self.ring[i].datacenter); - } else if datacenters.len() == self.n_datacenters && !ret.contains(&self.ring[i].node) { - ret.push(self.ring[i].node); - } - } - - ret - } -} - -fn read_network_config(metadata_dir: &PathBuf) -> Result { - let mut path = metadata_dir.clone(); - path.push("network_config"); - - let mut file = std::fs::OpenOptions::new() - .read(true) - .open(path.as_path())?; - - let mut net_config_bytes = vec![]; - file.read_to_end(&mut net_config_bytes)?; - - let net_config = rmp_serde::decode::from_read_ref(&net_config_bytes[..]) - .expect("Unable to parse network configuration file (has version format changed?)."); - - Ok(net_config) -} - -impl System { - pub fn new( - config: Config, - id: UUID, - background: Arc, - rpc_server: &mut RpcServer, - ) -> Arc { - let net_config = match read_network_config(&config.metadata_dir) { - Ok(x) => x, - Err(e) => { - info!( - "No valid previous network configuration stored ({}), starting fresh.", - e - ); - NetworkConfig { - members: HashMap::new(), - version: 0, - } - } - }; - let mut status = Status { - nodes: HashMap::new(), - hash: Hash::default(), - }; - status.recalculate_hash(); - let (update_status, status) = watch::channel(Arc::new(status)); - - let state_info = StateInfo { - hostname: gethostname::gethostname() - .into_string() - .unwrap_or("".to_string()), - }; - - let mut ring = Ring { - config: net_config, - ring: Vec::new(), - n_datacenters: 0, - }; - ring.rebuild_ring(); - let (update_ring, ring) = watch::channel(Arc::new(ring)); - - let rpc_http_client = Arc::new( - RpcHttpClient::new(config.max_concurrent_rpc_requests, &config.rpc_tls) - .expect("Could not create RPC client"), - ); - - let rpc_path = MEMBERSHIP_RPC_PATH.to_string(); - let rpc_client = RpcClient::new( - RpcAddrClient::::new(rpc_http_client.clone(), rpc_path.clone()), - background.clone(), - status.clone(), - ); - - let sys = Arc::new(System { - config, - id, - state_info, - rpc_http_client, - rpc_client, - status, - ring, - update_lock: Mutex::new((update_status, update_ring)), - background, - }); - sys.clone().register_handler(rpc_server, rpc_path); - sys - } - - fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { - rpc_server.add_handler::(path, move |msg, addr| { - let self2 = self.clone(); - async move { - match msg { - Message::Ping(ping) => self2.handle_ping(&addr, &ping).await, - - Message::PullStatus => self2.handle_pull_status(), - Message::PullConfig => self2.handle_pull_config(), - Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await, - Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await, - - _ => Err(Error::BadRequest(format!("Unexpected RPC message"))), - } - } - }); - } - - pub fn rpc_client(self: &Arc, path: &str) -> Arc> { - RpcClient::new( - RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()), - self.background.clone(), - self.status.clone(), - ) - } - - async fn save_network_config(self: Arc) -> Result<(), Error> { - let mut path = self.config.metadata_dir.clone(); - path.push("network_config"); - - let ring = self.ring.borrow().clone(); - let data = rmp_to_vec_all_named(&ring.config)?; - - let mut f = tokio::fs::File::create(path.as_path()).await?; - f.write_all(&data[..]).await?; - Ok(()) - } - - pub fn make_ping(&self) -> Message { - let status = self.status.borrow().clone(); - let ring = self.ring.borrow().clone(); - Message::Ping(PingMessage { - id: self.id, - rpc_port: self.config.rpc_bind_addr.port(), - status_hash: status.hash, - config_version: ring.config.version, - state_info: self.state_info.clone(), - }) - } - - pub async fn broadcast(self: Arc, msg: Message, timeout: Duration) { - let status = self.status.borrow().clone(); - let to = status - .nodes - .keys() - .filter(|x| **x != self.id) - .cloned() - .collect::>(); - self.rpc_client.call_many(&to[..], msg, timeout).await; - } - - pub async fn bootstrap(self: Arc) { - let bootstrap_peers = self - .config - .bootstrap_peers - .iter() - .map(|ip| (*ip, None)) - .collect::>(); - self.clone().ping_nodes(bootstrap_peers).await; - - self.clone() - .background - .spawn_worker(format!("ping loop"), |stop_signal| { - self.ping_loop(stop_signal).map(Ok) - }) - .await; - } - - async fn ping_nodes(self: Arc, peers: Vec<(SocketAddr, Option)>) { - let ping_msg = self.make_ping(); - let ping_resps = join_all(peers.iter().map(|(addr, id_option)| { - let sys = self.clone(); - let ping_msg_ref = &ping_msg; - async move { - ( - id_option, - addr, - sys.rpc_client - .by_addr() - .call(&addr, ping_msg_ref, PING_TIMEOUT) - .await, - ) - } - })) - .await; - - let update_locked = self.update_lock.lock().await; - let mut status: Status = self.status.borrow().as_ref().clone(); - let ring = self.ring.borrow().clone(); - - let mut has_changes = false; - let mut to_advertise = vec![]; - - for (id_option, addr, ping_resp) in ping_resps { - if let Ok(Ok(Message::Ping(info))) = ping_resp { - let is_new = status.handle_ping(addr.ip(), &info); - if is_new { - has_changes = true; - to_advertise.push(AdvertisedNode { - id: info.id, - addr: *addr, - is_up: true, - last_seen: now_msec(), - state_info: info.state_info.clone(), - }); - } - if is_new || status.hash != info.status_hash { - self.background - .spawn_cancellable(self.clone().pull_status(info.id).map(Ok)); - } - if is_new || ring.config.version < info.config_version { - self.background - .spawn_cancellable(self.clone().pull_config(info.id).map(Ok)); - } - } else if let Some(id) = id_option { - if let Some(st) = status.nodes.get_mut(id) { - st.num_failures.fetch_add(1, Ordering::SeqCst); - if !st.is_up() { - warn!("Node {:?} seems to be down.", id); - if !ring.config.members.contains_key(id) { - info!("Removing node {:?} from status (not in config and not responding to pings anymore)", id); - drop(st); - status.nodes.remove(&id); - has_changes = true; - } - } - } - } - } - if has_changes { - status.recalculate_hash(); - } - if let Err(e) = update_locked.0.broadcast(Arc::new(status)) { - error!("In ping_nodes: could not save status update ({})", e); - } - drop(update_locked); - - if to_advertise.len() > 0 { - self.broadcast(Message::AdvertiseNodesUp(to_advertise), PING_TIMEOUT) - .await; - } - } - - pub async fn handle_ping( - self: Arc, - from: &SocketAddr, - ping: &PingMessage, - ) -> Result { - let update_locked = self.update_lock.lock().await; - let mut status: Status = self.status.borrow().as_ref().clone(); - - let is_new = status.handle_ping(from.ip(), ping); - if is_new { - status.recalculate_hash(); - } - let status_hash = status.hash; - let config_version = self.ring.borrow().config.version; - - update_locked.0.broadcast(Arc::new(status))?; - drop(update_locked); - - if is_new || status_hash != ping.status_hash { - self.background - .spawn_cancellable(self.clone().pull_status(ping.id).map(Ok)); - } - if is_new || config_version < ping.config_version { - self.background - .spawn_cancellable(self.clone().pull_config(ping.id).map(Ok)); - } - - Ok(self.make_ping()) - } - - pub fn handle_pull_status(&self) -> Result { - let status = self.status.borrow().clone(); - let mut mem = vec![]; - for (node, status) in status.nodes.iter() { - let state_info = if *node == self.id { - self.state_info.clone() - } else { - status.state_info.clone() - }; - mem.push(AdvertisedNode { - id: *node, - addr: status.addr, - is_up: status.is_up(), - last_seen: status.last_seen, - state_info, - }); - } - Ok(Message::AdvertiseNodesUp(mem)) - } - - pub fn handle_pull_config(&self) -> Result { - let ring = self.ring.borrow().clone(); - Ok(Message::AdvertiseConfig(ring.config.clone())) - } - - pub async fn handle_advertise_nodes_up( - self: Arc, - adv: &[AdvertisedNode], - ) -> Result { - let mut to_ping = vec![]; - - let update_lock = self.update_lock.lock().await; - let mut status: Status = self.status.borrow().as_ref().clone(); - let mut has_changed = false; - - for node in adv.iter() { - if node.id == self.id { - // learn our own ip address - let self_addr = SocketAddr::new(node.addr.ip(), self.config.rpc_bind_addr.port()); - let old_self = status.nodes.insert( - node.id, - Arc::new(StatusEntry { - addr: self_addr, - last_seen: now_msec(), - num_failures: AtomicUsize::from(0), - state_info: self.state_info.clone(), - }), - ); - has_changed = match old_self { - None => true, - Some(x) => x.addr != self_addr, - }; - } else { - let ping_them = match status.nodes.get(&node.id) { - // Case 1: new node - None => true, - // Case 2: the node might have changed address - Some(our_node) => node.is_up && !our_node.is_up() && our_node.addr != node.addr, - }; - if ping_them { - to_ping.push((node.addr, Some(node.id))); - } - } - } - if has_changed { - status.recalculate_hash(); - } - update_lock.0.broadcast(Arc::new(status))?; - drop(update_lock); - - if to_ping.len() > 0 { - self.background - .spawn_cancellable(self.clone().ping_nodes(to_ping).map(Ok)); - } - - Ok(Message::Ok) - } - - pub async fn handle_advertise_config( - self: Arc, - adv: &NetworkConfig, - ) -> Result { - let update_lock = self.update_lock.lock().await; - let mut ring: Ring = self.ring.borrow().as_ref().clone(); - - if adv.version > ring.config.version { - ring.config = adv.clone(); - ring.rebuild_ring(); - update_lock.1.broadcast(Arc::new(ring))?; - drop(update_lock); - - self.background.spawn_cancellable( - self.clone() - .broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT) - .map(Ok), - ); - self.background.spawn(self.clone().save_network_config()); - } - - Ok(Message::Ok) - } - - pub async fn ping_loop(self: Arc, mut stop_signal: watch::Receiver) { - loop { - let restart_at = tokio::time::delay_for(PING_INTERVAL); - - let status = self.status.borrow().clone(); - let ping_addrs = status - .nodes - .iter() - .filter(|(id, _)| **id != self.id) - .map(|(id, status)| (status.addr, Some(*id))) - .collect::>(); - - self.clone().ping_nodes(ping_addrs).await; - - select! { - _ = restart_at.fuse() => (), - must_exit = stop_signal.recv().fuse() => { - match must_exit { - None | Some(true) => return, - _ => (), - } - } - } - } - } - - pub fn pull_status( - self: Arc, - peer: UUID, - ) -> impl futures::future::Future + Send + 'static { - async move { - let resp = self - .rpc_client - .call(peer, Message::PullStatus, PING_TIMEOUT) - .await; - if let Ok(Message::AdvertiseNodesUp(nodes)) = resp { - let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await; - } - } - } - - pub async fn pull_config(self: Arc, peer: UUID) { - let resp = self - .rpc_client - .call(peer, Message::PullConfig, PING_TIMEOUT) - .await; - if let Ok(Message::AdvertiseConfig(config)) = resp { - let _: Result<_, _> = self.handle_advertise_config(&config).await; - } - } -} diff --git a/src/object_table.rs b/src/object_table.rs deleted file mode 100644 index edad4925..00000000 --- a/src/object_table.rs +++ /dev/null @@ -1,133 +0,0 @@ -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; - -use crate::background::BackgroundRunner; -use crate::data::*; -use crate::error::Error; -use crate::table::*; -use crate::table_sharded::*; - -use crate::version_table::*; - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub struct Object { - // Primary key - pub bucket: String, - - // Sort key - pub key: String, - - // Data - pub versions: Vec>, -} - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub struct ObjectVersion { - pub uuid: UUID, - pub timestamp: u64, - - pub mime_type: String, - pub size: u64, - pub is_complete: bool, - - pub data: ObjectVersionData, -} - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub enum ObjectVersionData { - DeleteMarker, - Inline(#[serde(with = "serde_bytes")] Vec), - FirstBlock(Hash), -} - -impl ObjectVersion { - fn cmp_key(&self) -> (u64, &UUID) { - (self.timestamp, &self.uuid) - } -} - -impl Entry for Object { - fn partition_key(&self) -> &String { - &self.bucket - } - fn sort_key(&self) -> &String { - &self.key - } - - fn merge(&mut self, other: &Self) { - for other_v in other.versions.iter() { - match self - .versions - .binary_search_by(|v| v.cmp_key().cmp(&other_v.cmp_key())) - { - Ok(i) => { - let mut v = &mut self.versions[i]; - if other_v.size > v.size { - v.size = other_v.size; - } - if other_v.is_complete && !v.is_complete { - v.is_complete = true; - } - } - Err(i) => { - self.versions.insert(i, other_v.clone()); - } - } - } - let last_complete = self - .versions - .iter() - .enumerate() - .rev() - .filter(|(_, v)| v.is_complete) - .next() - .map(|(vi, _)| vi); - - if let Some(last_vi) = last_complete { - self.versions = self.versions.drain(last_vi..).collect::>(); - } - } -} - -pub struct ObjectTable { - pub background: Arc, - pub version_table: Arc>, -} - -#[async_trait] -impl TableSchema for ObjectTable { - type P = String; - type S = String; - type E = Object; - type Filter = (); - - async fn updated(&self, old: Option, new: Option) -> Result<(), Error> { - let version_table = self.version_table.clone(); - if let (Some(old_v), Some(new_v)) = (old, new) { - // Propagate deletion of old versions - for v in old_v.versions.iter() { - if new_v - .versions - .binary_search_by(|nv| nv.cmp_key().cmp(&v.cmp_key())) - .is_err() - { - let deleted_version = Version { - uuid: v.uuid, - deleted: true, - blocks: vec![], - bucket: old_v.bucket.clone(), - key: old_v.key.clone(), - }; - version_table.insert(&deleted_version).await?; - } - } - } - Ok(()) - } - - fn matches_filter(_entry: &Self::E, _filter: &Self::Filter) -> bool { - // TODO - true - } -} diff --git a/src/rpc/membership.rs b/src/rpc/membership.rs new file mode 100644 index 00000000..e0509536 --- /dev/null +++ b/src/rpc/membership.rs @@ -0,0 +1,692 @@ +use std::collections::HashMap; +use std::hash::Hash as StdHash; +use std::hash::Hasher; +use std::io::{Read, Write}; +use std::net::{IpAddr, SocketAddr}; +use std::path::PathBuf; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use futures::future::join_all; +use futures::select; +use futures_util::future::*; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use tokio::prelude::*; +use tokio::sync::watch; +use tokio::sync::Mutex; + +use crate::background::BackgroundRunner; +use crate::data::*; +use crate::error::Error; + +use crate::rpc::rpc_client::*; +use crate::rpc::rpc_server::*; + +const PING_INTERVAL: Duration = Duration::from_secs(10); +const PING_TIMEOUT: Duration = Duration::from_secs(2); +const MAX_FAILURES_BEFORE_CONSIDERED_DOWN: usize = 5; + +pub const MEMBERSHIP_RPC_PATH: &str = "_membership"; + +#[derive(Debug, Serialize, Deserialize)] +pub enum Message { + Ok, + Ping(PingMessage), + PullStatus, + PullConfig, + AdvertiseNodesUp(Vec), + AdvertiseConfig(NetworkConfig), +} + +impl RpcMessage for Message {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PingMessage { + pub id: UUID, + pub rpc_port: u16, + + pub status_hash: Hash, + pub config_version: u64, + + pub state_info: StateInfo, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AdvertisedNode { + pub id: UUID, + pub addr: SocketAddr, + + pub is_up: bool, + pub last_seen: u64, + + pub state_info: StateInfo, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NetworkConfig { + pub members: HashMap, + pub version: u64, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NetworkConfigEntry { + pub datacenter: String, + pub n_tokens: u32, + pub tag: String, +} + +pub struct System { + pub id: UUID, + pub data_dir: PathBuf, + pub rpc_local_port: u16, + + pub state_info: StateInfo, + + pub rpc_http_client: Arc, + rpc_client: Arc>, + + pub status: watch::Receiver>, + pub ring: watch::Receiver>, + + update_lock: Mutex<(watch::Sender>, watch::Sender>)>, + + pub background: Arc, +} + +#[derive(Debug, Clone)] +pub struct Status { + pub nodes: HashMap>, + pub hash: Hash, +} + +#[derive(Debug)] +pub struct StatusEntry { + pub addr: SocketAddr, + pub last_seen: u64, + pub num_failures: AtomicUsize, + pub state_info: StateInfo, +} + +impl StatusEntry { + pub fn is_up(&self) -> bool { + self.num_failures.load(Ordering::SeqCst) < MAX_FAILURES_BEFORE_CONSIDERED_DOWN + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateInfo { + pub hostname: String, +} + +#[derive(Clone)] +pub struct Ring { + pub config: NetworkConfig, + pub ring: Vec, + pub n_datacenters: usize, +} + +#[derive(Clone, Debug)] +pub struct RingEntry { + pub location: Hash, + pub node: UUID, + pub datacenter: u64, +} + +impl Status { + fn handle_ping(&mut self, ip: IpAddr, info: &PingMessage) -> bool { + let addr = SocketAddr::new(ip, info.rpc_port); + let old_status = self.nodes.insert( + info.id, + Arc::new(StatusEntry { + addr, + last_seen: now_msec(), + num_failures: AtomicUsize::from(0), + state_info: info.state_info.clone(), + }), + ); + match old_status { + None => { + info!("Newly pingable node: {}", hex::encode(&info.id)); + true + } + Some(x) => x.addr != addr, + } + } + + fn recalculate_hash(&mut self) { + let mut nodes = self.nodes.iter().collect::>(); + nodes.sort_unstable_by_key(|(id, _status)| *id); + + let mut hasher = Sha256::new(); + debug!("Current set of pingable nodes: --"); + for (id, status) in nodes { + debug!("{} {}", hex::encode(&id), status.addr); + hasher.input(format!("{} {}\n", hex::encode(&id), status.addr)); + } + debug!("END --"); + self.hash + .as_slice_mut() + .copy_from_slice(&hasher.result()[..]); + } +} + +impl Ring { + fn rebuild_ring(&mut self) { + let mut new_ring = vec![]; + let mut datacenters = vec![]; + + for (id, config) in self.config.members.iter() { + let mut dc_hasher = std::collections::hash_map::DefaultHasher::new(); + config.datacenter.hash(&mut dc_hasher); + let datacenter = dc_hasher.finish(); + + if !datacenters.contains(&datacenter) { + datacenters.push(datacenter); + } + + for i in 0..config.n_tokens { + let location = hash(format!("{} {}", hex::encode(&id), i).as_bytes()); + + new_ring.push(RingEntry { + location: location.into(), + node: *id, + datacenter, + }) + } + } + + new_ring.sort_unstable_by(|x, y| x.location.cmp(&y.location)); + self.ring = new_ring; + self.n_datacenters = datacenters.len(); + + // eprintln!("RING: --"); + // for e in self.ring.iter() { + // eprintln!("{:?}", e); + // } + // eprintln!("END --"); + } + + pub fn walk_ring(&self, from: &Hash, n: usize) -> Vec { + if n >= self.config.members.len() { + return self.config.members.keys().cloned().collect::>(); + } + + let start = match self.ring.binary_search_by(|x| x.location.cmp(from)) { + Ok(i) => i, + Err(i) => { + if i == 0 { + self.ring.len() - 1 + } else { + i - 1 + } + } + }; + + self.walk_ring_from_pos(start, n) + } + + fn walk_ring_from_pos(&self, start: usize, n: usize) -> Vec { + if n >= self.config.members.len() { + return self.config.members.keys().cloned().collect::>(); + } + + let mut ret = vec![]; + let mut datacenters = vec![]; + + let mut delta = 0; + while ret.len() < n { + let i = (start + delta) % self.ring.len(); + delta += 1; + + if !datacenters.contains(&self.ring[i].datacenter) { + ret.push(self.ring[i].node); + datacenters.push(self.ring[i].datacenter); + } else if datacenters.len() == self.n_datacenters && !ret.contains(&self.ring[i].node) { + ret.push(self.ring[i].node); + } + } + + ret + } +} + +fn gen_node_id(metadata_dir: &PathBuf) -> Result { + let mut id_file = metadata_dir.clone(); + id_file.push("node_id"); + if id_file.as_path().exists() { + let mut f = std::fs::File::open(id_file.as_path())?; + let mut d = vec![]; + f.read_to_end(&mut d)?; + if d.len() != 32 { + return Err(Error::Message(format!("Corrupt node_id file"))); + } + + let mut id = [0u8; 32]; + id.copy_from_slice(&d[..]); + Ok(id.into()) + } else { + let id = gen_uuid(); + + let mut f = std::fs::File::create(id_file.as_path())?; + f.write_all(id.as_slice())?; + Ok(id) + } +} + +fn read_network_config(metadata_dir: &PathBuf) -> Result { + let mut path = metadata_dir.clone(); + path.push("network_config"); + + let mut file = std::fs::OpenOptions::new() + .read(true) + .open(path.as_path())?; + + let mut net_config_bytes = vec![]; + file.read_to_end(&mut net_config_bytes)?; + + let net_config = rmp_serde::decode::from_read_ref(&net_config_bytes[..]) + .expect("Unable to parse network configuration file (has version format changed?)."); + + Ok(net_config) +} + +impl System { + pub fn new( + data_dir: PathBuf, + rpc_http_client: Arc, + background: Arc, + rpc_server: &mut RpcServer, + ) -> Arc { + let id = gen_node_id(&data_dir).expect("Unable to read or generate node ID"); + info!("Node ID: {}", hex::encode(&id)); + + let net_config = match read_network_config(&data_dir) { + Ok(x) => x, + Err(e) => { + info!( + "No valid previous network configuration stored ({}), starting fresh.", + e + ); + NetworkConfig { + members: HashMap::new(), + version: 0, + } + } + }; + let mut status = Status { + nodes: HashMap::new(), + hash: Hash::default(), + }; + status.recalculate_hash(); + let (update_status, status) = watch::channel(Arc::new(status)); + + let state_info = StateInfo { + hostname: gethostname::gethostname() + .into_string() + .unwrap_or("".to_string()), + }; + + let mut ring = Ring { + config: net_config, + ring: Vec::new(), + n_datacenters: 0, + }; + ring.rebuild_ring(); + let (update_ring, ring) = watch::channel(Arc::new(ring)); + + let rpc_path = MEMBERSHIP_RPC_PATH.to_string(); + let rpc_client = RpcClient::new( + RpcAddrClient::::new(rpc_http_client.clone(), rpc_path.clone()), + background.clone(), + status.clone(), + ); + + let sys = Arc::new(System { + id, + data_dir, + rpc_local_port: rpc_server.bind_addr.port(), + state_info, + rpc_http_client, + rpc_client, + status, + ring, + update_lock: Mutex::new((update_status, update_ring)), + background, + }); + sys.clone().register_handler(rpc_server, rpc_path); + sys + } + + fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + rpc_server.add_handler::(path, move |msg, addr| { + let self2 = self.clone(); + async move { + match msg { + Message::Ping(ping) => self2.handle_ping(&addr, &ping).await, + + Message::PullStatus => self2.handle_pull_status(), + Message::PullConfig => self2.handle_pull_config(), + Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await, + Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await, + + _ => Err(Error::BadRequest(format!("Unexpected RPC message"))), + } + } + }); + } + + pub fn rpc_client(self: &Arc, path: &str) -> Arc> { + RpcClient::new( + RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()), + self.background.clone(), + self.status.clone(), + ) + } + + async fn save_network_config(self: Arc) -> Result<(), Error> { + let mut path = self.data_dir.clone(); + path.push("network_config"); + + let ring = self.ring.borrow().clone(); + let data = rmp_to_vec_all_named(&ring.config)?; + + let mut f = tokio::fs::File::create(path.as_path()).await?; + f.write_all(&data[..]).await?; + Ok(()) + } + + pub fn make_ping(&self) -> Message { + let status = self.status.borrow().clone(); + let ring = self.ring.borrow().clone(); + Message::Ping(PingMessage { + id: self.id, + rpc_port: self.rpc_local_port, + status_hash: status.hash, + config_version: ring.config.version, + state_info: self.state_info.clone(), + }) + } + + pub async fn broadcast(self: Arc, msg: Message, timeout: Duration) { + let status = self.status.borrow().clone(); + let to = status + .nodes + .keys() + .filter(|x| **x != self.id) + .cloned() + .collect::>(); + self.rpc_client.call_many(&to[..], msg, timeout).await; + } + + pub async fn bootstrap(self: Arc, peers: &[SocketAddr]) { + let bootstrap_peers = peers.iter().map(|ip| (*ip, None)).collect::>(); + self.clone().ping_nodes(bootstrap_peers).await; + + self.clone() + .background + .spawn_worker(format!("ping loop"), |stop_signal| { + self.ping_loop(stop_signal).map(Ok) + }) + .await; + } + + async fn ping_nodes(self: Arc, peers: Vec<(SocketAddr, Option)>) { + let ping_msg = self.make_ping(); + let ping_resps = join_all(peers.iter().map(|(addr, id_option)| { + let sys = self.clone(); + let ping_msg_ref = &ping_msg; + async move { + ( + id_option, + addr, + sys.rpc_client + .by_addr() + .call(&addr, ping_msg_ref, PING_TIMEOUT) + .await, + ) + } + })) + .await; + + let update_locked = self.update_lock.lock().await; + let mut status: Status = self.status.borrow().as_ref().clone(); + let ring = self.ring.borrow().clone(); + + let mut has_changes = false; + let mut to_advertise = vec![]; + + for (id_option, addr, ping_resp) in ping_resps { + if let Ok(Ok(Message::Ping(info))) = ping_resp { + let is_new = status.handle_ping(addr.ip(), &info); + if is_new { + has_changes = true; + to_advertise.push(AdvertisedNode { + id: info.id, + addr: *addr, + is_up: true, + last_seen: now_msec(), + state_info: info.state_info.clone(), + }); + } + if is_new || status.hash != info.status_hash { + self.background + .spawn_cancellable(self.clone().pull_status(info.id).map(Ok)); + } + if is_new || ring.config.version < info.config_version { + self.background + .spawn_cancellable(self.clone().pull_config(info.id).map(Ok)); + } + } else if let Some(id) = id_option { + if let Some(st) = status.nodes.get_mut(id) { + st.num_failures.fetch_add(1, Ordering::SeqCst); + if !st.is_up() { + warn!("Node {:?} seems to be down.", id); + if !ring.config.members.contains_key(id) { + info!("Removing node {:?} from status (not in config and not responding to pings anymore)", id); + drop(st); + status.nodes.remove(&id); + has_changes = true; + } + } + } + } + } + if has_changes { + status.recalculate_hash(); + } + if let Err(e) = update_locked.0.broadcast(Arc::new(status)) { + error!("In ping_nodes: could not save status update ({})", e); + } + drop(update_locked); + + if to_advertise.len() > 0 { + self.broadcast(Message::AdvertiseNodesUp(to_advertise), PING_TIMEOUT) + .await; + } + } + + pub async fn handle_ping( + self: Arc, + from: &SocketAddr, + ping: &PingMessage, + ) -> Result { + let update_locked = self.update_lock.lock().await; + let mut status: Status = self.status.borrow().as_ref().clone(); + + let is_new = status.handle_ping(from.ip(), ping); + if is_new { + status.recalculate_hash(); + } + let status_hash = status.hash; + let config_version = self.ring.borrow().config.version; + + update_locked.0.broadcast(Arc::new(status))?; + drop(update_locked); + + if is_new || status_hash != ping.status_hash { + self.background + .spawn_cancellable(self.clone().pull_status(ping.id).map(Ok)); + } + if is_new || config_version < ping.config_version { + self.background + .spawn_cancellable(self.clone().pull_config(ping.id).map(Ok)); + } + + Ok(self.make_ping()) + } + + pub fn handle_pull_status(&self) -> Result { + let status = self.status.borrow().clone(); + let mut mem = vec![]; + for (node, status) in status.nodes.iter() { + let state_info = if *node == self.id { + self.state_info.clone() + } else { + status.state_info.clone() + }; + mem.push(AdvertisedNode { + id: *node, + addr: status.addr, + is_up: status.is_up(), + last_seen: status.last_seen, + state_info, + }); + } + Ok(Message::AdvertiseNodesUp(mem)) + } + + pub fn handle_pull_config(&self) -> Result { + let ring = self.ring.borrow().clone(); + Ok(Message::AdvertiseConfig(ring.config.clone())) + } + + pub async fn handle_advertise_nodes_up( + self: Arc, + adv: &[AdvertisedNode], + ) -> Result { + let mut to_ping = vec![]; + + let update_lock = self.update_lock.lock().await; + let mut status: Status = self.status.borrow().as_ref().clone(); + let mut has_changed = false; + + for node in adv.iter() { + if node.id == self.id { + // learn our own ip address + let self_addr = SocketAddr::new(node.addr.ip(), self.rpc_local_port); + let old_self = status.nodes.insert( + node.id, + Arc::new(StatusEntry { + addr: self_addr, + last_seen: now_msec(), + num_failures: AtomicUsize::from(0), + state_info: self.state_info.clone(), + }), + ); + has_changed = match old_self { + None => true, + Some(x) => x.addr != self_addr, + }; + } else { + let ping_them = match status.nodes.get(&node.id) { + // Case 1: new node + None => true, + // Case 2: the node might have changed address + Some(our_node) => node.is_up && !our_node.is_up() && our_node.addr != node.addr, + }; + if ping_them { + to_ping.push((node.addr, Some(node.id))); + } + } + } + if has_changed { + status.recalculate_hash(); + } + update_lock.0.broadcast(Arc::new(status))?; + drop(update_lock); + + if to_ping.len() > 0 { + self.background + .spawn_cancellable(self.clone().ping_nodes(to_ping).map(Ok)); + } + + Ok(Message::Ok) + } + + pub async fn handle_advertise_config( + self: Arc, + adv: &NetworkConfig, + ) -> Result { + let update_lock = self.update_lock.lock().await; + let mut ring: Ring = self.ring.borrow().as_ref().clone(); + + if adv.version > ring.config.version { + ring.config = adv.clone(); + ring.rebuild_ring(); + update_lock.1.broadcast(Arc::new(ring))?; + drop(update_lock); + + self.background.spawn_cancellable( + self.clone() + .broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT) + .map(Ok), + ); + self.background.spawn(self.clone().save_network_config()); + } + + Ok(Message::Ok) + } + + pub async fn ping_loop(self: Arc, mut stop_signal: watch::Receiver) { + loop { + let restart_at = tokio::time::delay_for(PING_INTERVAL); + + let status = self.status.borrow().clone(); + let ping_addrs = status + .nodes + .iter() + .filter(|(id, _)| **id != self.id) + .map(|(id, status)| (status.addr, Some(*id))) + .collect::>(); + + self.clone().ping_nodes(ping_addrs).await; + + select! { + _ = restart_at.fuse() => (), + must_exit = stop_signal.recv().fuse() => { + match must_exit { + None | Some(true) => return, + _ => (), + } + } + } + } + } + + pub fn pull_status( + self: Arc, + peer: UUID, + ) -> impl futures::future::Future + Send + 'static { + async move { + let resp = self + .rpc_client + .call(peer, Message::PullStatus, PING_TIMEOUT) + .await; + if let Ok(Message::AdvertiseNodesUp(nodes)) = resp { + let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await; + } + } + } + + pub async fn pull_config(self: Arc, peer: UUID) { + let resp = self + .rpc_client + .call(peer, Message::PullConfig, PING_TIMEOUT) + .await; + if let Ok(Message::AdvertiseConfig(config)) = resp { + let _: Result<_, _> = self.handle_advertise_config(&config).await; + } + } +} diff --git a/src/rpc/mod.rs b/src/rpc/mod.rs new file mode 100644 index 00000000..83fd0aac --- /dev/null +++ b/src/rpc/mod.rs @@ -0,0 +1,4 @@ +pub mod membership; +pub mod rpc_client; +pub mod rpc_server; +pub mod tls_util; diff --git a/src/rpc/rpc_client.rs b/src/rpc/rpc_client.rs new file mode 100644 index 00000000..027a3cde --- /dev/null +++ b/src/rpc/rpc_client.rs @@ -0,0 +1,360 @@ +use std::borrow::Borrow; +use std::marker::PhantomData; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwapOption; +use bytes::IntoBuf; +use err_derive::Error; +use futures::future::Future; +use futures::stream::futures_unordered::FuturesUnordered; +use futures::stream::StreamExt; +use futures_util::future::FutureExt; +use hyper::client::{Client, HttpConnector}; +use hyper::{Body, Method, Request}; +use tokio::sync::{watch, Semaphore}; + +use crate::background::BackgroundRunner; +use crate::data::*; +use crate::error::Error; + +use crate::rpc::membership::Status; +use crate::rpc::rpc_server::RpcMessage; +use crate::rpc::tls_util; + +use crate::config::TlsConfig; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + +#[derive(Debug, Error)] +pub enum RPCError { + #[error(display = "Node is down: {:?}.", _0)] + NodeDown(UUID), + #[error(display = "Timeout: {}", _0)] + Timeout(#[error(source)] tokio::time::Elapsed), + #[error(display = "HTTP error: {}", _0)] + HTTP(#[error(source)] http::Error), + #[error(display = "Hyper error: {}", _0)] + Hyper(#[error(source)] hyper::Error), + #[error(display = "Messagepack encode error: {}", _0)] + RMPEncode(#[error(source)] rmp_serde::encode::Error), + #[error(display = "Messagepack decode error: {}", _0)] + RMPDecode(#[error(source)] rmp_serde::decode::Error), + #[error(display = "Too many errors: {:?}", _0)] + TooManyErrors(Vec), +} + +#[derive(Copy, Clone)] +pub struct RequestStrategy { + pub rs_timeout: Duration, + pub rs_quorum: usize, + pub rs_interrupt_after_quorum: bool, +} + +impl RequestStrategy { + pub fn with_quorum(quorum: usize) -> Self { + RequestStrategy { + rs_timeout: DEFAULT_TIMEOUT, + rs_quorum: quorum, + rs_interrupt_after_quorum: false, + } + } + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.rs_timeout = timeout; + self + } + pub fn interrupt_after_quorum(mut self, interrupt: bool) -> Self { + self.rs_interrupt_after_quorum = interrupt; + self + } +} + +pub type LocalHandlerFn = + Box) -> Pin> + Send>> + Send + Sync>; + +pub struct RpcClient { + status: watch::Receiver>, + background: Arc, + + local_handler: ArcSwapOption<(UUID, LocalHandlerFn)>, + + pub rpc_addr_client: RpcAddrClient, +} + +impl RpcClient { + pub fn new( + rac: RpcAddrClient, + background: Arc, + status: watch::Receiver>, + ) -> Arc { + Arc::new(Self { + rpc_addr_client: rac, + background, + status, + local_handler: ArcSwapOption::new(None), + }) + } + + pub fn set_local_handler(&self, my_id: UUID, handler: F) + where + F: Fn(Arc) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + let handler_arc = Arc::new(handler); + let handler: LocalHandlerFn = Box::new(move |msg| { + let handler_arc2 = handler_arc.clone(); + Box::pin(async move { handler_arc2(msg).await }) + }); + self.local_handler.swap(Some(Arc::new((my_id, handler)))); + } + + pub fn by_addr(&self) -> &RpcAddrClient { + &self.rpc_addr_client + } + + pub async fn call(&self, to: UUID, msg: M, timeout: Duration) -> Result { + self.call_arc(to, Arc::new(msg), timeout).await + } + + pub async fn call_arc(&self, to: UUID, msg: Arc, timeout: Duration) -> Result { + if let Some(lh) = self.local_handler.load_full() { + let (my_id, local_handler) = lh.as_ref(); + if to.borrow() == my_id { + return local_handler(msg).await; + } + } + let status = self.status.borrow().clone(); + let node_status = match status.nodes.get(&to) { + Some(node_status) => { + if node_status.is_up() { + node_status + } else { + return Err(Error::from(RPCError::NodeDown(to))); + } + } + None => { + return Err(Error::Message(format!( + "Peer ID not found: {:?}", + to.borrow() + ))) + } + }; + match self + .rpc_addr_client + .call(&node_status.addr, msg, timeout) + .await + { + Err(rpc_error) => { + node_status.num_failures.fetch_add(1, Ordering::SeqCst); + // TODO: Save failure info somewhere + Err(Error::from(rpc_error)) + } + Ok(x) => x, + } + } + + pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec> { + let msg = Arc::new(msg); + let mut resp_stream = to + .iter() + .map(|to| self.call_arc(*to, msg.clone(), timeout)) + .collect::>(); + + let mut results = vec![]; + while let Some(resp) = resp_stream.next().await { + results.push(resp); + } + results + } + + pub async fn try_call_many( + self: &Arc, + to: &[UUID], + msg: M, + strategy: RequestStrategy, + ) -> Result, Error> { + let timeout = strategy.rs_timeout; + + let msg = Arc::new(msg); + let mut resp_stream = to + .to_vec() + .into_iter() + .map(|to| { + let self2 = self.clone(); + let msg = msg.clone(); + async move { self2.call_arc(to, msg, timeout).await } + }) + .collect::>(); + + let mut results = vec![]; + let mut errors = vec![]; + + while let Some(resp) = resp_stream.next().await { + match resp { + Ok(msg) => { + results.push(msg); + if results.len() >= strategy.rs_quorum { + break; + } + } + Err(e) => { + errors.push(e); + } + } + } + + if results.len() >= strategy.rs_quorum { + // Continue requests in background. + // Continue the remaining requests immediately using tokio::spawn + // but enqueue a task in the background runner + // to ensure that the process won't exit until the requests are done + // (if we had just enqueued the resp_stream.collect directly in the background runner, + // the requests might have been put on hold in the background runner's queue, + // in which case they might timeout or otherwise fail) + if !strategy.rs_interrupt_after_quorum { + let wait_finished_fut = tokio::spawn(async move { + resp_stream.collect::>().await; + Ok(()) + }); + self.background.spawn(wait_finished_fut.map(|x| { + x.unwrap_or_else(|e| Err(Error::Message(format!("Await failed: {}", e)))) + })); + } + + Ok(results) + } else { + let errors = errors.iter().map(|e| format!("{}", e)).collect::>(); + Err(Error::from(RPCError::TooManyErrors(errors))) + } + } +} + +pub struct RpcAddrClient { + phantom: PhantomData, + + pub http_client: Arc, + pub path: String, +} + +impl RpcAddrClient { + pub fn new(http_client: Arc, path: String) -> Self { + Self { + phantom: PhantomData::default(), + http_client: http_client, + path, + } + } + + pub async fn call( + &self, + to_addr: &SocketAddr, + msg: MB, + timeout: Duration, + ) -> Result, RPCError> + where + MB: Borrow, + { + self.http_client + .call(&self.path, to_addr, msg, timeout) + .await + } +} + +pub struct RpcHttpClient { + request_limiter: Semaphore, + method: ClientMethod, +} + +enum ClientMethod { + HTTP(Client), + HTTPS(Client, hyper::Body>), +} + +impl RpcHttpClient { + pub fn new( + max_concurrent_requests: usize, + tls_config: &Option, + ) -> Result { + let method = if let Some(cf) = tls_config { + let ca_certs = tls_util::load_certs(&cf.ca_cert)?; + let node_certs = tls_util::load_certs(&cf.node_cert)?; + let node_key = tls_util::load_private_key(&cf.node_key)?; + + let mut config = rustls::ClientConfig::new(); + + for crt in ca_certs.iter() { + config.root_store.add(crt)?; + } + + config.set_single_client_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; + + let connector = + tls_util::HttpsConnectorFixedDnsname::::new(config, "garage"); + + ClientMethod::HTTPS(Client::builder().build(connector)) + } else { + ClientMethod::HTTP(Client::new()) + }; + Ok(RpcHttpClient { + method, + request_limiter: Semaphore::new(max_concurrent_requests), + }) + } + + async fn call( + &self, + path: &str, + to_addr: &SocketAddr, + msg: MB, + timeout: Duration, + ) -> Result, RPCError> + where + MB: Borrow, + M: RpcMessage, + { + let uri = match self.method { + ClientMethod::HTTP(_) => format!("http://{}/{}", to_addr, path), + ClientMethod::HTTPS(_) => format!("https://{}/{}", to_addr, path), + }; + + let req = Request::builder() + .method(Method::POST) + .uri(uri) + .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; + + let resp_fut = match &self.method { + ClientMethod::HTTP(client) => client.request(req).fuse(), + ClientMethod::HTTPS(client) => client.request(req).fuse(), + }; + + let slot = self.request_limiter.acquire().await; + let resp = tokio::time::timeout(timeout, resp_fut) + .await + .map_err(|e| { + debug!( + "RPC timeout to {}: {}", + to_addr, + debug_serialize(msg.borrow()) + ); + e + })? + .map_err(|e| { + warn!( + "RPC HTTP client error when connecting to {}: {}", + to_addr, e + ); + e + })?; + drop(slot); + + let status = resp.status(); + let body = hyper::body::to_bytes(resp.into_body()).await?; + match rmp_serde::decode::from_read::<_, Result>(body.into_buf())? { + Err(e) => Ok(Err(Error::RemoteError(e, status))), + Ok(x) => Ok(Ok(x)), + } + } +} diff --git a/src/rpc/rpc_server.rs b/src/rpc/rpc_server.rs new file mode 100644 index 00000000..4ee53909 --- /dev/null +++ b/src/rpc/rpc_server.rs @@ -0,0 +1,219 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Instant; + +use bytes::IntoBuf; +use futures::future::Future; +use futures_util::future::*; +use futures_util::stream::*; +use hyper::server::conn::AddrStream; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use serde::{Deserialize, Serialize}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_rustls::server::TlsStream; +use tokio_rustls::TlsAcceptor; + +use crate::config::TlsConfig; +use crate::data::*; +use crate::error::Error; + +use crate::rpc::tls_util; + +pub trait RpcMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {} + +type ResponseFuture = Pin, Error>> + Send>>; +type Handler = Box, SocketAddr) -> ResponseFuture + Send + Sync>; + +pub struct RpcServer { + pub bind_addr: SocketAddr, + pub tls_config: Option, + + handlers: HashMap, +} + +async fn handle_func( + handler: Arc, + req: Request, + sockaddr: SocketAddr, + name: Arc, +) -> Result, Error> +where + M: RpcMessage + 'static, + F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + let begin_time = Instant::now(); + let whole_body = hyper::body::to_bytes(req.into_body()).await?; + let msg = rmp_serde::decode::from_read::<_, M>(whole_body.into_buf())?; + match handler(msg, sockaddr).await { + Ok(resp) => { + let resp_bytes = rmp_to_vec_all_named::>(&Ok(resp))?; + let rpc_duration = (Instant::now() - begin_time).as_millis(); + if rpc_duration > 100 { + debug!("RPC {} ok, took long: {} ms", name, rpc_duration,); + } + Ok(Response::new(Body::from(resp_bytes))) + } + Err(e) => { + let err_str = format!("{}", e); + let rep_bytes = rmp_to_vec_all_named::>(&Err(err_str))?; + let mut err_response = Response::new(Body::from(rep_bytes)); + *err_response.status_mut() = e.http_status_code(); + warn!( + "RPC error ({}): {} ({} ms)", + name, + e, + (Instant::now() - begin_time).as_millis(), + ); + Ok(err_response) + } + } +} + +impl RpcServer { + pub fn new(bind_addr: SocketAddr, tls_config: Option) -> Self { + Self { + bind_addr, + tls_config, + handlers: HashMap::new(), + } + } + + pub fn add_handler(&mut self, name: String, handler: F) + where + M: RpcMessage + 'static, + F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + let name2 = Arc::new(name.clone()); + let handler_arc = Arc::new(handler); + let handler = Box::new(move |req: Request, sockaddr: SocketAddr| { + let handler2 = handler_arc.clone(); + let b: ResponseFuture = Box::pin(handle_func(handler2, req, sockaddr, name2.clone())); + b + }); + self.handlers.insert(name, handler); + } + + async fn handler( + self: Arc, + req: Request, + addr: SocketAddr, + ) -> Result, Error> { + if req.method() != &Method::POST { + let mut bad_request = Response::default(); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } + + let path = &req.uri().path()[1..]; + let handler = match self.handlers.get(path) { + Some(h) => h, + None => { + let mut not_found = Response::default(); + *not_found.status_mut() = StatusCode::NOT_FOUND; + return Ok(not_found); + } + }; + + let resp_waiter = tokio::spawn(handler(req, addr)); + match resp_waiter.await { + Err(err) => { + warn!("Handler await error: {}", err); + let mut ise = Response::default(); + *ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(ise) + } + Ok(Err(err)) => { + let mut bad_request = Response::new(Body::from(format!("{}", err))); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + Ok(bad_request) + } + Ok(Ok(resp)) => Ok(resp), + } + } + + pub async fn run( + self: Arc, + shutdown_signal: impl Future, + ) -> Result<(), Error> { + if let Some(tls_config) = self.tls_config.as_ref() { + let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; + let node_certs = tls_util::load_certs(&tls_config.node_cert)?; + let node_key = tls_util::load_private_key(&tls_config.node_key)?; + + let mut ca_store = rustls::RootCertStore::empty(); + for crt in ca_certs.iter() { + ca_store.add(crt)?; + } + + let mut config = + rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store)); + config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; + let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); + + let mut listener = TcpListener::bind(&self.bind_addr).await?; + let incoming = listener.incoming().filter_map(|socket| async { + match socket { + Ok(stream) => match tls_acceptor.clone().accept(stream).await { + Ok(x) => Some(Ok::<_, hyper::Error>(x)), + Err(_e) => None, + }, + Err(_) => None, + } + }); + let incoming = hyper::server::accept::from_stream(incoming); + + let self_arc = self.clone(); + let service = make_service_fn(|conn: &TlsStream| { + let client_addr = conn + .get_ref() + .0 + .peer_addr() + .unwrap_or(([0, 0, 0, 0], 0).into()); + let self_arc = self_arc.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request| { + self_arc.clone().handler(req, client_addr).map_err(|e| { + warn!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::builder(incoming).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + info!("RPC server listening on http://{}", self.bind_addr); + + graceful.await?; + } else { + let self_arc = self.clone(); + let service = make_service_fn(move |conn: &AddrStream| { + let client_addr = conn.remote_addr(); + let self_arc = self_arc.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request| { + self_arc.clone().handler(req, client_addr).map_err(|e| { + warn!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::bind(&self.bind_addr).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + info!("RPC server listening on http://{}", self.bind_addr); + + graceful.await?; + } + + Ok(()) + } +} diff --git a/src/rpc/tls_util.rs b/src/rpc/tls_util.rs new file mode 100644 index 00000000..52c52110 --- /dev/null +++ b/src/rpc/tls_util.rs @@ -0,0 +1,139 @@ +use core::future::Future; +use core::task::{Context, Poll}; +use std::pin::Pin; +use std::sync::Arc; +use std::{fs, io}; + +use futures_util::future::*; +use hyper::client::connect::Connection; +use hyper::client::HttpConnector; +use hyper::service::Service; +use hyper::Uri; +use hyper_rustls::MaybeHttpsStream; +use rustls::internal::pemfile; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::TlsConnector; +use webpki::DNSNameRef; + +use crate::error::Error; + +pub fn load_certs(filename: &str) -> Result, Error> { + let certfile = fs::File::open(&filename)?; + let mut reader = io::BufReader::new(certfile); + + let certs = pemfile::certs(&mut reader).map_err(|_| { + Error::Message(format!( + "Could not deecode certificates from file: {}", + filename + )) + })?; + + if certs.is_empty() { + return Err(Error::Message(format!( + "Invalid certificate file: {}", + filename + ))); + } + Ok(certs) +} + +pub fn load_private_key(filename: &str) -> Result { + let keyfile = fs::File::open(&filename)?; + let mut reader = io::BufReader::new(keyfile); + + let keys = pemfile::rsa_private_keys(&mut reader).map_err(|_| { + Error::Message(format!( + "Could not decode private key from file: {}", + filename + )) + })?; + + if keys.len() != 1 { + return Err(Error::Message(format!( + "Invalid private key file: {} ({} private keys)", + filename, + keys.len() + ))); + } + Ok(keys[0].clone()) +} + +// ---- AWFUL COPYPASTA FROM HYPER-RUSTLS connector.rs +// ---- ALWAYS USE `garage` AS HOSTNAME FOR TLS VERIFICATION + +#[derive(Clone)] +pub struct HttpsConnectorFixedDnsname { + http: T, + tls_config: Arc, + fixed_dnsname: &'static str, +} + +type BoxError = Box; + +impl HttpsConnectorFixedDnsname { + pub fn new(mut tls_config: rustls::ClientConfig, fixed_dnsname: &'static str) -> Self { + let mut http = HttpConnector::new(); + http.enforce_http(false); + tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + Self { + http, + tls_config: Arc::new(tls_config), + fixed_dnsname, + } + } +} + +impl Service for HttpsConnectorFixedDnsname +where + T: Service, + T::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static, + T::Future: Send + 'static, + T::Error: Into, +{ + type Response = MaybeHttpsStream; + type Error = BoxError; + + #[allow(clippy::type_complexity)] + type Future = + Pin, BoxError>> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.http.poll_ready(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Pending => Poll::Pending, + } + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let is_https = dst.scheme_str() == Some("https"); + + if !is_https { + let connecting_future = self.http.call(dst); + + let f = async move { + let tcp = connecting_future.await.map_err(Into::into)?; + + Ok(MaybeHttpsStream::Http(tcp)) + }; + f.boxed() + } else { + let cfg = self.tls_config.clone(); + let connecting_future = self.http.call(dst); + + let dnsname = + DNSNameRef::try_from_ascii_str(self.fixed_dnsname).expect("Invalid fixed dnsname"); + + let f = async move { + let tcp = connecting_future.await.map_err(Into::into)?; + let connector = TlsConnector::from(cfg); + let tls = connector + .connect(dnsname, tcp) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok(MaybeHttpsStream::Https(tls)) + }; + f.boxed() + } + } +} diff --git a/src/rpc_client.rs b/src/rpc_client.rs deleted file mode 100644 index ba036c60..00000000 --- a/src/rpc_client.rs +++ /dev/null @@ -1,358 +0,0 @@ -use std::borrow::Borrow; -use std::marker::PhantomData; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::Duration; - -use arc_swap::ArcSwapOption; -use bytes::IntoBuf; -use err_derive::Error; -use futures::future::Future; -use futures::stream::futures_unordered::FuturesUnordered; -use futures::stream::StreamExt; -use futures_util::future::FutureExt; -use hyper::client::{Client, HttpConnector}; -use hyper::{Body, Method, Request}; -use tokio::sync::{watch, Semaphore}; - -use crate::background::BackgroundRunner; -use crate::data::*; -use crate::error::Error; -use crate::membership::Status; -use crate::rpc_server::RpcMessage; -use crate::server::TlsConfig; -use crate::tls_util; - -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); - -#[derive(Debug, Error)] -pub enum RPCError { - #[error(display = "Node is down: {:?}.", _0)] - NodeDown(UUID), - #[error(display = "Timeout: {}", _0)] - Timeout(#[error(source)] tokio::time::Elapsed), - #[error(display = "HTTP error: {}", _0)] - HTTP(#[error(source)] http::Error), - #[error(display = "Hyper error: {}", _0)] - Hyper(#[error(source)] hyper::Error), - #[error(display = "Messagepack encode error: {}", _0)] - RMPEncode(#[error(source)] rmp_serde::encode::Error), - #[error(display = "Messagepack decode error: {}", _0)] - RMPDecode(#[error(source)] rmp_serde::decode::Error), - #[error(display = "Too many errors: {:?}", _0)] - TooManyErrors(Vec), -} - -#[derive(Copy, Clone)] -pub struct RequestStrategy { - pub rs_timeout: Duration, - pub rs_quorum: usize, - pub rs_interrupt_after_quorum: bool, -} - -impl RequestStrategy { - pub fn with_quorum(quorum: usize) -> Self { - RequestStrategy { - rs_timeout: DEFAULT_TIMEOUT, - rs_quorum: quorum, - rs_interrupt_after_quorum: false, - } - } - pub fn with_timeout(mut self, timeout: Duration) -> Self { - self.rs_timeout = timeout; - self - } - pub fn interrupt_after_quorum(mut self, interrupt: bool) -> Self { - self.rs_interrupt_after_quorum = interrupt; - self - } -} - -pub type LocalHandlerFn = - Box) -> Pin> + Send>> + Send + Sync>; - -pub struct RpcClient { - status: watch::Receiver>, - background: Arc, - - local_handler: ArcSwapOption<(UUID, LocalHandlerFn)>, - - pub rpc_addr_client: RpcAddrClient, -} - -impl RpcClient { - pub fn new( - rac: RpcAddrClient, - background: Arc, - status: watch::Receiver>, - ) -> Arc { - Arc::new(Self { - rpc_addr_client: rac, - background, - status, - local_handler: ArcSwapOption::new(None), - }) - } - - pub fn set_local_handler(&self, my_id: UUID, handler: F) - where - F: Fn(Arc) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + 'static, - { - let handler_arc = Arc::new(handler); - let handler: LocalHandlerFn = Box::new(move |msg| { - let handler_arc2 = handler_arc.clone(); - Box::pin(async move { handler_arc2(msg).await }) - }); - self.local_handler.swap(Some(Arc::new((my_id, handler)))); - } - - pub fn by_addr(&self) -> &RpcAddrClient { - &self.rpc_addr_client - } - - pub async fn call(&self, to: UUID, msg: M, timeout: Duration) -> Result { - self.call_arc(to, Arc::new(msg), timeout).await - } - - pub async fn call_arc(&self, to: UUID, msg: Arc, timeout: Duration) -> Result { - if let Some(lh) = self.local_handler.load_full() { - let (my_id, local_handler) = lh.as_ref(); - if to.borrow() == my_id { - return local_handler(msg).await; - } - } - let status = self.status.borrow().clone(); - let node_status = match status.nodes.get(&to) { - Some(node_status) => { - if node_status.is_up() { - node_status - } else { - return Err(Error::from(RPCError::NodeDown(to))); - } - } - None => { - return Err(Error::Message(format!( - "Peer ID not found: {:?}", - to.borrow() - ))) - } - }; - match self - .rpc_addr_client - .call(&node_status.addr, msg, timeout) - .await - { - Err(rpc_error) => { - node_status.num_failures.fetch_add(1, Ordering::SeqCst); - // TODO: Save failure info somewhere - Err(Error::from(rpc_error)) - } - Ok(x) => x, - } - } - - pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec> { - let msg = Arc::new(msg); - let mut resp_stream = to - .iter() - .map(|to| self.call_arc(*to, msg.clone(), timeout)) - .collect::>(); - - let mut results = vec![]; - while let Some(resp) = resp_stream.next().await { - results.push(resp); - } - results - } - - pub async fn try_call_many( - self: &Arc, - to: &[UUID], - msg: M, - strategy: RequestStrategy, - ) -> Result, Error> { - let timeout = strategy.rs_timeout; - - let msg = Arc::new(msg); - let mut resp_stream = to - .to_vec() - .into_iter() - .map(|to| { - let self2 = self.clone(); - let msg = msg.clone(); - async move { self2.call_arc(to, msg, timeout).await } - }) - .collect::>(); - - let mut results = vec![]; - let mut errors = vec![]; - - while let Some(resp) = resp_stream.next().await { - match resp { - Ok(msg) => { - results.push(msg); - if results.len() >= strategy.rs_quorum { - break; - } - } - Err(e) => { - errors.push(e); - } - } - } - - if results.len() >= strategy.rs_quorum { - // Continue requests in background. - // Continue the remaining requests immediately using tokio::spawn - // but enqueue a task in the background runner - // to ensure that the process won't exit until the requests are done - // (if we had just enqueued the resp_stream.collect directly in the background runner, - // the requests might have been put on hold in the background runner's queue, - // in which case they might timeout or otherwise fail) - if !strategy.rs_interrupt_after_quorum { - let wait_finished_fut = tokio::spawn(async move { - resp_stream.collect::>().await; - Ok(()) - }); - self.background.spawn(wait_finished_fut.map(|x| { - x.unwrap_or_else(|e| Err(Error::Message(format!("Await failed: {}", e)))) - })); - } - - Ok(results) - } else { - let errors = errors.iter().map(|e| format!("{}", e)).collect::>(); - Err(Error::from(RPCError::TooManyErrors(errors))) - } - } -} - -pub struct RpcAddrClient { - phantom: PhantomData, - - pub http_client: Arc, - pub path: String, -} - -impl RpcAddrClient { - pub fn new(http_client: Arc, path: String) -> Self { - Self { - phantom: PhantomData::default(), - http_client: http_client, - path, - } - } - - pub async fn call( - &self, - to_addr: &SocketAddr, - msg: MB, - timeout: Duration, - ) -> Result, RPCError> - where - MB: Borrow, - { - self.http_client - .call(&self.path, to_addr, msg, timeout) - .await - } -} - -pub struct RpcHttpClient { - request_limiter: Semaphore, - method: ClientMethod, -} - -enum ClientMethod { - HTTP(Client), - HTTPS(Client, hyper::Body>), -} - -impl RpcHttpClient { - pub fn new( - max_concurrent_requests: usize, - tls_config: &Option, - ) -> Result { - let method = if let Some(cf) = tls_config { - let ca_certs = tls_util::load_certs(&cf.ca_cert)?; - let node_certs = tls_util::load_certs(&cf.node_cert)?; - let node_key = tls_util::load_private_key(&cf.node_key)?; - - let mut config = rustls::ClientConfig::new(); - - for crt in ca_certs.iter() { - config.root_store.add(crt)?; - } - - config.set_single_client_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; - - let connector = - tls_util::HttpsConnectorFixedDnsname::::new(config, "garage"); - - ClientMethod::HTTPS(Client::builder().build(connector)) - } else { - ClientMethod::HTTP(Client::new()) - }; - Ok(RpcHttpClient { - method, - request_limiter: Semaphore::new(max_concurrent_requests), - }) - } - - async fn call( - &self, - path: &str, - to_addr: &SocketAddr, - msg: MB, - timeout: Duration, - ) -> Result, RPCError> - where - MB: Borrow, - M: RpcMessage, - { - let uri = match self.method { - ClientMethod::HTTP(_) => format!("http://{}/{}", to_addr, path), - ClientMethod::HTTPS(_) => format!("https://{}/{}", to_addr, path), - }; - - let req = Request::builder() - .method(Method::POST) - .uri(uri) - .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; - - let resp_fut = match &self.method { - ClientMethod::HTTP(client) => client.request(req).fuse(), - ClientMethod::HTTPS(client) => client.request(req).fuse(), - }; - - let slot = self.request_limiter.acquire().await; - let resp = tokio::time::timeout(timeout, resp_fut) - .await - .map_err(|e| { - debug!( - "RPC timeout to {}: {}", - to_addr, - debug_serialize(msg.borrow()) - ); - e - })? - .map_err(|e| { - warn!( - "RPC HTTP client error when connecting to {}: {}", - to_addr, e - ); - e - })?; - drop(slot); - - let status = resp.status(); - let body = hyper::body::to_bytes(resp.into_body()).await?; - match rmp_serde::decode::from_read::<_, Result>(body.into_buf())? { - Err(e) => Ok(Err(Error::RemoteError(e, status))), - Ok(x) => Ok(Ok(x)), - } - } -} diff --git a/src/rpc_server.rs b/src/rpc_server.rs deleted file mode 100644 index bcf7496f..00000000 --- a/src/rpc_server.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::time::Instant; - -use bytes::IntoBuf; -use futures::future::Future; -use futures_util::future::*; -use futures_util::stream::*; -use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use serde::{Deserialize, Serialize}; -use tokio::net::{TcpListener, TcpStream}; -use tokio_rustls::server::TlsStream; -use tokio_rustls::TlsAcceptor; - -use crate::data::*; -use crate::error::Error; -use crate::server::TlsConfig; -use crate::tls_util; - -pub trait RpcMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {} - -type ResponseFuture = Pin, Error>> + Send>>; -type Handler = Box, SocketAddr) -> ResponseFuture + Send + Sync>; - -pub struct RpcServer { - pub bind_addr: SocketAddr, - pub tls_config: Option, - - handlers: HashMap, -} - -async fn handle_func( - handler: Arc, - req: Request, - sockaddr: SocketAddr, - name: Arc, -) -> Result, Error> -where - M: RpcMessage + 'static, - F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + 'static, -{ - let begin_time = Instant::now(); - let whole_body = hyper::body::to_bytes(req.into_body()).await?; - let msg = rmp_serde::decode::from_read::<_, M>(whole_body.into_buf())?; - match handler(msg, sockaddr).await { - Ok(resp) => { - let resp_bytes = rmp_to_vec_all_named::>(&Ok(resp))?; - let rpc_duration = (Instant::now() - begin_time).as_millis(); - if rpc_duration > 100 { - debug!("RPC {} ok, took long: {} ms", name, rpc_duration,); - } - Ok(Response::new(Body::from(resp_bytes))) - } - Err(e) => { - let err_str = format!("{}", e); - let rep_bytes = rmp_to_vec_all_named::>(&Err(err_str))?; - let mut err_response = Response::new(Body::from(rep_bytes)); - *err_response.status_mut() = e.http_status_code(); - warn!( - "RPC error ({}): {} ({} ms)", - name, - e, - (Instant::now() - begin_time).as_millis(), - ); - Ok(err_response) - } - } -} - -impl RpcServer { - pub fn new(bind_addr: SocketAddr, tls_config: Option) -> Self { - Self { - bind_addr, - tls_config, - handlers: HashMap::new(), - } - } - - pub fn add_handler(&mut self, name: String, handler: F) - where - M: RpcMessage + 'static, - F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + 'static, - { - let name2 = Arc::new(name.clone()); - let handler_arc = Arc::new(handler); - let handler = Box::new(move |req: Request, sockaddr: SocketAddr| { - let handler2 = handler_arc.clone(); - let b: ResponseFuture = Box::pin(handle_func(handler2, req, sockaddr, name2.clone())); - b - }); - self.handlers.insert(name, handler); - } - - async fn handler( - self: Arc, - req: Request, - addr: SocketAddr, - ) -> Result, Error> { - if req.method() != &Method::POST { - let mut bad_request = Response::default(); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); - } - - let path = &req.uri().path()[1..]; - let handler = match self.handlers.get(path) { - Some(h) => h, - None => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - return Ok(not_found); - } - }; - - let resp_waiter = tokio::spawn(handler(req, addr)); - match resp_waiter.await { - Err(err) => { - warn!("Handler await error: {}", err); - let mut ise = Response::default(); - *ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - Ok(ise) - } - Ok(Err(err)) => { - let mut bad_request = Response::new(Body::from(format!("{}", err))); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - Ok(bad_request) - } - Ok(Ok(resp)) => Ok(resp), - } - } - - pub async fn run( - self: Arc, - shutdown_signal: impl Future, - ) -> Result<(), Error> { - if let Some(tls_config) = self.tls_config.as_ref() { - let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; - let node_certs = tls_util::load_certs(&tls_config.node_cert)?; - let node_key = tls_util::load_private_key(&tls_config.node_key)?; - - let mut ca_store = rustls::RootCertStore::empty(); - for crt in ca_certs.iter() { - ca_store.add(crt)?; - } - - let mut config = - rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store)); - config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; - let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); - - let mut listener = TcpListener::bind(&self.bind_addr).await?; - let incoming = listener.incoming().filter_map(|socket| async { - match socket { - Ok(stream) => match tls_acceptor.clone().accept(stream).await { - Ok(x) => Some(Ok::<_, hyper::Error>(x)), - Err(_e) => None, - }, - Err(_) => None, - } - }); - let incoming = hyper::server::accept::from_stream(incoming); - - let self_arc = self.clone(); - let service = make_service_fn(|conn: &TlsStream| { - let client_addr = conn - .get_ref() - .0 - .peer_addr() - .unwrap_or(([0, 0, 0, 0], 0).into()); - let self_arc = self_arc.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request| { - self_arc.clone().handler(req, client_addr).map_err(|e| { - warn!("RPC handler error: {}", e); - e - }) - })) - } - }); - - let server = Server::builder(incoming).serve(service); - - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("RPC server listening on http://{}", self.bind_addr); - - graceful.await?; - } else { - let self_arc = self.clone(); - let service = make_service_fn(move |conn: &AddrStream| { - let client_addr = conn.remote_addr(); - let self_arc = self_arc.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request| { - self_arc.clone().handler(req, client_addr).map_err(|e| { - warn!("RPC handler error: {}", e); - e - }) - })) - } - }); - - let server = Server::bind(&self.bind_addr).serve(service); - - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("RPC server listening on http://{}", self.bind_addr); - - graceful.await?; - } - - Ok(()) - } -} diff --git a/src/server.rs b/src/server.rs index 3ea29105..de04615f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,79 +1,34 @@ -use std::io::{Read, Write}; -use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use futures_util::future::*; -use serde::Deserialize; use tokio::sync::watch; use crate::background::*; -use crate::data::*; +use crate::config::*; use crate::error::Error; -use crate::membership::System; -use crate::rpc_server::RpcServer; -use crate::table::*; -use crate::table_fullcopy::*; -use crate::table_sharded::*; - -use crate::block::*; -use crate::block_ref_table::*; -use crate::bucket_table::*; -use crate::object_table::*; -use crate::version_table::*; - -use crate::admin_rpc::*; -use crate::api_server; - -#[derive(Deserialize, Debug, Clone)] -pub struct Config { - pub metadata_dir: PathBuf, - pub data_dir: PathBuf, - - pub api_bind_addr: SocketAddr, - pub rpc_bind_addr: SocketAddr, - pub bootstrap_peers: Vec, +use crate::rpc::membership::System; +use crate::rpc::rpc_client::RpcHttpClient; +use crate::rpc::rpc_server::RpcServer; - #[serde(default = "default_max_concurrent_rpc_requests")] - pub max_concurrent_rpc_requests: usize, - - #[serde(default = "default_block_size")] - pub block_size: usize, - - #[serde(default = "default_replication_factor")] - pub meta_replication_factor: usize, - - #[serde(default = "default_epidemic_factor")] - pub meta_epidemic_factor: usize, +use crate::table::table_fullcopy::*; +use crate::table::table_sharded::*; +use crate::table::*; - #[serde(default = "default_replication_factor")] - pub data_replication_factor: usize, +use crate::store::block::*; +use crate::store::block_ref_table::*; +use crate::store::bucket_table::*; +use crate::store::object_table::*; +use crate::store::version_table::*; - pub rpc_tls: Option, -} +use crate::api::api_server; -fn default_max_concurrent_rpc_requests() -> usize { - 12 -} -fn default_block_size() -> usize { - 1048576 -} -fn default_replication_factor() -> usize { - 3 -} -fn default_epidemic_factor() -> usize { - 3 -} - -#[derive(Deserialize, Debug, Clone)] -pub struct TlsConfig { - pub ca_cert: String, - pub node_cert: String, - pub node_key: String, -} +use crate::admin_rpc::*; pub struct Garage { + pub config: Config, + pub db: sled::Db, pub background: Arc, pub system: Arc, @@ -88,33 +43,46 @@ pub struct Garage { impl Garage { pub async fn new( config: Config, - id: UUID, db: sled::Db, background: Arc, rpc_server: &mut RpcServer, ) -> Arc { info!("Initialize membership management system..."); - let system = System::new(config.clone(), id, background.clone(), rpc_server); - - info!("Initialize block manager..."); - let block_manager = - BlockManager::new(&db, config.data_dir.clone(), system.clone(), rpc_server); + let rpc_http_client = Arc::new( + RpcHttpClient::new(config.max_concurrent_rpc_requests, &config.rpc_tls) + .expect("Could not create RPC client"), + ); + let system = System::new( + config.metadata_dir.clone(), + rpc_http_client, + background.clone(), + rpc_server, + ); let data_rep_param = TableShardedReplication { - replication_factor: system.config.data_replication_factor, - write_quorum: (system.config.data_replication_factor + 1) / 2, + replication_factor: config.data_replication_factor, + write_quorum: (config.data_replication_factor + 1) / 2, read_quorum: 1, }; let meta_rep_param = TableShardedReplication { - replication_factor: system.config.meta_replication_factor, - write_quorum: (system.config.meta_replication_factor + 1) / 2, - read_quorum: (system.config.meta_replication_factor + 1) / 2, + replication_factor: config.meta_replication_factor, + write_quorum: (config.meta_replication_factor + 1) / 2, + read_quorum: (config.meta_replication_factor + 1) / 2, }; let control_rep_param = TableFullReplication::new( - system.config.meta_epidemic_factor, - (system.config.meta_epidemic_factor + 1) / 2, + config.meta_epidemic_factor, + (config.meta_epidemic_factor + 1) / 2, + ); + + info!("Initialize block manager..."); + let block_manager = BlockManager::new( + &db, + config.data_dir.clone(), + data_rep_param.clone(), + system.clone(), + rpc_server, ); info!("Initialize block_ref_table..."); @@ -172,6 +140,7 @@ impl Garage { info!("Initialize Garage..."); let garage = Arc::new(Self { + config, db, system: system.clone(), block_manager, @@ -193,40 +162,6 @@ impl Garage { } } -fn read_config(config_file: PathBuf) -> Result { - let mut file = std::fs::OpenOptions::new() - .read(true) - .open(config_file.as_path())?; - - let mut config = String::new(); - file.read_to_string(&mut config)?; - - Ok(toml::from_str(&config)?) -} - -fn gen_node_id(metadata_dir: &PathBuf) -> Result { - let mut id_file = metadata_dir.clone(); - id_file.push("node_id"); - if id_file.as_path().exists() { - let mut f = std::fs::File::open(id_file.as_path())?; - let mut d = vec![]; - f.read_to_end(&mut d)?; - if d.len() != 32 { - return Err(Error::Message(format!("Corrupt node_id file"))); - } - - let mut id = [0u8; 32]; - id.copy_from_slice(&d[..]); - Ok(id.into()) - } else { - let id = gen_uuid(); - - let mut f = std::fs::File::create(id_file.as_path())?; - f.write_all(id.as_slice())?; - Ok(id) - } -} - async fn shutdown_signal(send_cancel: watch::Sender) -> Result<(), Error> { // Wait for the CTRL+C signal tokio::signal::ctrl_c() @@ -249,9 +184,6 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> { info!("Loading configuration..."); let config = read_config(config_file).expect("Unable to read config file"); - let id = gen_node_id(&config.metadata_dir).expect("Unable to read or generate node ID"); - info!("Node ID: {}", hex::encode(&id)); - info!("Opening database..."); let mut db_path = config.metadata_dir.clone(); db_path.push("db"); @@ -264,17 +196,21 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> { let (send_cancel, watch_cancel) = watch::channel(false); let background = BackgroundRunner::new(16, watch_cancel.clone()); - let garage = Garage::new(config, id, db, background.clone(), &mut rpc_server).await; + let garage = Garage::new(config, db, background.clone(), &mut rpc_server).await; info!("Initializing RPC and API servers..."); let run_rpc_server = Arc::new(rpc_server).run(wait_from(watch_cancel.clone())); let api_server = api_server::run_api_server(garage.clone(), wait_from(watch_cancel.clone())); futures::try_join!( - garage.system.clone().bootstrap().map(|rv| { - info!("Bootstrap done"); - Ok(rv) - }), + garage + .system + .clone() + .bootstrap(&garage.config.bootstrap_peers[..]) + .map(|rv| { + info!("Bootstrap done"); + Ok(rv) + }), run_rpc_server.map(|rv| { info!("RPC server exited"); rv diff --git a/src/store/block.rs b/src/store/block.rs new file mode 100644 index 00000000..e2ef32e0 --- /dev/null +++ b/src/store/block.rs @@ -0,0 +1,506 @@ +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwapOption; +use futures::future::*; +use futures::select; +use futures::stream::*; +use serde::{Deserialize, Serialize}; +use tokio::fs; +use tokio::prelude::*; +use tokio::sync::{watch, Mutex, Notify}; + +use crate::data; +use crate::data::*; +use crate::error::Error; + +use crate::rpc::membership::System; +use crate::rpc::rpc_client::*; +use crate::rpc::rpc_server::*; + +use crate::table::table_sharded::TableShardedReplication; +use crate::table::TableReplication; + +use crate::store::block_ref_table::*; + +use crate::server::Garage; + +pub const INLINE_THRESHOLD: usize = 3072; + +const BLOCK_RW_TIMEOUT: Duration = Duration::from_secs(42); +const NEED_BLOCK_QUERY_TIMEOUT: Duration = Duration::from_secs(5); +const RESYNC_RETRY_TIMEOUT: Duration = Duration::from_secs(10); + +#[derive(Debug, Serialize, Deserialize)] +pub enum Message { + Ok, + GetBlock(Hash), + PutBlock(PutBlockMessage), + NeedBlockQuery(Hash), + NeedBlockReply(bool), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PutBlockMessage { + pub hash: Hash, + + #[serde(with = "serde_bytes")] + pub data: Vec, +} + +impl RpcMessage for Message {} + +pub struct BlockManager { + pub replication: TableShardedReplication, + pub data_dir: PathBuf, + pub data_dir_lock: Mutex<()>, + + pub rc: sled::Tree, + + pub resync_queue: sled::Tree, + pub resync_notify: Notify, + + pub system: Arc, + rpc_client: Arc>, + pub garage: ArcSwapOption, +} + +impl BlockManager { + pub fn new( + db: &sled::Db, + data_dir: PathBuf, + replication: TableShardedReplication, + system: Arc, + rpc_server: &mut RpcServer, + ) -> Arc { + let rc = db + .open_tree("block_local_rc") + .expect("Unable to open block_local_rc tree"); + rc.set_merge_operator(rc_merge); + + let resync_queue = db + .open_tree("block_local_resync_queue") + .expect("Unable to open block_local_resync_queue tree"); + + let rpc_path = "block_manager"; + let rpc_client = system.rpc_client::(rpc_path); + + let block_manager = Arc::new(Self { + replication, + data_dir, + data_dir_lock: Mutex::new(()), + rc, + resync_queue, + resync_notify: Notify::new(), + system, + rpc_client, + garage: ArcSwapOption::from(None), + }); + block_manager + .clone() + .register_handler(rpc_server, rpc_path.into()); + block_manager + } + + fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + let self2 = self.clone(); + rpc_server.add_handler::(path, move |msg, _addr| { + let self2 = self2.clone(); + async move { self2.handle(&msg).await } + }); + + let self2 = self.clone(); + self.rpc_client + .set_local_handler(self.system.id, move |msg| { + let self2 = self2.clone(); + async move { self2.handle(&msg).await } + }); + } + + async fn handle(self: Arc, msg: &Message) -> Result { + match msg { + Message::PutBlock(m) => self.write_block(&m.hash, &m.data).await, + Message::GetBlock(h) => self.read_block(h).await, + Message::NeedBlockQuery(h) => self.need_block(h).await.map(Message::NeedBlockReply), + _ => Err(Error::BadRequest(format!("Unexpected RPC message"))), + } + } + + pub async fn spawn_background_worker(self: Arc) { + // Launch 2 simultaneous workers for background resync loop preprocessing + for i in 0..2usize { + let bm2 = self.clone(); + let background = self.system.background.clone(); + tokio::spawn(async move { + tokio::time::delay_for(Duration::from_secs(10)).await; + background + .spawn_worker(format!("block resync worker {}", i), move |must_exit| { + bm2.resync_loop(must_exit) + }) + .await; + }); + } + } + + pub async fn write_block(&self, hash: &Hash, data: &[u8]) -> Result { + let _lock = self.data_dir_lock.lock().await; + + let mut path = self.block_dir(hash); + fs::create_dir_all(&path).await?; + + path.push(hex::encode(hash)); + if fs::metadata(&path).await.is_ok() { + return Ok(Message::Ok); + } + + let mut f = fs::File::create(path).await?; + f.write_all(data).await?; + drop(f); + + Ok(Message::Ok) + } + + pub async fn read_block(&self, hash: &Hash) -> Result { + let path = self.block_path(hash); + + let mut f = match fs::File::open(&path).await { + Ok(f) => f, + Err(e) => { + // Not found but maybe we should have had it ?? + self.put_to_resync(hash, 0)?; + return Err(Into::into(e)); + } + }; + let mut data = vec![]; + f.read_to_end(&mut data).await?; + drop(f); + + if data::hash(&data[..]) != *hash { + let _lock = self.data_dir_lock.lock().await; + warn!("Block {:?} is corrupted. Deleting and resyncing.", hash); + fs::remove_file(path).await?; + self.put_to_resync(&hash, 0)?; + return Err(Error::CorruptData(*hash)); + } + + Ok(Message::PutBlock(PutBlockMessage { hash: *hash, data })) + } + + pub async fn need_block(&self, hash: &Hash) -> Result { + let needed = self + .rc + .get(hash.as_ref())? + .map(|x| u64_from_bytes(x.as_ref()) > 0) + .unwrap_or(false); + if needed { + let path = self.block_path(hash); + let exists = fs::metadata(&path).await.is_ok(); + Ok(!exists) + } else { + Ok(false) + } + } + + fn block_dir(&self, hash: &Hash) -> PathBuf { + let mut path = self.data_dir.clone(); + path.push(hex::encode(&hash.as_slice()[0..1])); + path.push(hex::encode(&hash.as_slice()[1..2])); + path + } + fn block_path(&self, hash: &Hash) -> PathBuf { + let mut path = self.block_dir(hash); + path.push(hex::encode(hash.as_ref())); + path + } + + pub fn block_incref(&self, hash: &Hash) -> Result<(), Error> { + let old_rc = self.rc.get(&hash)?; + self.rc.merge(&hash, vec![1])?; + if old_rc.map(|x| u64_from_bytes(&x[..]) == 0).unwrap_or(true) { + self.put_to_resync(&hash, BLOCK_RW_TIMEOUT.as_millis() as u64)?; + } + Ok(()) + } + + pub fn block_decref(&self, hash: &Hash) -> Result<(), Error> { + let new_rc = self.rc.merge(&hash, vec![0])?; + if new_rc.map(|x| u64_from_bytes(&x[..]) == 0).unwrap_or(true) { + self.put_to_resync(&hash, 0)?; + } + Ok(()) + } + + fn put_to_resync(&self, hash: &Hash, delay_millis: u64) -> Result<(), Error> { + let when = now_msec() + delay_millis; + trace!("Put resync_queue: {} {:?}", when, hash); + let mut key = u64::to_be_bytes(when).to_vec(); + key.extend(hash.as_ref()); + self.resync_queue.insert(key, hash.as_ref())?; + self.resync_notify.notify(); + Ok(()) + } + + async fn resync_loop( + self: Arc, + mut must_exit: watch::Receiver, + ) -> Result<(), Error> { + let mut n_failures = 0usize; + while !*must_exit.borrow() { + if let Some((time_bytes, hash_bytes)) = self.resync_queue.pop_min()? { + let time_msec = u64_from_bytes(&time_bytes[0..8]); + let now = now_msec(); + if now >= time_msec { + let mut hash = [0u8; 32]; + hash.copy_from_slice(hash_bytes.as_ref()); + let hash = Hash::from(hash); + + if let Err(e) = self.resync_iter(&hash).await { + warn!("Failed to resync block {:?}, retrying later: {}", hash, e); + self.put_to_resync(&hash, RESYNC_RETRY_TIMEOUT.as_millis() as u64)?; + n_failures += 1; + if n_failures >= 10 { + warn!("Too many resync failures, throttling."); + tokio::time::delay_for(Duration::from_secs(1)).await; + } + } else { + n_failures = 0; + } + } else { + self.resync_queue.insert(time_bytes, hash_bytes)?; + let delay = tokio::time::delay_for(Duration::from_millis(time_msec - now)); + select! { + _ = delay.fuse() => (), + _ = self.resync_notify.notified().fuse() => (), + _ = must_exit.recv().fuse() => (), + } + } + } else { + select! { + _ = self.resync_notify.notified().fuse() => (), + _ = must_exit.recv().fuse() => (), + } + } + } + Ok(()) + } + + async fn resync_iter(&self, hash: &Hash) -> Result<(), Error> { + let path = self.block_path(hash); + + let exists = fs::metadata(&path).await.is_ok(); + let needed = self + .rc + .get(hash.as_ref())? + .map(|x| u64_from_bytes(x.as_ref()) > 0) + .unwrap_or(false); + + if exists != needed { + info!( + "Resync block {:?}: exists {}, needed {}", + hash, exists, needed + ); + } + + if exists && !needed { + let garage = self.garage.load_full().unwrap(); + let active_refs = garage + .block_ref_table + .get_range(&hash, None, Some(()), 1) + .await?; + let needed_by_others = !active_refs.is_empty(); + if needed_by_others { + let ring = self.system.ring.borrow().clone(); + let who = self.replication.replication_nodes(&hash, &ring); + let msg = Arc::new(Message::NeedBlockQuery(*hash)); + let who_needs_fut = who.iter().map(|to| { + self.rpc_client + .call_arc(*to, msg.clone(), NEED_BLOCK_QUERY_TIMEOUT) + }); + let who_needs = join_all(who_needs_fut).await; + + let mut need_nodes = vec![]; + for (node, needed) in who.into_iter().zip(who_needs.iter()) { + match needed { + Ok(Message::NeedBlockReply(needed)) => { + if *needed { + need_nodes.push(node); + } + } + Err(e) => { + return Err(Error::Message(format!( + "Should delete block, but unable to confirm that all other nodes that need it have it: {}", + e + ))); + } + Ok(_) => { + return Err(Error::Message(format!( + "Unexpected response to NeedBlockQuery RPC" + ))); + } + } + } + + if need_nodes.len() > 0 { + let put_block_message = self.read_block(hash).await?; + self.rpc_client + .try_call_many( + &need_nodes[..], + put_block_message, + RequestStrategy::with_quorum(need_nodes.len()) + .with_timeout(BLOCK_RW_TIMEOUT), + ) + .await?; + } + } + fs::remove_file(path).await?; + self.resync_queue.remove(&hash)?; + } + + if needed && !exists { + // TODO find a way to not do this if they are sending it to us + // Let's suppose this isn't an issue for now with the BLOCK_RW_TIMEOUT delay + // between the RC being incremented and this part being called. + let block_data = self.rpc_get_block(&hash).await?; + self.write_block(hash, &block_data[..]).await?; + } + + Ok(()) + } + + pub async fn rpc_get_block(&self, hash: &Hash) -> Result, Error> { + let who = self.replication.read_nodes(&hash, &self.system); + let resps = self + .rpc_client + .try_call_many( + &who[..], + Message::GetBlock(*hash), + RequestStrategy::with_quorum(1) + .with_timeout(BLOCK_RW_TIMEOUT) + .interrupt_after_quorum(true), + ) + .await?; + + for resp in resps { + if let Message::PutBlock(msg) = resp { + return Ok(msg.data); + } + } + Err(Error::Message(format!( + "Unable to read block {:?}: no valid blocks returned", + hash + ))) + } + + pub async fn rpc_put_block(&self, hash: Hash, data: Vec) -> Result<(), Error> { + let who = self.replication.write_nodes(&hash, &self.system); + self.rpc_client + .try_call_many( + &who[..], + Message::PutBlock(PutBlockMessage { hash, data }), + RequestStrategy::with_quorum(self.replication.write_quorum()) + .with_timeout(BLOCK_RW_TIMEOUT), + ) + .await?; + Ok(()) + } + + pub async fn repair_data_store(&self, must_exit: &watch::Receiver) -> Result<(), Error> { + // 1. Repair blocks from RC table + let garage = self.garage.load_full().unwrap(); + let mut last_hash = None; + let mut i = 0usize; + for entry in garage.block_ref_table.store.iter() { + let (_k, v_bytes) = entry?; + let block_ref = rmp_serde::decode::from_read_ref::<_, BlockRef>(v_bytes.as_ref())?; + if Some(&block_ref.block) == last_hash.as_ref() { + continue; + } + if !block_ref.deleted { + last_hash = Some(block_ref.block); + self.put_to_resync(&block_ref.block, 0)?; + } + i += 1; + if i & 0xFF == 0 && *must_exit.borrow() { + return Ok(()); + } + } + + // 2. Repair blocks actually on disk + let mut ls_data_dir = fs::read_dir(&self.data_dir).await?; + while let Some(data_dir_ent) = ls_data_dir.next().await { + let data_dir_ent = data_dir_ent?; + let dir_name = data_dir_ent.file_name(); + let dir_name = match dir_name.into_string() { + Ok(x) => x, + Err(_) => continue, + }; + if dir_name.len() != 2 || hex::decode(&dir_name).is_err() { + continue; + } + + let mut ls_data_dir_2 = match fs::read_dir(data_dir_ent.path()).await { + Err(e) => { + warn!( + "Warning: could not list dir {:?}: {}", + data_dir_ent.path().to_str(), + e + ); + continue; + } + Ok(x) => x, + }; + while let Some(file) = ls_data_dir_2.next().await { + let file = file?; + let file_name = file.file_name(); + let file_name = match file_name.into_string() { + Ok(x) => x, + Err(_) => continue, + }; + if file_name.len() != 64 { + continue; + } + let hash_bytes = match hex::decode(&file_name) { + Ok(h) => h, + Err(_) => continue, + }; + let mut hash = [0u8; 32]; + hash.copy_from_slice(&hash_bytes[..]); + self.put_to_resync(&hash.into(), 0)?; + + if *must_exit.borrow() { + return Ok(()); + } + } + } + Ok(()) + } +} + +fn u64_from_bytes(bytes: &[u8]) -> u64 { + assert!(bytes.len() == 8); + let mut x8 = [0u8; 8]; + x8.copy_from_slice(bytes); + u64::from_be_bytes(x8) +} + +fn rc_merge(_key: &[u8], old: Option<&[u8]>, new: &[u8]) -> Option> { + let old = old.map(u64_from_bytes).unwrap_or(0); + assert!(new.len() == 1); + let new = match new[0] { + 0 => { + if old > 0 { + old - 1 + } else { + 0 + } + } + 1 => old + 1, + _ => unreachable!(), + }; + if new == 0 { + None + } else { + Some(u64::to_be_bytes(new).to_vec()) + } +} diff --git a/src/store/block_ref_table.rs b/src/store/block_ref_table.rs new file mode 100644 index 00000000..c8a2a2a1 --- /dev/null +++ b/src/store/block_ref_table.rs @@ -0,0 +1,68 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::background::*; +use crate::data::*; +use crate::error::Error; + +use crate::table::*; + +use crate::store::block::*; + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub struct BlockRef { + // Primary key + pub block: Hash, + + // Sort key + pub version: UUID, + + // Keep track of deleted status + pub deleted: bool, +} + +impl Entry for BlockRef { + fn partition_key(&self) -> &Hash { + &self.block + } + fn sort_key(&self) -> &UUID { + &self.version + } + + fn merge(&mut self, other: &Self) { + if other.deleted { + self.deleted = true; + } + } +} + +pub struct BlockRefTable { + pub background: Arc, + pub block_manager: Arc, +} + +#[async_trait] +impl TableSchema for BlockRefTable { + type P = Hash; + type S = UUID; + type E = BlockRef; + type Filter = (); + + async fn updated(&self, old: Option, new: Option) -> Result<(), Error> { + let block = &old.as_ref().or(new.as_ref()).unwrap().block; + let was_before = old.as_ref().map(|x| !x.deleted).unwrap_or(false); + let is_after = new.as_ref().map(|x| !x.deleted).unwrap_or(false); + if is_after && !was_before { + self.block_manager.block_incref(block)?; + } + if was_before && !is_after { + self.block_manager.block_decref(block)?; + } + Ok(()) + } + + fn matches_filter(entry: &Self::E, _filter: &Self::Filter) -> bool { + !entry.deleted + } +} diff --git a/src/store/bucket_table.rs b/src/store/bucket_table.rs new file mode 100644 index 00000000..5604049c --- /dev/null +++ b/src/store/bucket_table.rs @@ -0,0 +1,82 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::error::Error; +use crate::table::*; + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub struct Bucket { + // Primary key + pub name: String, + + // Timestamp and deletion + // Upon version increment, all info is replaced + pub timestamp: u64, + pub deleted: bool, + + // Authorized keys + pub authorized_keys: Vec, +} + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub struct AllowedKey { + pub access_key_id: String, + pub timestamp: u64, + pub allowed_read: bool, + pub allowed_write: bool, +} + +impl Entry for Bucket { + fn partition_key(&self) -> &EmptyKey { + &EmptyKey + } + fn sort_key(&self) -> &String { + &self.name + } + + fn merge(&mut self, other: &Self) { + if other.timestamp < self.timestamp { + *self = other.clone(); + return; + } + if self.timestamp > other.timestamp { + return; + } + for ak in other.authorized_keys.iter() { + match self + .authorized_keys + .binary_search_by(|our_ak| our_ak.access_key_id.cmp(&ak.access_key_id)) + { + Ok(i) => { + let our_ak = &mut self.authorized_keys[i]; + if ak.timestamp > our_ak.timestamp { + our_ak.timestamp = ak.timestamp; + our_ak.allowed_read = ak.allowed_read; + our_ak.allowed_write = ak.allowed_write; + } + } + Err(i) => { + self.authorized_keys.insert(i, ak.clone()); + } + } + } + } +} + +pub struct BucketTable; + +#[async_trait] +impl TableSchema for BucketTable { + type P = EmptyKey; + type S = String; + type E = Bucket; + type Filter = (); + + async fn updated(&self, _old: Option, _new: Option) -> Result<(), Error> { + Ok(()) + } + + fn matches_filter(entry: &Self::E, _filter: &Self::Filter) -> bool { + !entry.deleted + } +} diff --git a/src/store/mod.rs b/src/store/mod.rs new file mode 100644 index 00000000..afadc9bb --- /dev/null +++ b/src/store/mod.rs @@ -0,0 +1,5 @@ +pub mod block; +pub mod block_ref_table; +pub mod bucket_table; +pub mod object_table; +pub mod version_table; diff --git a/src/store/object_table.rs b/src/store/object_table.rs new file mode 100644 index 00000000..97de0cdb --- /dev/null +++ b/src/store/object_table.rs @@ -0,0 +1,134 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::background::BackgroundRunner; +use crate::data::*; +use crate::error::Error; + +use crate::table::table_sharded::*; +use crate::table::*; + +use crate::store::version_table::*; + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub struct Object { + // Primary key + pub bucket: String, + + // Sort key + pub key: String, + + // Data + pub versions: Vec>, +} + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub struct ObjectVersion { + pub uuid: UUID, + pub timestamp: u64, + + pub mime_type: String, + pub size: u64, + pub is_complete: bool, + + pub data: ObjectVersionData, +} + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub enum ObjectVersionData { + DeleteMarker, + Inline(#[serde(with = "serde_bytes")] Vec), + FirstBlock(Hash), +} + +impl ObjectVersion { + fn cmp_key(&self) -> (u64, &UUID) { + (self.timestamp, &self.uuid) + } +} + +impl Entry for Object { + fn partition_key(&self) -> &String { + &self.bucket + } + fn sort_key(&self) -> &String { + &self.key + } + + fn merge(&mut self, other: &Self) { + for other_v in other.versions.iter() { + match self + .versions + .binary_search_by(|v| v.cmp_key().cmp(&other_v.cmp_key())) + { + Ok(i) => { + let mut v = &mut self.versions[i]; + if other_v.size > v.size { + v.size = other_v.size; + } + if other_v.is_complete && !v.is_complete { + v.is_complete = true; + } + } + Err(i) => { + self.versions.insert(i, other_v.clone()); + } + } + } + let last_complete = self + .versions + .iter() + .enumerate() + .rev() + .filter(|(_, v)| v.is_complete) + .next() + .map(|(vi, _)| vi); + + if let Some(last_vi) = last_complete { + self.versions = self.versions.drain(last_vi..).collect::>(); + } + } +} + +pub struct ObjectTable { + pub background: Arc, + pub version_table: Arc>, +} + +#[async_trait] +impl TableSchema for ObjectTable { + type P = String; + type S = String; + type E = Object; + type Filter = (); + + async fn updated(&self, old: Option, new: Option) -> Result<(), Error> { + let version_table = self.version_table.clone(); + if let (Some(old_v), Some(new_v)) = (old, new) { + // Propagate deletion of old versions + for v in old_v.versions.iter() { + if new_v + .versions + .binary_search_by(|nv| nv.cmp_key().cmp(&v.cmp_key())) + .is_err() + { + let deleted_version = Version { + uuid: v.uuid, + deleted: true, + blocks: vec![], + bucket: old_v.bucket.clone(), + key: old_v.key.clone(), + }; + version_table.insert(&deleted_version).await?; + } + } + } + Ok(()) + } + + fn matches_filter(_entry: &Self::E, _filter: &Self::Filter) -> bool { + // TODO + true + } +} diff --git a/src/store/version_table.rs b/src/store/version_table.rs new file mode 100644 index 00000000..d25a56ca --- /dev/null +++ b/src/store/version_table.rs @@ -0,0 +1,95 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::background::BackgroundRunner; +use crate::data::*; +use crate::error::Error; + +use crate::table::table_sharded::*; +use crate::table::*; + +use crate::store::block_ref_table::*; + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub struct Version { + // Primary key + pub uuid: UUID, + + // Actual data: the blocks for this version + pub deleted: bool, + pub blocks: Vec, + + // Back link to bucket+key so that we can figure if + // this was deleted later on + pub bucket: String, + pub key: String, +} + +#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] +pub struct VersionBlock { + pub offset: u64, + pub hash: Hash, +} + +impl Entry for Version { + fn partition_key(&self) -> &Hash { + &self.uuid + } + fn sort_key(&self) -> &EmptyKey { + &EmptyKey + } + + fn merge(&mut self, other: &Self) { + if other.deleted { + self.deleted = true; + self.blocks.clear(); + } else if !self.deleted { + for bi in other.blocks.iter() { + match self.blocks.binary_search_by(|x| x.offset.cmp(&bi.offset)) { + Ok(_) => (), + Err(pos) => { + self.blocks.insert(pos, bi.clone()); + } + } + } + } + } +} + +pub struct VersionTable { + pub background: Arc, + pub block_ref_table: Arc>, +} + +#[async_trait] +impl TableSchema for VersionTable { + type P = Hash; + type S = EmptyKey; + type E = Version; + type Filter = (); + + async fn updated(&self, old: Option, new: Option) -> Result<(), Error> { + let block_ref_table = self.block_ref_table.clone(); + if let (Some(old_v), Some(new_v)) = (old, new) { + // Propagate deletion of version blocks + if new_v.deleted && !old_v.deleted { + let deleted_block_refs = old_v + .blocks + .iter() + .map(|vb| BlockRef { + block: vb.hash, + version: old_v.uuid, + deleted: true, + }) + .collect::>(); + block_ref_table.insert_many(&deleted_block_refs[..]).await?; + } + } + Ok(()) + } + + fn matches_filter(entry: &Self::E, _filter: &Self::Filter) -> bool { + !entry.deleted + } +} diff --git a/src/table.rs b/src/table.rs deleted file mode 100644 index a3d02d0c..00000000 --- a/src/table.rs +++ /dev/null @@ -1,522 +0,0 @@ -use std::collections::{BTreeMap, HashMap}; -use std::sync::Arc; -use std::time::Duration; - -use arc_swap::ArcSwapOption; -use async_trait::async_trait; -use futures::stream::*; -use serde::{Deserialize, Serialize}; -use serde_bytes::ByteBuf; - -use crate::data::*; -use crate::error::Error; -use crate::membership::{Ring, System}; -use crate::rpc_client::*; -use crate::rpc_server::*; -use crate::table_sync::*; - -const TABLE_RPC_TIMEOUT: Duration = Duration::from_secs(10); - -pub struct Table { - pub instance: F, - pub replication: R, - - pub name: String, - pub rpc_client: Arc>>, - - pub system: Arc, - pub store: sled::Tree, - pub syncer: ArcSwapOption>, -} - -#[derive(Serialize, Deserialize)] -pub enum TableRPC { - Ok, - - ReadEntry(F::P, F::S), - ReadEntryResponse(Option), - - // Read range: read all keys in partition P, possibly starting at a certain sort key offset - ReadRange(F::P, Option, Option, usize), - - Update(Vec>), - - SyncRPC(SyncRPC), -} - -impl RpcMessage for TableRPC {} - -pub trait PartitionKey { - fn hash(&self) -> Hash; -} - -pub trait SortKey { - fn sort_key(&self) -> &[u8]; -} - -pub trait Entry: - PartialEq + Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync -{ - fn partition_key(&self) -> &P; - fn sort_key(&self) -> &S; - - fn merge(&mut self, other: &Self); -} - -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct EmptyKey; -impl SortKey for EmptyKey { - fn sort_key(&self) -> &[u8] { - &[] - } -} -impl PartitionKey for EmptyKey { - fn hash(&self) -> Hash { - [0u8; 32].into() - } -} - -impl> PartitionKey for T { - fn hash(&self) -> Hash { - hash(self.as_ref().as_bytes()) - } -} -impl> SortKey for T { - fn sort_key(&self) -> &[u8] { - self.as_ref().as_bytes() - } -} - -impl PartitionKey for Hash { - fn hash(&self) -> Hash { - self.clone() - } -} -impl SortKey for Hash { - fn sort_key(&self) -> &[u8] { - self.as_slice() - } -} - -#[async_trait] -pub trait TableSchema: Send + Sync { - type P: PartitionKey + Clone + PartialEq + Serialize + for<'de> Deserialize<'de> + Send + Sync; - type S: SortKey + Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync; - type E: Entry; - type Filter: Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync; - - async fn updated(&self, old: Option, new: Option) -> Result<(), Error>; - fn matches_filter(_entry: &Self::E, _filter: &Self::Filter) -> bool { - true - } -} - -pub trait TableReplication: Send + Sync { - // See examples in table_sharded.rs and table_fullcopy.rs - // To understand various replication methods - - // Which nodes to send reads from - fn read_nodes(&self, hash: &Hash, system: &System) -> Vec; - fn read_quorum(&self) -> usize; - - // Which nodes to send writes to - fn write_nodes(&self, hash: &Hash, system: &System) -> Vec; - fn write_quorum(&self) -> usize; - fn max_write_errors(&self) -> usize; - fn epidemic_writes(&self) -> bool; - - // Which are the nodes that do actually replicate the data - fn replication_nodes(&self, hash: &Hash, ring: &Ring) -> Vec; - fn split_points(&self, ring: &Ring) -> Vec; -} - -impl Table -where - F: TableSchema + 'static, - R: TableReplication + 'static, -{ - // =============== PUBLIC INTERFACE FUNCTIONS (new, insert, get, etc) =============== - - pub async fn new( - instance: F, - replication: R, - system: Arc, - db: &sled::Db, - name: String, - rpc_server: &mut RpcServer, - ) -> Arc { - let store = db.open_tree(&name).expect("Unable to open DB tree"); - - let rpc_path = format!("table_{}", name); - let rpc_client = system.rpc_client::>(&rpc_path); - - let table = Arc::new(Self { - instance, - replication, - name, - rpc_client, - system, - store, - syncer: ArcSwapOption::from(None), - }); - table.clone().register_handler(rpc_server, rpc_path); - - let syncer = TableSyncer::launch(table.clone()).await; - table.syncer.swap(Some(syncer)); - - table - } - - pub async fn insert(&self, e: &F::E) -> Result<(), Error> { - let hash = e.partition_key().hash(); - let who = self.replication.write_nodes(&hash, &self.system); - //eprintln!("insert who: {:?}", who); - - let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(e)?)); - let rpc = TableRPC::::Update(vec![e_enc]); - - self.rpc_client - .try_call_many( - &who[..], - rpc, - RequestStrategy::with_quorum(self.replication.write_quorum()) - .with_timeout(TABLE_RPC_TIMEOUT), - ) - .await?; - Ok(()) - } - - pub async fn insert_many(&self, entries: &[F::E]) -> Result<(), Error> { - let mut call_list = HashMap::new(); - - for entry in entries.iter() { - let hash = entry.partition_key().hash(); - let who = self.replication.write_nodes(&hash, &self.system); - let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(entry)?)); - for node in who { - if !call_list.contains_key(&node) { - call_list.insert(node, vec![]); - } - call_list.get_mut(&node).unwrap().push(e_enc.clone()); - } - } - - let call_futures = call_list.drain().map(|(node, entries)| async move { - let rpc = TableRPC::::Update(entries); - - let resp = self.rpc_client.call(node, rpc, TABLE_RPC_TIMEOUT).await?; - Ok::<_, Error>((node, resp)) - }); - let mut resps = call_futures.collect::>(); - let mut errors = vec![]; - - while let Some(resp) = resps.next().await { - if let Err(e) = resp { - errors.push(e); - } - } - if errors.len() > self.replication.max_write_errors() { - Err(Error::Message("Too many errors".into())) - } else { - Ok(()) - } - } - - pub async fn get( - self: &Arc, - partition_key: &F::P, - sort_key: &F::S, - ) -> Result, Error> { - let hash = partition_key.hash(); - let who = self.replication.read_nodes(&hash, &self.system); - //eprintln!("get who: {:?}", who); - - let rpc = TableRPC::::ReadEntry(partition_key.clone(), sort_key.clone()); - let resps = self - .rpc_client - .try_call_many( - &who[..], - rpc, - RequestStrategy::with_quorum(self.replication.read_quorum()) - .with_timeout(TABLE_RPC_TIMEOUT) - .interrupt_after_quorum(true), - ) - .await?; - - let mut ret = None; - let mut not_all_same = false; - for resp in resps { - if let TableRPC::ReadEntryResponse(value) = resp { - if let Some(v_bytes) = value { - let v = rmp_serde::decode::from_read_ref::<_, F::E>(v_bytes.as_slice())?; - ret = match ret { - None => Some(v), - Some(mut x) => { - if x != v { - not_all_same = true; - x.merge(&v); - } - Some(x) - } - } - } - } else { - return Err(Error::Message(format!("Invalid return value to read"))); - } - } - if let Some(ret_entry) = &ret { - if not_all_same { - let self2 = self.clone(); - let ent2 = ret_entry.clone(); - self.system - .background - .spawn_cancellable(async move { self2.repair_on_read(&who[..], ent2).await }); - } - } - Ok(ret) - } - - pub async fn get_range( - self: &Arc, - partition_key: &F::P, - begin_sort_key: Option, - filter: Option, - limit: usize, - ) -> Result, Error> { - let hash = partition_key.hash(); - let who = self.replication.read_nodes(&hash, &self.system); - - let rpc = TableRPC::::ReadRange(partition_key.clone(), begin_sort_key, filter, limit); - - let resps = self - .rpc_client - .try_call_many( - &who[..], - rpc, - RequestStrategy::with_quorum(self.replication.read_quorum()) - .with_timeout(TABLE_RPC_TIMEOUT) - .interrupt_after_quorum(true), - ) - .await?; - - let mut ret = BTreeMap::new(); - let mut to_repair = BTreeMap::new(); - for resp in resps { - if let TableRPC::Update(entries) = resp { - for entry_bytes in entries.iter() { - let entry = - rmp_serde::decode::from_read_ref::<_, F::E>(entry_bytes.as_slice())?; - let entry_key = self.tree_key(entry.partition_key(), entry.sort_key()); - match ret.remove(&entry_key) { - None => { - ret.insert(entry_key, Some(entry)); - } - Some(Some(mut prev)) => { - let must_repair = prev != entry; - prev.merge(&entry); - if must_repair { - to_repair.insert(entry_key.clone(), Some(prev.clone())); - } - ret.insert(entry_key, Some(prev)); - } - Some(None) => unreachable!(), - } - } - } - } - if !to_repair.is_empty() { - let self2 = self.clone(); - self.system.background.spawn_cancellable(async move { - for (_, v) in to_repair.iter_mut() { - self2.repair_on_read(&who[..], v.take().unwrap()).await?; - } - Ok(()) - }); - } - let ret_vec = ret - .iter_mut() - .take(limit) - .map(|(_k, v)| v.take().unwrap()) - .collect::>(); - Ok(ret_vec) - } - - // =============== UTILITY FUNCTION FOR CLIENT OPERATIONS =============== - - async fn repair_on_read(&self, who: &[UUID], what: F::E) -> Result<(), Error> { - let what_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(&what)?)); - self.rpc_client - .try_call_many( - &who[..], - TableRPC::::Update(vec![what_enc]), - RequestStrategy::with_quorum(who.len()).with_timeout(TABLE_RPC_TIMEOUT), - ) - .await?; - Ok(()) - } - - // =============== HANDLERS FOR RPC OPERATIONS (SERVER SIDE) ============== - - fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { - let self2 = self.clone(); - rpc_server.add_handler::, _, _>(path, move |msg, _addr| { - let self2 = self2.clone(); - async move { self2.handle(&msg).await } - }); - - let self2 = self.clone(); - self.rpc_client - .set_local_handler(self.system.id, move |msg| { - let self2 = self2.clone(); - async move { self2.handle(&msg).await } - }); - } - - async fn handle(self: &Arc, msg: &TableRPC) -> Result, Error> { - match msg { - TableRPC::ReadEntry(key, sort_key) => { - let value = self.handle_read_entry(key, sort_key)?; - Ok(TableRPC::ReadEntryResponse(value)) - } - TableRPC::ReadRange(key, begin_sort_key, filter, limit) => { - let values = self.handle_read_range(key, begin_sort_key, filter, *limit)?; - Ok(TableRPC::Update(values)) - } - TableRPC::Update(pairs) => { - self.handle_update(pairs).await?; - Ok(TableRPC::Ok) - } - TableRPC::SyncRPC(rpc) => { - let syncer = self.syncer.load_full().unwrap(); - let response = syncer - .handle_rpc(rpc, self.system.background.stop_signal.clone()) - .await?; - Ok(TableRPC::SyncRPC(response)) - } - _ => Err(Error::BadRequest(format!("Unexpected table RPC"))), - } - } - - fn handle_read_entry(&self, p: &F::P, s: &F::S) -> Result, Error> { - let tree_key = self.tree_key(p, s); - if let Some(bytes) = self.store.get(&tree_key)? { - Ok(Some(ByteBuf::from(bytes.to_vec()))) - } else { - Ok(None) - } - } - - fn handle_read_range( - &self, - p: &F::P, - s: &Option, - filter: &Option, - limit: usize, - ) -> Result>, Error> { - let partition_hash = p.hash(); - let first_key = match s { - None => partition_hash.to_vec(), - Some(sk) => self.tree_key(p, sk), - }; - let mut ret = vec![]; - for item in self.store.range(first_key..) { - let (key, value) = item?; - if &key[..32] != partition_hash.as_slice() { - break; - } - let keep = match filter { - None => true, - Some(f) => { - let entry = rmp_serde::decode::from_read_ref::<_, F::E>(value.as_ref())?; - F::matches_filter(&entry, f) - } - }; - if keep { - ret.push(Arc::new(ByteBuf::from(value.as_ref()))); - } - if ret.len() >= limit { - break; - } - } - Ok(ret) - } - - pub async fn handle_update(self: &Arc, entries: &[Arc]) -> Result<(), Error> { - let syncer = self.syncer.load_full().unwrap(); - let mut epidemic_propagate = vec![]; - - for update_bytes in entries.iter() { - let update = rmp_serde::decode::from_read_ref::<_, F::E>(update_bytes.as_slice())?; - - let tree_key = self.tree_key(update.partition_key(), update.sort_key()); - - let (old_entry, new_entry) = self.store.transaction(|db| { - let (old_entry, new_entry) = match db.get(&tree_key)? { - Some(prev_bytes) => { - let old_entry = rmp_serde::decode::from_read_ref::<_, F::E>(&prev_bytes) - .map_err(Error::RMPDecode) - .map_err(sled::ConflictableTransactionError::Abort)?; - let mut new_entry = old_entry.clone(); - new_entry.merge(&update); - (Some(old_entry), new_entry) - } - None => (None, update.clone()), - }; - - let new_bytes = rmp_to_vec_all_named(&new_entry) - .map_err(Error::RMPEncode) - .map_err(sled::ConflictableTransactionError::Abort)?; - db.insert(tree_key.clone(), new_bytes)?; - Ok((old_entry, new_entry)) - })?; - - if old_entry.as_ref() != Some(&new_entry) { - if self.replication.epidemic_writes() { - epidemic_propagate.push(new_entry.clone()); - } - - self.instance.updated(old_entry, Some(new_entry)).await?; - self.system - .background - .spawn_cancellable(syncer.clone().invalidate(tree_key)); - } - } - - if epidemic_propagate.len() > 0 { - let self2 = self.clone(); - self.system - .background - .spawn_cancellable(async move { self2.insert_many(&epidemic_propagate[..]).await }); - } - - Ok(()) - } - - pub async fn delete_range(&self, begin: &Hash, end: &Hash) -> Result<(), Error> { - let syncer = self.syncer.load_full().unwrap(); - - debug!("({}) Deleting range {:?} - {:?}", self.name, begin, end); - let mut count = 0; - while let Some((key, _value)) = self.store.get_lt(end.as_slice())? { - if key.as_ref() < begin.as_slice() { - break; - } - if let Some(old_val) = self.store.remove(&key)? { - let old_entry = rmp_serde::decode::from_read_ref::<_, F::E>(&old_val)?; - self.instance.updated(Some(old_entry), None).await?; - self.system - .background - .spawn_cancellable(syncer.clone().invalidate(key.to_vec())); - count += 1; - } - } - debug!("({}) {} entries deleted", self.name, count); - Ok(()) - } - - fn tree_key(&self, p: &F::P, s: &F::S) -> Vec { - let mut ret = p.hash().to_vec(); - ret.extend(s.sort_key()); - ret - } -} diff --git a/src/table/mod.rs b/src/table/mod.rs new file mode 100644 index 00000000..e03b8d0b --- /dev/null +++ b/src/table/mod.rs @@ -0,0 +1,6 @@ +pub mod table; +pub mod table_fullcopy; +pub mod table_sharded; +pub mod table_sync; + +pub use table::*; diff --git a/src/table/table.rs b/src/table/table.rs new file mode 100644 index 00000000..50e8739a --- /dev/null +++ b/src/table/table.rs @@ -0,0 +1,524 @@ +use std::collections::{BTreeMap, HashMap}; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use futures::stream::*; +use serde::{Deserialize, Serialize}; +use serde_bytes::ByteBuf; + +use crate::data::*; +use crate::error::Error; + +use crate::rpc::membership::{Ring, System}; +use crate::rpc::rpc_client::*; +use crate::rpc::rpc_server::*; + +use crate::table::table_sync::*; + +const TABLE_RPC_TIMEOUT: Duration = Duration::from_secs(10); + +pub struct Table { + pub instance: F, + pub replication: R, + + pub name: String, + pub rpc_client: Arc>>, + + pub system: Arc, + pub store: sled::Tree, + pub syncer: ArcSwapOption>, +} + +#[derive(Serialize, Deserialize)] +pub enum TableRPC { + Ok, + + ReadEntry(F::P, F::S), + ReadEntryResponse(Option), + + // Read range: read all keys in partition P, possibly starting at a certain sort key offset + ReadRange(F::P, Option, Option, usize), + + Update(Vec>), + + SyncRPC(SyncRPC), +} + +impl RpcMessage for TableRPC {} + +pub trait PartitionKey { + fn hash(&self) -> Hash; +} + +pub trait SortKey { + fn sort_key(&self) -> &[u8]; +} + +pub trait Entry: + PartialEq + Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync +{ + fn partition_key(&self) -> &P; + fn sort_key(&self) -> &S; + + fn merge(&mut self, other: &Self); +} + +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct EmptyKey; +impl SortKey for EmptyKey { + fn sort_key(&self) -> &[u8] { + &[] + } +} +impl PartitionKey for EmptyKey { + fn hash(&self) -> Hash { + [0u8; 32].into() + } +} + +impl> PartitionKey for T { + fn hash(&self) -> Hash { + hash(self.as_ref().as_bytes()) + } +} +impl> SortKey for T { + fn sort_key(&self) -> &[u8] { + self.as_ref().as_bytes() + } +} + +impl PartitionKey for Hash { + fn hash(&self) -> Hash { + self.clone() + } +} +impl SortKey for Hash { + fn sort_key(&self) -> &[u8] { + self.as_slice() + } +} + +#[async_trait] +pub trait TableSchema: Send + Sync { + type P: PartitionKey + Clone + PartialEq + Serialize + for<'de> Deserialize<'de> + Send + Sync; + type S: SortKey + Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync; + type E: Entry; + type Filter: Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync; + + async fn updated(&self, old: Option, new: Option) -> Result<(), Error>; + fn matches_filter(_entry: &Self::E, _filter: &Self::Filter) -> bool { + true + } +} + +pub trait TableReplication: Send + Sync { + // See examples in table_sharded.rs and table_fullcopy.rs + // To understand various replication methods + + // Which nodes to send reads from + fn read_nodes(&self, hash: &Hash, system: &System) -> Vec; + fn read_quorum(&self) -> usize; + + // Which nodes to send writes to + fn write_nodes(&self, hash: &Hash, system: &System) -> Vec; + fn write_quorum(&self) -> usize; + fn max_write_errors(&self) -> usize; + fn epidemic_writes(&self) -> bool; + + // Which are the nodes that do actually replicate the data + fn replication_nodes(&self, hash: &Hash, ring: &Ring) -> Vec; + fn split_points(&self, ring: &Ring) -> Vec; +} + +impl Table +where + F: TableSchema + 'static, + R: TableReplication + 'static, +{ + // =============== PUBLIC INTERFACE FUNCTIONS (new, insert, get, etc) =============== + + pub async fn new( + instance: F, + replication: R, + system: Arc, + db: &sled::Db, + name: String, + rpc_server: &mut RpcServer, + ) -> Arc { + let store = db.open_tree(&name).expect("Unable to open DB tree"); + + let rpc_path = format!("table_{}", name); + let rpc_client = system.rpc_client::>(&rpc_path); + + let table = Arc::new(Self { + instance, + replication, + name, + rpc_client, + system, + store, + syncer: ArcSwapOption::from(None), + }); + table.clone().register_handler(rpc_server, rpc_path); + + let syncer = TableSyncer::launch(table.clone()).await; + table.syncer.swap(Some(syncer)); + + table + } + + pub async fn insert(&self, e: &F::E) -> Result<(), Error> { + let hash = e.partition_key().hash(); + let who = self.replication.write_nodes(&hash, &self.system); + //eprintln!("insert who: {:?}", who); + + let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(e)?)); + let rpc = TableRPC::::Update(vec![e_enc]); + + self.rpc_client + .try_call_many( + &who[..], + rpc, + RequestStrategy::with_quorum(self.replication.write_quorum()) + .with_timeout(TABLE_RPC_TIMEOUT), + ) + .await?; + Ok(()) + } + + pub async fn insert_many(&self, entries: &[F::E]) -> Result<(), Error> { + let mut call_list = HashMap::new(); + + for entry in entries.iter() { + let hash = entry.partition_key().hash(); + let who = self.replication.write_nodes(&hash, &self.system); + let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(entry)?)); + for node in who { + if !call_list.contains_key(&node) { + call_list.insert(node, vec![]); + } + call_list.get_mut(&node).unwrap().push(e_enc.clone()); + } + } + + let call_futures = call_list.drain().map(|(node, entries)| async move { + let rpc = TableRPC::::Update(entries); + + let resp = self.rpc_client.call(node, rpc, TABLE_RPC_TIMEOUT).await?; + Ok::<_, Error>((node, resp)) + }); + let mut resps = call_futures.collect::>(); + let mut errors = vec![]; + + while let Some(resp) = resps.next().await { + if let Err(e) = resp { + errors.push(e); + } + } + if errors.len() > self.replication.max_write_errors() { + Err(Error::Message("Too many errors".into())) + } else { + Ok(()) + } + } + + pub async fn get( + self: &Arc, + partition_key: &F::P, + sort_key: &F::S, + ) -> Result, Error> { + let hash = partition_key.hash(); + let who = self.replication.read_nodes(&hash, &self.system); + //eprintln!("get who: {:?}", who); + + let rpc = TableRPC::::ReadEntry(partition_key.clone(), sort_key.clone()); + let resps = self + .rpc_client + .try_call_many( + &who[..], + rpc, + RequestStrategy::with_quorum(self.replication.read_quorum()) + .with_timeout(TABLE_RPC_TIMEOUT) + .interrupt_after_quorum(true), + ) + .await?; + + let mut ret = None; + let mut not_all_same = false; + for resp in resps { + if let TableRPC::ReadEntryResponse(value) = resp { + if let Some(v_bytes) = value { + let v = rmp_serde::decode::from_read_ref::<_, F::E>(v_bytes.as_slice())?; + ret = match ret { + None => Some(v), + Some(mut x) => { + if x != v { + not_all_same = true; + x.merge(&v); + } + Some(x) + } + } + } + } else { + return Err(Error::Message(format!("Invalid return value to read"))); + } + } + if let Some(ret_entry) = &ret { + if not_all_same { + let self2 = self.clone(); + let ent2 = ret_entry.clone(); + self.system + .background + .spawn_cancellable(async move { self2.repair_on_read(&who[..], ent2).await }); + } + } + Ok(ret) + } + + pub async fn get_range( + self: &Arc, + partition_key: &F::P, + begin_sort_key: Option, + filter: Option, + limit: usize, + ) -> Result, Error> { + let hash = partition_key.hash(); + let who = self.replication.read_nodes(&hash, &self.system); + + let rpc = TableRPC::::ReadRange(partition_key.clone(), begin_sort_key, filter, limit); + + let resps = self + .rpc_client + .try_call_many( + &who[..], + rpc, + RequestStrategy::with_quorum(self.replication.read_quorum()) + .with_timeout(TABLE_RPC_TIMEOUT) + .interrupt_after_quorum(true), + ) + .await?; + + let mut ret = BTreeMap::new(); + let mut to_repair = BTreeMap::new(); + for resp in resps { + if let TableRPC::Update(entries) = resp { + for entry_bytes in entries.iter() { + let entry = + rmp_serde::decode::from_read_ref::<_, F::E>(entry_bytes.as_slice())?; + let entry_key = self.tree_key(entry.partition_key(), entry.sort_key()); + match ret.remove(&entry_key) { + None => { + ret.insert(entry_key, Some(entry)); + } + Some(Some(mut prev)) => { + let must_repair = prev != entry; + prev.merge(&entry); + if must_repair { + to_repair.insert(entry_key.clone(), Some(prev.clone())); + } + ret.insert(entry_key, Some(prev)); + } + Some(None) => unreachable!(), + } + } + } + } + if !to_repair.is_empty() { + let self2 = self.clone(); + self.system.background.spawn_cancellable(async move { + for (_, v) in to_repair.iter_mut() { + self2.repair_on_read(&who[..], v.take().unwrap()).await?; + } + Ok(()) + }); + } + let ret_vec = ret + .iter_mut() + .take(limit) + .map(|(_k, v)| v.take().unwrap()) + .collect::>(); + Ok(ret_vec) + } + + // =============== UTILITY FUNCTION FOR CLIENT OPERATIONS =============== + + async fn repair_on_read(&self, who: &[UUID], what: F::E) -> Result<(), Error> { + let what_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(&what)?)); + self.rpc_client + .try_call_many( + &who[..], + TableRPC::::Update(vec![what_enc]), + RequestStrategy::with_quorum(who.len()).with_timeout(TABLE_RPC_TIMEOUT), + ) + .await?; + Ok(()) + } + + // =============== HANDLERS FOR RPC OPERATIONS (SERVER SIDE) ============== + + fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + let self2 = self.clone(); + rpc_server.add_handler::, _, _>(path, move |msg, _addr| { + let self2 = self2.clone(); + async move { self2.handle(&msg).await } + }); + + let self2 = self.clone(); + self.rpc_client + .set_local_handler(self.system.id, move |msg| { + let self2 = self2.clone(); + async move { self2.handle(&msg).await } + }); + } + + async fn handle(self: &Arc, msg: &TableRPC) -> Result, Error> { + match msg { + TableRPC::ReadEntry(key, sort_key) => { + let value = self.handle_read_entry(key, sort_key)?; + Ok(TableRPC::ReadEntryResponse(value)) + } + TableRPC::ReadRange(key, begin_sort_key, filter, limit) => { + let values = self.handle_read_range(key, begin_sort_key, filter, *limit)?; + Ok(TableRPC::Update(values)) + } + TableRPC::Update(pairs) => { + self.handle_update(pairs).await?; + Ok(TableRPC::Ok) + } + TableRPC::SyncRPC(rpc) => { + let syncer = self.syncer.load_full().unwrap(); + let response = syncer + .handle_rpc(rpc, self.system.background.stop_signal.clone()) + .await?; + Ok(TableRPC::SyncRPC(response)) + } + _ => Err(Error::BadRequest(format!("Unexpected table RPC"))), + } + } + + fn handle_read_entry(&self, p: &F::P, s: &F::S) -> Result, Error> { + let tree_key = self.tree_key(p, s); + if let Some(bytes) = self.store.get(&tree_key)? { + Ok(Some(ByteBuf::from(bytes.to_vec()))) + } else { + Ok(None) + } + } + + fn handle_read_range( + &self, + p: &F::P, + s: &Option, + filter: &Option, + limit: usize, + ) -> Result>, Error> { + let partition_hash = p.hash(); + let first_key = match s { + None => partition_hash.to_vec(), + Some(sk) => self.tree_key(p, sk), + }; + let mut ret = vec![]; + for item in self.store.range(first_key..) { + let (key, value) = item?; + if &key[..32] != partition_hash.as_slice() { + break; + } + let keep = match filter { + None => true, + Some(f) => { + let entry = rmp_serde::decode::from_read_ref::<_, F::E>(value.as_ref())?; + F::matches_filter(&entry, f) + } + }; + if keep { + ret.push(Arc::new(ByteBuf::from(value.as_ref()))); + } + if ret.len() >= limit { + break; + } + } + Ok(ret) + } + + pub async fn handle_update(self: &Arc, entries: &[Arc]) -> Result<(), Error> { + let syncer = self.syncer.load_full().unwrap(); + let mut epidemic_propagate = vec![]; + + for update_bytes in entries.iter() { + let update = rmp_serde::decode::from_read_ref::<_, F::E>(update_bytes.as_slice())?; + + let tree_key = self.tree_key(update.partition_key(), update.sort_key()); + + let (old_entry, new_entry) = self.store.transaction(|db| { + let (old_entry, new_entry) = match db.get(&tree_key)? { + Some(prev_bytes) => { + let old_entry = rmp_serde::decode::from_read_ref::<_, F::E>(&prev_bytes) + .map_err(Error::RMPDecode) + .map_err(sled::ConflictableTransactionError::Abort)?; + let mut new_entry = old_entry.clone(); + new_entry.merge(&update); + (Some(old_entry), new_entry) + } + None => (None, update.clone()), + }; + + let new_bytes = rmp_to_vec_all_named(&new_entry) + .map_err(Error::RMPEncode) + .map_err(sled::ConflictableTransactionError::Abort)?; + db.insert(tree_key.clone(), new_bytes)?; + Ok((old_entry, new_entry)) + })?; + + if old_entry.as_ref() != Some(&new_entry) { + if self.replication.epidemic_writes() { + epidemic_propagate.push(new_entry.clone()); + } + + self.instance.updated(old_entry, Some(new_entry)).await?; + self.system + .background + .spawn_cancellable(syncer.clone().invalidate(tree_key)); + } + } + + if epidemic_propagate.len() > 0 { + let self2 = self.clone(); + self.system + .background + .spawn_cancellable(async move { self2.insert_many(&epidemic_propagate[..]).await }); + } + + Ok(()) + } + + pub async fn delete_range(&self, begin: &Hash, end: &Hash) -> Result<(), Error> { + let syncer = self.syncer.load_full().unwrap(); + + debug!("({}) Deleting range {:?} - {:?}", self.name, begin, end); + let mut count = 0; + while let Some((key, _value)) = self.store.get_lt(end.as_slice())? { + if key.as_ref() < begin.as_slice() { + break; + } + if let Some(old_val) = self.store.remove(&key)? { + let old_entry = rmp_serde::decode::from_read_ref::<_, F::E>(&old_val)?; + self.instance.updated(Some(old_entry), None).await?; + self.system + .background + .spawn_cancellable(syncer.clone().invalidate(key.to_vec())); + count += 1; + } + } + debug!("({}) {} entries deleted", self.name, count); + Ok(()) + } + + fn tree_key(&self, p: &F::P, s: &F::S) -> Vec { + let mut ret = p.hash().to_vec(); + ret.extend(s.sort_key()); + ret + } +} diff --git a/src/table/table_fullcopy.rs b/src/table/table_fullcopy.rs new file mode 100644 index 00000000..2cd2e464 --- /dev/null +++ b/src/table/table_fullcopy.rs @@ -0,0 +1,100 @@ +use arc_swap::ArcSwapOption; +use std::sync::Arc; + +use crate::data::*; +use crate::rpc::membership::{Ring, System}; +use crate::table::*; + +#[derive(Clone)] +pub struct TableFullReplication { + pub write_factor: usize, + pub write_quorum: usize, + + neighbors: ArcSwapOption, +} + +#[derive(Clone)] +struct Neighbors { + ring: Arc, + neighbors: Vec, +} + +impl TableFullReplication { + pub fn new(write_factor: usize, write_quorum: usize) -> Self { + TableFullReplication { + write_factor, + write_quorum, + neighbors: ArcSwapOption::from(None), + } + } + + fn get_neighbors(&self, system: &System) -> Vec { + let neighbors = self.neighbors.load_full(); + if let Some(n) = neighbors { + if Arc::ptr_eq(&n.ring, &system.ring.borrow()) { + return n.neighbors.clone(); + } + } + + // Recalculate neighbors + let ring = system.ring.borrow().clone(); + let my_id = system.id; + + let mut nodes = vec![]; + for (node, _) in ring.config.members.iter() { + let node_ranking = hash(&[node.as_slice(), my_id.as_slice()].concat()); + nodes.push((*node, node_ranking)); + } + nodes.sort_by(|(_, rank1), (_, rank2)| rank1.cmp(rank2)); + let mut neighbors = nodes + .drain(..) + .map(|(node, _)| node) + .filter(|node| *node != my_id) + .take(self.write_factor) + .collect::>(); + neighbors.push(my_id); + self.neighbors.swap(Some(Arc::new(Neighbors { + ring, + neighbors: neighbors.clone(), + }))); + neighbors + } +} + +impl TableReplication for TableFullReplication { + // Full replication schema: all nodes store everything + // Writes are disseminated in an epidemic manner in the network + + // Advantage: do all reads locally, extremely fast + // Inconvenient: only suitable to reasonably small tables + + fn read_nodes(&self, _hash: &Hash, system: &System) -> Vec { + vec![system.id] + } + fn read_quorum(&self) -> usize { + 1 + } + + fn write_nodes(&self, _hash: &Hash, system: &System) -> Vec { + self.get_neighbors(system) + } + fn write_quorum(&self) -> usize { + self.write_quorum + } + fn max_write_errors(&self) -> usize { + self.write_factor - self.write_quorum + } + fn epidemic_writes(&self) -> bool { + true + } + + fn replication_nodes(&self, _hash: &Hash, ring: &Ring) -> Vec { + ring.config.members.keys().cloned().collect::>() + } + fn split_points(&self, _ring: &Ring) -> Vec { + let mut ret = vec![]; + ret.push([0u8; 32].into()); + ret.push([0xFFu8; 32].into()); + ret + } +} diff --git a/src/table/table_sharded.rs b/src/table/table_sharded.rs new file mode 100644 index 00000000..5190f5d4 --- /dev/null +++ b/src/table/table_sharded.rs @@ -0,0 +1,55 @@ +use crate::data::*; +use crate::rpc::membership::{Ring, System}; +use crate::table::*; + +#[derive(Clone)] +pub struct TableShardedReplication { + pub replication_factor: usize, + pub read_quorum: usize, + pub write_quorum: usize, +} + +impl TableReplication for TableShardedReplication { + // Sharded replication schema: + // - based on the ring of nodes, a certain set of neighbors + // store entries, given as a function of the position of the + // entry's hash in the ring + // - reads are done on all of the nodes that replicate the data + // - writes as well + + fn read_nodes(&self, hash: &Hash, system: &System) -> Vec { + let ring = system.ring.borrow().clone(); + ring.walk_ring(&hash, self.replication_factor) + } + fn read_quorum(&self) -> usize { + self.read_quorum + } + + fn write_nodes(&self, hash: &Hash, system: &System) -> Vec { + let ring = system.ring.borrow().clone(); + ring.walk_ring(&hash, self.replication_factor) + } + fn write_quorum(&self) -> usize { + self.write_quorum + } + fn max_write_errors(&self) -> usize { + self.replication_factor - self.write_quorum + } + fn epidemic_writes(&self) -> bool { + false + } + + fn replication_nodes(&self, hash: &Hash, ring: &Ring) -> Vec { + ring.walk_ring(&hash, self.replication_factor) + } + fn split_points(&self, ring: &Ring) -> Vec { + let mut ret = vec![]; + + ret.push([0u8; 32].into()); + for entry in ring.ring.iter() { + ret.push(entry.location); + } + ret.push([0xFFu8; 32].into()); + ret + } +} diff --git a/src/table/table_sync.rs b/src/table/table_sync.rs new file mode 100644 index 00000000..8f6582a7 --- /dev/null +++ b/src/table/table_sync.rs @@ -0,0 +1,791 @@ +use rand::Rng; +use std::collections::{BTreeMap, VecDeque}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use futures::future::BoxFuture; +use futures::{pin_mut, select}; +use futures_util::future::*; +use futures_util::stream::*; +use serde::{Deserialize, Serialize}; +use serde_bytes::ByteBuf; +use tokio::sync::Mutex; +use tokio::sync::{mpsc, watch}; + +use crate::data::*; +use crate::error::Error; +use crate::rpc::membership::Ring; +use crate::table::*; + +const MAX_DEPTH: usize = 16; +const SCAN_INTERVAL: Duration = Duration::from_secs(3600); +const CHECKSUM_CACHE_TIMEOUT: Duration = Duration::from_secs(1800); +const TABLE_SYNC_RPC_TIMEOUT: Duration = Duration::from_secs(30); + +pub struct TableSyncer { + table: Arc>, + todo: Mutex, + cache: Vec>>, +} + +#[derive(Serialize, Deserialize)] +pub enum SyncRPC { + GetRootChecksumRange(Hash, Hash), + RootChecksumRange(SyncRange), + Checksums(Vec, bool), + Difference(Vec, Vec>), +} + +pub struct SyncTodo { + todo: Vec, +} + +#[derive(Debug, Clone)] +struct TodoPartition { + begin: Hash, + end: Hash, + retain: bool, +} + +// A SyncRange defines a query on the dataset stored by a node, in the following way: +// - all items whose key are >= `begin` +// - stopping at the first item whose key hash has at least `level` leading zero bytes (excluded) +// - except if the first item of the range has such many leading zero bytes +// - and stopping at `end` (excluded) if such an item is not found +// The checksum itself does not store all of the items in the database, only the hashes of the "sub-ranges" +// i.e. of ranges of level `level-1` that cover the same range +// (ranges of level 0 do not exist and their hash is simply the hash of the first item >= begin) +// See RangeChecksum for the struct that stores this information. +#[derive(Hash, PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] +pub struct SyncRange { + begin: Vec, + end: Vec, + level: usize, +} + +impl std::cmp::PartialOrd for SyncRange { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl std::cmp::Ord for SyncRange { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.begin + .cmp(&other.begin) + .then(self.level.cmp(&other.level)) + .then(self.end.cmp(&other.end)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RangeChecksum { + bounds: SyncRange, + children: Vec<(SyncRange, Hash)>, + found_limit: Option>, + + #[serde(skip, default = "std::time::Instant::now")] + time: Instant, +} + +#[derive(Debug, Clone)] +pub struct RangeChecksumCache { + hash: Option, // None if no children + found_limit: Option>, + time: Instant, +} + +impl TableSyncer +where + F: TableSchema + 'static, + R: TableReplication + 'static, +{ + pub async fn launch(table: Arc>) -> Arc { + let todo = SyncTodo { todo: Vec::new() }; + let syncer = Arc::new(TableSyncer { + table: table.clone(), + todo: Mutex::new(todo), + cache: (0..MAX_DEPTH) + .map(|_| Mutex::new(BTreeMap::new())) + .collect::>(), + }); + + let (busy_tx, busy_rx) = mpsc::unbounded_channel(); + + let s1 = syncer.clone(); + table + .system + .background + .spawn_worker( + format!("table sync watcher for {}", table.name), + move |must_exit: watch::Receiver| s1.watcher_task(must_exit, busy_rx), + ) + .await; + + let s2 = syncer.clone(); + table + .system + .background + .spawn_worker( + format!("table syncer for {}", table.name), + move |must_exit: watch::Receiver| s2.syncer_task(must_exit, busy_tx), + ) + .await; + + let s3 = syncer.clone(); + tokio::spawn(async move { + tokio::time::delay_for(Duration::from_secs(20)).await; + s3.add_full_scan().await; + }); + + syncer + } + + async fn watcher_task( + self: Arc, + mut must_exit: watch::Receiver, + mut busy_rx: mpsc::UnboundedReceiver, + ) -> Result<(), Error> { + let mut prev_ring: Arc = self.table.system.ring.borrow().clone(); + let mut ring_recv: watch::Receiver> = self.table.system.ring.clone(); + let mut nothing_to_do_since = Some(Instant::now()); + + while !*must_exit.borrow() { + let s_ring_recv = ring_recv.recv().fuse(); + let s_busy = busy_rx.recv().fuse(); + let s_must_exit = must_exit.recv().fuse(); + let s_timeout = tokio::time::delay_for(Duration::from_secs(1)).fuse(); + pin_mut!(s_ring_recv, s_busy, s_must_exit, s_timeout); + + select! { + new_ring_r = s_ring_recv => { + if let Some(new_ring) = new_ring_r { + debug!("({}) Adding ring difference to syncer todo list", self.table.name); + self.todo.lock().await.add_ring_difference(&self.table, &prev_ring, &new_ring); + prev_ring = new_ring; + } + } + busy_opt = s_busy => { + if let Some(busy) = busy_opt { + if busy { + nothing_to_do_since = None; + } else { + if nothing_to_do_since.is_none() { + nothing_to_do_since = Some(Instant::now()); + } + } + } + } + must_exit_v = s_must_exit => { + if must_exit_v.unwrap_or(false) { + break; + } + } + _ = s_timeout => { + if nothing_to_do_since.map(|t| Instant::now() - t >= SCAN_INTERVAL).unwrap_or(false) { + nothing_to_do_since = None; + debug!("({}) Adding full scan to syncer todo list", self.table.name); + self.add_full_scan().await; + } + } + } + } + Ok(()) + } + + pub async fn add_full_scan(&self) { + self.todo.lock().await.add_full_scan(&self.table); + } + + async fn syncer_task( + self: Arc, + mut must_exit: watch::Receiver, + busy_tx: mpsc::UnboundedSender, + ) -> Result<(), Error> { + while !*must_exit.borrow() { + if let Some(partition) = self.todo.lock().await.pop_task() { + busy_tx.send(true)?; + let res = self + .clone() + .sync_partition(&partition, &mut must_exit) + .await; + if let Err(e) = res { + warn!( + "({}) Error while syncing {:?}: {}", + self.table.name, partition, e + ); + } + } else { + busy_tx.send(false)?; + tokio::time::delay_for(Duration::from_secs(1)).await; + } + } + Ok(()) + } + + async fn sync_partition( + self: Arc, + partition: &TodoPartition, + must_exit: &mut watch::Receiver, + ) -> Result<(), Error> { + let my_id = self.table.system.id; + let nodes = self + .table + .replication + .write_nodes(&partition.begin, &self.table.system) + .into_iter() + .filter(|node| *node != my_id) + .collect::>(); + + debug!( + "({}) Preparing to sync {:?} with {:?}...", + self.table.name, partition, nodes + ); + let root_cks = self + .root_checksum(&partition.begin, &partition.end, must_exit) + .await?; + + let mut sync_futures = nodes + .iter() + .map(|node| { + self.clone().do_sync_with( + partition.clone(), + root_cks.clone(), + *node, + partition.retain, + must_exit.clone(), + ) + }) + .collect::>(); + + let mut n_errors = 0; + while let Some(r) = sync_futures.next().await { + if let Err(e) = r { + n_errors += 1; + warn!("({}) Sync error: {}", self.table.name, e); + } + } + if n_errors > self.table.replication.max_write_errors() { + return Err(Error::Message(format!( + "Sync failed with too many nodes (should have been: {:?}).", + nodes + ))); + } + + if !partition.retain { + self.table + .delete_range(&partition.begin, &partition.end) + .await?; + } + + Ok(()) + } + + async fn root_checksum( + self: &Arc, + begin: &Hash, + end: &Hash, + must_exit: &mut watch::Receiver, + ) -> Result { + for i in 1..MAX_DEPTH { + let rc = self + .range_checksum( + &SyncRange { + begin: begin.to_vec(), + end: end.to_vec(), + level: i, + }, + must_exit, + ) + .await?; + if rc.found_limit.is_none() { + return Ok(rc); + } + } + Err(Error::Message(format!( + "Unable to compute root checksum (this should never happen)" + ))) + } + + async fn range_checksum( + self: &Arc, + range: &SyncRange, + must_exit: &mut watch::Receiver, + ) -> Result { + assert!(range.level != 0); + + if range.level == 1 { + let mut children = vec![]; + for item in self + .table + .store + .range(range.begin.clone()..range.end.clone()) + { + let (key, value) = item?; + let key_hash = hash(&key[..]); + if children.len() > 0 + && key_hash.as_slice()[0..range.level] + .iter() + .all(|x| *x == 0u8) + { + return Ok(RangeChecksum { + bounds: range.clone(), + children, + found_limit: Some(key.to_vec()), + time: Instant::now(), + }); + } + let item_range = SyncRange { + begin: key.to_vec(), + end: vec![], + level: 0, + }; + children.push((item_range, hash(&value[..]))); + } + Ok(RangeChecksum { + bounds: range.clone(), + children, + found_limit: None, + time: Instant::now(), + }) + } else { + let mut children = vec![]; + let mut sub_range = SyncRange { + begin: range.begin.clone(), + end: range.end.clone(), + level: range.level - 1, + }; + let mut time = Instant::now(); + while !*must_exit.borrow() { + let sub_ck = self + .range_checksum_cached_hash(&sub_range, must_exit) + .await?; + + if let Some(hash) = sub_ck.hash { + children.push((sub_range.clone(), hash)); + if sub_ck.time < time { + time = sub_ck.time; + } + } + + if sub_ck.found_limit.is_none() || sub_ck.hash.is_none() { + return Ok(RangeChecksum { + bounds: range.clone(), + children, + found_limit: None, + time, + }); + } + let found_limit = sub_ck.found_limit.unwrap(); + + let actual_limit_hash = hash(&found_limit[..]); + if actual_limit_hash.as_slice()[0..range.level] + .iter() + .all(|x| *x == 0u8) + { + return Ok(RangeChecksum { + bounds: range.clone(), + children, + found_limit: Some(found_limit.clone()), + time, + }); + } + + sub_range.begin = found_limit; + } + Err(Error::Message(format!("Exiting."))) + } + } + + fn range_checksum_cached_hash<'a>( + self: &'a Arc, + range: &'a SyncRange, + must_exit: &'a mut watch::Receiver, + ) -> BoxFuture<'a, Result> { + async move { + let mut cache = self.cache[range.level].lock().await; + if let Some(v) = cache.get(&range) { + if Instant::now() - v.time < CHECKSUM_CACHE_TIMEOUT { + return Ok(v.clone()); + } + } + cache.remove(&range); + drop(cache); + + let v = self.range_checksum(&range, must_exit).await?; + trace!( + "({}) New checksum calculated for {}-{}/{}, {} children", + self.table.name, + hex::encode(&range.begin) + .chars() + .take(16) + .collect::(), + hex::encode(&range.end).chars().take(16).collect::(), + range.level, + v.children.len() + ); + + let hash = if v.children.len() > 0 { + Some(hash(&rmp_to_vec_all_named(&v)?[..])) + } else { + None + }; + let cache_entry = RangeChecksumCache { + hash, + found_limit: v.found_limit, + time: v.time, + }; + + let mut cache = self.cache[range.level].lock().await; + cache.insert(range.clone(), cache_entry.clone()); + Ok(cache_entry) + } + .boxed() + } + + async fn do_sync_with( + self: Arc, + partition: TodoPartition, + root_ck: RangeChecksum, + who: UUID, + retain: bool, + mut must_exit: watch::Receiver, + ) -> Result<(), Error> { + let mut todo = VecDeque::new(); + + // If their root checksum has level > than us, use that as a reference + let root_cks_resp = self + .table + .rpc_client + .call( + who, + TableRPC::::SyncRPC(SyncRPC::GetRootChecksumRange( + partition.begin.clone(), + partition.end.clone(), + )), + TABLE_SYNC_RPC_TIMEOUT, + ) + .await?; + if let TableRPC::::SyncRPC(SyncRPC::RootChecksumRange(range)) = root_cks_resp { + if range.level > root_ck.bounds.level { + let their_root_range_ck = self.range_checksum(&range, &mut must_exit).await?; + todo.push_back(their_root_range_ck); + } else { + todo.push_back(root_ck); + } + } else { + return Err(Error::BadRequest(format!( + "Invalid respone to GetRootChecksumRange RPC: {}", + debug_serialize(root_cks_resp) + ))); + } + + while !todo.is_empty() && !*must_exit.borrow() { + let total_children = todo.iter().map(|x| x.children.len()).fold(0, |x, y| x + y); + trace!( + "({}) Sync with {:?}: {} ({}) remaining", + self.table.name, + who, + todo.len(), + total_children + ); + + let step_size = std::cmp::min(16, todo.len()); + let step = todo.drain(..step_size).collect::>(); + + let rpc_resp = self + .table + .rpc_client + .call( + who, + TableRPC::::SyncRPC(SyncRPC::Checksums(step, retain)), + TABLE_SYNC_RPC_TIMEOUT, + ) + .await?; + if let TableRPC::::SyncRPC(SyncRPC::Difference(mut diff_ranges, diff_items)) = + rpc_resp + { + if diff_ranges.len() > 0 || diff_items.len() > 0 { + info!( + "({}) Sync with {:?}: difference {} ranges, {} items", + self.table.name, + who, + diff_ranges.len(), + diff_items.len() + ); + } + let mut items_to_send = vec![]; + for differing in diff_ranges.drain(..) { + if differing.level == 0 { + items_to_send.push(differing.begin); + } else { + let checksum = self.range_checksum(&differing, &mut must_exit).await?; + todo.push_back(checksum); + } + } + if retain && diff_items.len() > 0 { + self.table.handle_update(&diff_items[..]).await?; + } + if items_to_send.len() > 0 { + self.send_items(who, items_to_send).await?; + } + } else { + return Err(Error::BadRequest(format!( + "Unexpected response to sync RPC checksums: {}", + debug_serialize(&rpc_resp) + ))); + } + } + Ok(()) + } + + async fn send_items(&self, who: UUID, item_list: Vec>) -> Result<(), Error> { + info!( + "({}) Sending {} items to {:?}", + self.table.name, + item_list.len(), + who + ); + + let mut values = vec![]; + for item in item_list.iter() { + if let Some(v) = self.table.store.get(&item[..])? { + values.push(Arc::new(ByteBuf::from(v.as_ref()))); + } + } + let rpc_resp = self + .table + .rpc_client + .call(who, TableRPC::::Update(values), TABLE_SYNC_RPC_TIMEOUT) + .await?; + if let TableRPC::::Ok = rpc_resp { + Ok(()) + } else { + Err(Error::Message(format!( + "Unexpected response to RPC Update: {}", + debug_serialize(&rpc_resp) + ))) + } + } + + pub async fn handle_rpc( + self: &Arc, + message: &SyncRPC, + mut must_exit: watch::Receiver, + ) -> Result { + match message { + SyncRPC::GetRootChecksumRange(begin, end) => { + let root_cks = self.root_checksum(&begin, &end, &mut must_exit).await?; + Ok(SyncRPC::RootChecksumRange(root_cks.bounds)) + } + SyncRPC::Checksums(checksums, retain) => { + self.handle_checksums_rpc(&checksums[..], *retain, &mut must_exit) + .await + } + _ => Err(Error::Message(format!("Unexpected sync RPC"))), + } + } + + async fn handle_checksums_rpc( + self: &Arc, + checksums: &[RangeChecksum], + retain: bool, + must_exit: &mut watch::Receiver, + ) -> Result { + let mut ret_ranges = vec![]; + let mut ret_items = vec![]; + + for their_ckr in checksums.iter() { + let our_ckr = self.range_checksum(&their_ckr.bounds, must_exit).await?; + for (their_range, their_hash) in their_ckr.children.iter() { + let differs = match our_ckr + .children + .binary_search_by(|(our_range, _)| our_range.cmp(&their_range)) + { + Err(_) => { + if their_range.level >= 1 { + let cached_hash = self + .range_checksum_cached_hash(&their_range, must_exit) + .await?; + cached_hash.hash.map(|h| h != *their_hash).unwrap_or(true) + } else { + true + } + } + Ok(i) => our_ckr.children[i].1 != *their_hash, + }; + if differs { + ret_ranges.push(their_range.clone()); + if retain && their_range.level == 0 { + if let Some(item_bytes) = + self.table.store.get(their_range.begin.as_slice())? + { + ret_items.push(Arc::new(ByteBuf::from(item_bytes.to_vec()))); + } + } + } + } + for (our_range, _hash) in our_ckr.children.iter() { + if let Some(their_found_limit) = &their_ckr.found_limit { + if our_range.begin.as_slice() > their_found_limit.as_slice() { + break; + } + } + + let not_present = our_ckr + .children + .binary_search_by(|(their_range, _)| their_range.cmp(&our_range)) + .is_err(); + if not_present { + if our_range.level > 0 { + ret_ranges.push(our_range.clone()); + } + if retain && our_range.level == 0 { + if let Some(item_bytes) = + self.table.store.get(our_range.begin.as_slice())? + { + ret_items.push(Arc::new(ByteBuf::from(item_bytes.to_vec()))); + } + } + } + } + } + let n_checksums = checksums + .iter() + .map(|x| x.children.len()) + .fold(0, |x, y| x + y); + if ret_ranges.len() > 0 || ret_items.len() > 0 { + trace!( + "({}) Checksum comparison RPC: {} different + {} items for {} received", + self.table.name, + ret_ranges.len(), + ret_items.len(), + n_checksums + ); + } + Ok(SyncRPC::Difference(ret_ranges, ret_items)) + } + + pub async fn invalidate(self: Arc, item_key: Vec) -> Result<(), Error> { + for i in 1..MAX_DEPTH { + let needle = SyncRange { + begin: item_key.to_vec(), + end: vec![], + level: i, + }; + let mut cache = self.cache[i].lock().await; + if let Some(cache_entry) = cache.range(..=needle).rev().next() { + if cache_entry.0.begin <= item_key && cache_entry.0.end > item_key { + let index = cache_entry.0.clone(); + drop(cache_entry); + cache.remove(&index); + } + } + } + Ok(()) + } +} + +impl SyncTodo { + fn add_full_scan(&mut self, table: &Table) { + let my_id = table.system.id; + + self.todo.clear(); + + let ring = table.system.ring.borrow().clone(); + let split_points = table.replication.split_points(&ring); + + for i in 0..split_points.len() - 1 { + let begin = split_points[i]; + let end = split_points[i + 1]; + let nodes = table.replication.replication_nodes(&begin, &ring); + + let retain = nodes.contains(&my_id); + if !retain { + // Check if we have some data to send, otherwise skip + if table.store.range(begin..end).next().is_none() { + continue; + } + } + + self.todo.push(TodoPartition { begin, end, retain }); + } + } + + fn add_ring_difference( + &mut self, + table: &Table, + old_ring: &Ring, + new_ring: &Ring, + ) { + let my_id = table.system.id; + + // If it is us who are entering or leaving the system, + // initiate a full sync instead of incremental sync + if old_ring.config.members.contains_key(&my_id) + != new_ring.config.members.contains_key(&my_id) + { + self.add_full_scan(table); + return; + } + + let mut all_points = None + .into_iter() + .chain(table.replication.split_points(old_ring).drain(..)) + .chain(table.replication.split_points(new_ring).drain(..)) + .chain(self.todo.iter().map(|x| x.begin)) + .chain(self.todo.iter().map(|x| x.end)) + .collect::>(); + all_points.sort(); + all_points.dedup(); + + let mut old_todo = std::mem::replace(&mut self.todo, vec![]); + old_todo.sort_by(|x, y| x.begin.cmp(&y.begin)); + let mut new_todo = vec![]; + + for i in 0..all_points.len() - 1 { + let begin = all_points[i]; + let end = all_points[i + 1]; + let was_ours = table + .replication + .replication_nodes(&begin, &old_ring) + .contains(&my_id); + let is_ours = table + .replication + .replication_nodes(&begin, &new_ring) + .contains(&my_id); + + let was_todo = match old_todo.binary_search_by(|x| x.begin.cmp(&begin)) { + Ok(_) => true, + Err(j) => { + (j > 0 && old_todo[j - 1].begin < end && begin < old_todo[j - 1].end) + || (j < old_todo.len() + && old_todo[j].begin < end && begin < old_todo[j].end) + } + }; + if was_todo || (is_ours && !was_ours) || (was_ours && !is_ours) { + new_todo.push(TodoPartition { + begin, + end, + retain: is_ours, + }); + } + } + + self.todo = new_todo; + } + + fn pop_task(&mut self) -> Option { + if self.todo.is_empty() { + return None; + } + + let i = rand::thread_rng().gen_range::(0, self.todo.len()); + if i == self.todo.len() - 1 { + self.todo.pop() + } else { + let replacement = self.todo.pop().unwrap(); + let ret = std::mem::replace(&mut self.todo[i], replacement); + Some(ret) + } + } +} diff --git a/src/table_fullcopy.rs b/src/table_fullcopy.rs deleted file mode 100644 index 2fcf56db..00000000 --- a/src/table_fullcopy.rs +++ /dev/null @@ -1,100 +0,0 @@ -use arc_swap::ArcSwapOption; -use std::sync::Arc; - -use crate::data::*; -use crate::membership::{Ring, System}; -use crate::table::*; - -#[derive(Clone)] -pub struct TableFullReplication { - pub write_factor: usize, - pub write_quorum: usize, - - neighbors: ArcSwapOption, -} - -#[derive(Clone)] -struct Neighbors { - ring: Arc, - neighbors: Vec, -} - -impl TableFullReplication { - pub fn new(write_factor: usize, write_quorum: usize) -> Self { - TableFullReplication { - write_factor, - write_quorum, - neighbors: ArcSwapOption::from(None), - } - } - - fn get_neighbors(&self, system: &System) -> Vec { - let neighbors = self.neighbors.load_full(); - if let Some(n) = neighbors { - if Arc::ptr_eq(&n.ring, &system.ring.borrow()) { - return n.neighbors.clone(); - } - } - - // Recalculate neighbors - let ring = system.ring.borrow().clone(); - let my_id = system.id; - - let mut nodes = vec![]; - for (node, _) in ring.config.members.iter() { - let node_ranking = hash(&[node.as_slice(), my_id.as_slice()].concat()); - nodes.push((*node, node_ranking)); - } - nodes.sort_by(|(_, rank1), (_, rank2)| rank1.cmp(rank2)); - let mut neighbors = nodes - .drain(..) - .map(|(node, _)| node) - .filter(|node| *node != my_id) - .take(self.write_factor) - .collect::>(); - neighbors.push(my_id); - self.neighbors.swap(Some(Arc::new(Neighbors { - ring, - neighbors: neighbors.clone(), - }))); - neighbors - } -} - -impl TableReplication for TableFullReplication { - // Full replication schema: all nodes store everything - // Writes are disseminated in an epidemic manner in the network - - // Advantage: do all reads locally, extremely fast - // Inconvenient: only suitable to reasonably small tables - - fn read_nodes(&self, _hash: &Hash, system: &System) -> Vec { - vec![system.id] - } - fn read_quorum(&self) -> usize { - 1 - } - - fn write_nodes(&self, _hash: &Hash, system: &System) -> Vec { - self.get_neighbors(system) - } - fn write_quorum(&self) -> usize { - self.write_quorum - } - fn max_write_errors(&self) -> usize { - self.write_factor - self.write_quorum - } - fn epidemic_writes(&self) -> bool { - true - } - - fn replication_nodes(&self, _hash: &Hash, ring: &Ring) -> Vec { - ring.config.members.keys().cloned().collect::>() - } - fn split_points(&self, _ring: &Ring) -> Vec { - let mut ret = vec![]; - ret.push([0u8; 32].into()); - ret.push([0xFFu8; 32].into()); - ret - } -} diff --git a/src/table_sharded.rs b/src/table_sharded.rs deleted file mode 100644 index c17ea0d4..00000000 --- a/src/table_sharded.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::data::*; -use crate::membership::{Ring, System}; -use crate::table::*; - -#[derive(Clone)] -pub struct TableShardedReplication { - pub replication_factor: usize, - pub read_quorum: usize, - pub write_quorum: usize, -} - -impl TableReplication for TableShardedReplication { - // Sharded replication schema: - // - based on the ring of nodes, a certain set of neighbors - // store entries, given as a function of the position of the - // entry's hash in the ring - // - reads are done on all of the nodes that replicate the data - // - writes as well - - fn read_nodes(&self, hash: &Hash, system: &System) -> Vec { - let ring = system.ring.borrow().clone(); - ring.walk_ring(&hash, self.replication_factor) - } - fn read_quorum(&self) -> usize { - self.read_quorum - } - - fn write_nodes(&self, hash: &Hash, system: &System) -> Vec { - let ring = system.ring.borrow().clone(); - ring.walk_ring(&hash, self.replication_factor) - } - fn write_quorum(&self) -> usize { - self.write_quorum - } - fn max_write_errors(&self) -> usize { - self.replication_factor - self.write_quorum - } - fn epidemic_writes(&self) -> bool { - false - } - - fn replication_nodes(&self, hash: &Hash, ring: &Ring) -> Vec { - ring.walk_ring(&hash, self.replication_factor) - } - fn split_points(&self, ring: &Ring) -> Vec { - let mut ret = vec![]; - - ret.push([0u8; 32].into()); - for entry in ring.ring.iter() { - ret.push(entry.location); - } - ret.push([0xFFu8; 32].into()); - ret - } -} diff --git a/src/table_sync.rs b/src/table_sync.rs deleted file mode 100644 index 60d5c4df..00000000 --- a/src/table_sync.rs +++ /dev/null @@ -1,791 +0,0 @@ -use rand::Rng; -use std::collections::{BTreeMap, VecDeque}; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use futures::future::BoxFuture; -use futures::{pin_mut, select}; -use futures_util::future::*; -use futures_util::stream::*; -use serde::{Deserialize, Serialize}; -use serde_bytes::ByteBuf; -use tokio::sync::Mutex; -use tokio::sync::{mpsc, watch}; - -use crate::data::*; -use crate::error::Error; -use crate::membership::Ring; -use crate::table::*; - -const MAX_DEPTH: usize = 16; -const SCAN_INTERVAL: Duration = Duration::from_secs(3600); -const CHECKSUM_CACHE_TIMEOUT: Duration = Duration::from_secs(1800); -const TABLE_SYNC_RPC_TIMEOUT: Duration = Duration::from_secs(30); - -pub struct TableSyncer { - table: Arc>, - todo: Mutex, - cache: Vec>>, -} - -#[derive(Serialize, Deserialize)] -pub enum SyncRPC { - GetRootChecksumRange(Hash, Hash), - RootChecksumRange(SyncRange), - Checksums(Vec, bool), - Difference(Vec, Vec>), -} - -pub struct SyncTodo { - todo: Vec, -} - -#[derive(Debug, Clone)] -struct TodoPartition { - begin: Hash, - end: Hash, - retain: bool, -} - -// A SyncRange defines a query on the dataset stored by a node, in the following way: -// - all items whose key are >= `begin` -// - stopping at the first item whose key hash has at least `level` leading zero bytes (excluded) -// - except if the first item of the range has such many leading zero bytes -// - and stopping at `end` (excluded) if such an item is not found -// The checksum itself does not store all of the items in the database, only the hashes of the "sub-ranges" -// i.e. of ranges of level `level-1` that cover the same range -// (ranges of level 0 do not exist and their hash is simply the hash of the first item >= begin) -// See RangeChecksum for the struct that stores this information. -#[derive(Hash, PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] -pub struct SyncRange { - begin: Vec, - end: Vec, - level: usize, -} - -impl std::cmp::PartialOrd for SyncRange { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl std::cmp::Ord for SyncRange { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.begin - .cmp(&other.begin) - .then(self.level.cmp(&other.level)) - .then(self.end.cmp(&other.end)) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RangeChecksum { - bounds: SyncRange, - children: Vec<(SyncRange, Hash)>, - found_limit: Option>, - - #[serde(skip, default = "std::time::Instant::now")] - time: Instant, -} - -#[derive(Debug, Clone)] -pub struct RangeChecksumCache { - hash: Option, // None if no children - found_limit: Option>, - time: Instant, -} - -impl TableSyncer -where - F: TableSchema + 'static, - R: TableReplication + 'static, -{ - pub async fn launch(table: Arc>) -> Arc { - let todo = SyncTodo { todo: Vec::new() }; - let syncer = Arc::new(TableSyncer { - table: table.clone(), - todo: Mutex::new(todo), - cache: (0..MAX_DEPTH) - .map(|_| Mutex::new(BTreeMap::new())) - .collect::>(), - }); - - let (busy_tx, busy_rx) = mpsc::unbounded_channel(); - - let s1 = syncer.clone(); - table - .system - .background - .spawn_worker( - format!("table sync watcher for {}", table.name), - move |must_exit: watch::Receiver| s1.watcher_task(must_exit, busy_rx), - ) - .await; - - let s2 = syncer.clone(); - table - .system - .background - .spawn_worker( - format!("table syncer for {}", table.name), - move |must_exit: watch::Receiver| s2.syncer_task(must_exit, busy_tx), - ) - .await; - - let s3 = syncer.clone(); - tokio::spawn(async move { - tokio::time::delay_for(Duration::from_secs(20)).await; - s3.add_full_scan().await; - }); - - syncer - } - - async fn watcher_task( - self: Arc, - mut must_exit: watch::Receiver, - mut busy_rx: mpsc::UnboundedReceiver, - ) -> Result<(), Error> { - let mut prev_ring: Arc = self.table.system.ring.borrow().clone(); - let mut ring_recv: watch::Receiver> = self.table.system.ring.clone(); - let mut nothing_to_do_since = Some(Instant::now()); - - while !*must_exit.borrow() { - let s_ring_recv = ring_recv.recv().fuse(); - let s_busy = busy_rx.recv().fuse(); - let s_must_exit = must_exit.recv().fuse(); - let s_timeout = tokio::time::delay_for(Duration::from_secs(1)).fuse(); - pin_mut!(s_ring_recv, s_busy, s_must_exit, s_timeout); - - select! { - new_ring_r = s_ring_recv => { - if let Some(new_ring) = new_ring_r { - debug!("({}) Adding ring difference to syncer todo list", self.table.name); - self.todo.lock().await.add_ring_difference(&self.table, &prev_ring, &new_ring); - prev_ring = new_ring; - } - } - busy_opt = s_busy => { - if let Some(busy) = busy_opt { - if busy { - nothing_to_do_since = None; - } else { - if nothing_to_do_since.is_none() { - nothing_to_do_since = Some(Instant::now()); - } - } - } - } - must_exit_v = s_must_exit => { - if must_exit_v.unwrap_or(false) { - break; - } - } - _ = s_timeout => { - if nothing_to_do_since.map(|t| Instant::now() - t >= SCAN_INTERVAL).unwrap_or(false) { - nothing_to_do_since = None; - debug!("({}) Adding full scan to syncer todo list", self.table.name); - self.add_full_scan().await; - } - } - } - } - Ok(()) - } - - pub async fn add_full_scan(&self) { - self.todo.lock().await.add_full_scan(&self.table); - } - - async fn syncer_task( - self: Arc, - mut must_exit: watch::Receiver, - busy_tx: mpsc::UnboundedSender, - ) -> Result<(), Error> { - while !*must_exit.borrow() { - if let Some(partition) = self.todo.lock().await.pop_task() { - busy_tx.send(true)?; - let res = self - .clone() - .sync_partition(&partition, &mut must_exit) - .await; - if let Err(e) = res { - warn!( - "({}) Error while syncing {:?}: {}", - self.table.name, partition, e - ); - } - } else { - busy_tx.send(false)?; - tokio::time::delay_for(Duration::from_secs(1)).await; - } - } - Ok(()) - } - - async fn sync_partition( - self: Arc, - partition: &TodoPartition, - must_exit: &mut watch::Receiver, - ) -> Result<(), Error> { - let my_id = self.table.system.id; - let nodes = self - .table - .replication - .write_nodes(&partition.begin, &self.table.system) - .into_iter() - .filter(|node| *node != my_id) - .collect::>(); - - debug!( - "({}) Preparing to sync {:?} with {:?}...", - self.table.name, partition, nodes - ); - let root_cks = self - .root_checksum(&partition.begin, &partition.end, must_exit) - .await?; - - let mut sync_futures = nodes - .iter() - .map(|node| { - self.clone().do_sync_with( - partition.clone(), - root_cks.clone(), - *node, - partition.retain, - must_exit.clone(), - ) - }) - .collect::>(); - - let mut n_errors = 0; - while let Some(r) = sync_futures.next().await { - if let Err(e) = r { - n_errors += 1; - warn!("({}) Sync error: {}", self.table.name, e); - } - } - if n_errors > self.table.replication.max_write_errors() { - return Err(Error::Message(format!( - "Sync failed with too many nodes (should have been: {:?}).", - nodes - ))); - } - - if !partition.retain { - self.table - .delete_range(&partition.begin, &partition.end) - .await?; - } - - Ok(()) - } - - async fn root_checksum( - self: &Arc, - begin: &Hash, - end: &Hash, - must_exit: &mut watch::Receiver, - ) -> Result { - for i in 1..MAX_DEPTH { - let rc = self - .range_checksum( - &SyncRange { - begin: begin.to_vec(), - end: end.to_vec(), - level: i, - }, - must_exit, - ) - .await?; - if rc.found_limit.is_none() { - return Ok(rc); - } - } - Err(Error::Message(format!( - "Unable to compute root checksum (this should never happen)" - ))) - } - - async fn range_checksum( - self: &Arc, - range: &SyncRange, - must_exit: &mut watch::Receiver, - ) -> Result { - assert!(range.level != 0); - - if range.level == 1 { - let mut children = vec![]; - for item in self - .table - .store - .range(range.begin.clone()..range.end.clone()) - { - let (key, value) = item?; - let key_hash = hash(&key[..]); - if children.len() > 0 - && key_hash.as_slice()[0..range.level] - .iter() - .all(|x| *x == 0u8) - { - return Ok(RangeChecksum { - bounds: range.clone(), - children, - found_limit: Some(key.to_vec()), - time: Instant::now(), - }); - } - let item_range = SyncRange { - begin: key.to_vec(), - end: vec![], - level: 0, - }; - children.push((item_range, hash(&value[..]))); - } - Ok(RangeChecksum { - bounds: range.clone(), - children, - found_limit: None, - time: Instant::now(), - }) - } else { - let mut children = vec![]; - let mut sub_range = SyncRange { - begin: range.begin.clone(), - end: range.end.clone(), - level: range.level - 1, - }; - let mut time = Instant::now(); - while !*must_exit.borrow() { - let sub_ck = self - .range_checksum_cached_hash(&sub_range, must_exit) - .await?; - - if let Some(hash) = sub_ck.hash { - children.push((sub_range.clone(), hash)); - if sub_ck.time < time { - time = sub_ck.time; - } - } - - if sub_ck.found_limit.is_none() || sub_ck.hash.is_none() { - return Ok(RangeChecksum { - bounds: range.clone(), - children, - found_limit: None, - time, - }); - } - let found_limit = sub_ck.found_limit.unwrap(); - - let actual_limit_hash = hash(&found_limit[..]); - if actual_limit_hash.as_slice()[0..range.level] - .iter() - .all(|x| *x == 0u8) - { - return Ok(RangeChecksum { - bounds: range.clone(), - children, - found_limit: Some(found_limit.clone()), - time, - }); - } - - sub_range.begin = found_limit; - } - Err(Error::Message(format!("Exiting."))) - } - } - - fn range_checksum_cached_hash<'a>( - self: &'a Arc, - range: &'a SyncRange, - must_exit: &'a mut watch::Receiver, - ) -> BoxFuture<'a, Result> { - async move { - let mut cache = self.cache[range.level].lock().await; - if let Some(v) = cache.get(&range) { - if Instant::now() - v.time < CHECKSUM_CACHE_TIMEOUT { - return Ok(v.clone()); - } - } - cache.remove(&range); - drop(cache); - - let v = self.range_checksum(&range, must_exit).await?; - trace!( - "({}) New checksum calculated for {}-{}/{}, {} children", - self.table.name, - hex::encode(&range.begin) - .chars() - .take(16) - .collect::(), - hex::encode(&range.end).chars().take(16).collect::(), - range.level, - v.children.len() - ); - - let hash = if v.children.len() > 0 { - Some(hash(&rmp_to_vec_all_named(&v)?[..])) - } else { - None - }; - let cache_entry = RangeChecksumCache { - hash, - found_limit: v.found_limit, - time: v.time, - }; - - let mut cache = self.cache[range.level].lock().await; - cache.insert(range.clone(), cache_entry.clone()); - Ok(cache_entry) - } - .boxed() - } - - async fn do_sync_with( - self: Arc, - partition: TodoPartition, - root_ck: RangeChecksum, - who: UUID, - retain: bool, - mut must_exit: watch::Receiver, - ) -> Result<(), Error> { - let mut todo = VecDeque::new(); - - // If their root checksum has level > than us, use that as a reference - let root_cks_resp = self - .table - .rpc_client - .call( - who, - TableRPC::::SyncRPC(SyncRPC::GetRootChecksumRange( - partition.begin.clone(), - partition.end.clone(), - )), - TABLE_SYNC_RPC_TIMEOUT, - ) - .await?; - if let TableRPC::::SyncRPC(SyncRPC::RootChecksumRange(range)) = root_cks_resp { - if range.level > root_ck.bounds.level { - let their_root_range_ck = self.range_checksum(&range, &mut must_exit).await?; - todo.push_back(their_root_range_ck); - } else { - todo.push_back(root_ck); - } - } else { - return Err(Error::BadRequest(format!( - "Invalid respone to GetRootChecksumRange RPC: {}", - debug_serialize(root_cks_resp) - ))); - } - - while !todo.is_empty() && !*must_exit.borrow() { - let total_children = todo.iter().map(|x| x.children.len()).fold(0, |x, y| x + y); - trace!( - "({}) Sync with {:?}: {} ({}) remaining", - self.table.name, - who, - todo.len(), - total_children - ); - - let step_size = std::cmp::min(16, todo.len()); - let step = todo.drain(..step_size).collect::>(); - - let rpc_resp = self - .table - .rpc_client - .call( - who, - TableRPC::::SyncRPC(SyncRPC::Checksums(step, retain)), - TABLE_SYNC_RPC_TIMEOUT, - ) - .await?; - if let TableRPC::::SyncRPC(SyncRPC::Difference(mut diff_ranges, diff_items)) = - rpc_resp - { - if diff_ranges.len() > 0 || diff_items.len() > 0 { - info!( - "({}) Sync with {:?}: difference {} ranges, {} items", - self.table.name, - who, - diff_ranges.len(), - diff_items.len() - ); - } - let mut items_to_send = vec![]; - for differing in diff_ranges.drain(..) { - if differing.level == 0 { - items_to_send.push(differing.begin); - } else { - let checksum = self.range_checksum(&differing, &mut must_exit).await?; - todo.push_back(checksum); - } - } - if retain && diff_items.len() > 0 { - self.table.handle_update(&diff_items[..]).await?; - } - if items_to_send.len() > 0 { - self.send_items(who, items_to_send).await?; - } - } else { - return Err(Error::BadRequest(format!( - "Unexpected response to sync RPC checksums: {}", - debug_serialize(&rpc_resp) - ))); - } - } - Ok(()) - } - - async fn send_items(&self, who: UUID, item_list: Vec>) -> Result<(), Error> { - info!( - "({}) Sending {} items to {:?}", - self.table.name, - item_list.len(), - who - ); - - let mut values = vec![]; - for item in item_list.iter() { - if let Some(v) = self.table.store.get(&item[..])? { - values.push(Arc::new(ByteBuf::from(v.as_ref()))); - } - } - let rpc_resp = self - .table - .rpc_client - .call(who, TableRPC::::Update(values), TABLE_SYNC_RPC_TIMEOUT) - .await?; - if let TableRPC::::Ok = rpc_resp { - Ok(()) - } else { - Err(Error::Message(format!( - "Unexpected response to RPC Update: {}", - debug_serialize(&rpc_resp) - ))) - } - } - - pub async fn handle_rpc( - self: &Arc, - message: &SyncRPC, - mut must_exit: watch::Receiver, - ) -> Result { - match message { - SyncRPC::GetRootChecksumRange(begin, end) => { - let root_cks = self.root_checksum(&begin, &end, &mut must_exit).await?; - Ok(SyncRPC::RootChecksumRange(root_cks.bounds)) - } - SyncRPC::Checksums(checksums, retain) => { - self.handle_checksums_rpc(&checksums[..], *retain, &mut must_exit) - .await - } - _ => Err(Error::Message(format!("Unexpected sync RPC"))), - } - } - - async fn handle_checksums_rpc( - self: &Arc, - checksums: &[RangeChecksum], - retain: bool, - must_exit: &mut watch::Receiver, - ) -> Result { - let mut ret_ranges = vec![]; - let mut ret_items = vec![]; - - for their_ckr in checksums.iter() { - let our_ckr = self.range_checksum(&their_ckr.bounds, must_exit).await?; - for (their_range, their_hash) in their_ckr.children.iter() { - let differs = match our_ckr - .children - .binary_search_by(|(our_range, _)| our_range.cmp(&their_range)) - { - Err(_) => { - if their_range.level >= 1 { - let cached_hash = self - .range_checksum_cached_hash(&their_range, must_exit) - .await?; - cached_hash.hash.map(|h| h != *their_hash).unwrap_or(true) - } else { - true - } - } - Ok(i) => our_ckr.children[i].1 != *their_hash, - }; - if differs { - ret_ranges.push(their_range.clone()); - if retain && their_range.level == 0 { - if let Some(item_bytes) = - self.table.store.get(their_range.begin.as_slice())? - { - ret_items.push(Arc::new(ByteBuf::from(item_bytes.to_vec()))); - } - } - } - } - for (our_range, _hash) in our_ckr.children.iter() { - if let Some(their_found_limit) = &their_ckr.found_limit { - if our_range.begin.as_slice() > their_found_limit.as_slice() { - break; - } - } - - let not_present = our_ckr - .children - .binary_search_by(|(their_range, _)| their_range.cmp(&our_range)) - .is_err(); - if not_present { - if our_range.level > 0 { - ret_ranges.push(our_range.clone()); - } - if retain && our_range.level == 0 { - if let Some(item_bytes) = - self.table.store.get(our_range.begin.as_slice())? - { - ret_items.push(Arc::new(ByteBuf::from(item_bytes.to_vec()))); - } - } - } - } - } - let n_checksums = checksums - .iter() - .map(|x| x.children.len()) - .fold(0, |x, y| x + y); - if ret_ranges.len() > 0 || ret_items.len() > 0 { - trace!( - "({}) Checksum comparison RPC: {} different + {} items for {} received", - self.table.name, - ret_ranges.len(), - ret_items.len(), - n_checksums - ); - } - Ok(SyncRPC::Difference(ret_ranges, ret_items)) - } - - pub async fn invalidate(self: Arc, item_key: Vec) -> Result<(), Error> { - for i in 1..MAX_DEPTH { - let needle = SyncRange { - begin: item_key.to_vec(), - end: vec![], - level: i, - }; - let mut cache = self.cache[i].lock().await; - if let Some(cache_entry) = cache.range(..=needle).rev().next() { - if cache_entry.0.begin <= item_key && cache_entry.0.end > item_key { - let index = cache_entry.0.clone(); - drop(cache_entry); - cache.remove(&index); - } - } - } - Ok(()) - } -} - -impl SyncTodo { - fn add_full_scan(&mut self, table: &Table) { - let my_id = table.system.id; - - self.todo.clear(); - - let ring = table.system.ring.borrow().clone(); - let split_points = table.replication.split_points(&ring); - - for i in 0..split_points.len() - 1 { - let begin = split_points[i]; - let end = split_points[i + 1]; - let nodes = table.replication.replication_nodes(&begin, &ring); - - let retain = nodes.contains(&my_id); - if !retain { - // Check if we have some data to send, otherwise skip - if table.store.range(begin..end).next().is_none() { - continue; - } - } - - self.todo.push(TodoPartition { begin, end, retain }); - } - } - - fn add_ring_difference( - &mut self, - table: &Table, - old_ring: &Ring, - new_ring: &Ring, - ) { - let my_id = table.system.id; - - // If it is us who are entering or leaving the system, - // initiate a full sync instead of incremental sync - if old_ring.config.members.contains_key(&my_id) - != new_ring.config.members.contains_key(&my_id) - { - self.add_full_scan(table); - return; - } - - let mut all_points = None - .into_iter() - .chain(table.replication.split_points(old_ring).drain(..)) - .chain(table.replication.split_points(new_ring).drain(..)) - .chain(self.todo.iter().map(|x| x.begin)) - .chain(self.todo.iter().map(|x| x.end)) - .collect::>(); - all_points.sort(); - all_points.dedup(); - - let mut old_todo = std::mem::replace(&mut self.todo, vec![]); - old_todo.sort_by(|x, y| x.begin.cmp(&y.begin)); - let mut new_todo = vec![]; - - for i in 0..all_points.len() - 1 { - let begin = all_points[i]; - let end = all_points[i + 1]; - let was_ours = table - .replication - .replication_nodes(&begin, &old_ring) - .contains(&my_id); - let is_ours = table - .replication - .replication_nodes(&begin, &new_ring) - .contains(&my_id); - - let was_todo = match old_todo.binary_search_by(|x| x.begin.cmp(&begin)) { - Ok(_) => true, - Err(j) => { - (j > 0 && old_todo[j - 1].begin < end && begin < old_todo[j - 1].end) - || (j < old_todo.len() - && old_todo[j].begin < end && begin < old_todo[j].end) - } - }; - if was_todo || (is_ours && !was_ours) || (was_ours && !is_ours) { - new_todo.push(TodoPartition { - begin, - end, - retain: is_ours, - }); - } - } - - self.todo = new_todo; - } - - fn pop_task(&mut self) -> Option { - if self.todo.is_empty() { - return None; - } - - let i = rand::thread_rng().gen_range::(0, self.todo.len()); - if i == self.todo.len() - 1 { - self.todo.pop() - } else { - let replacement = self.todo.pop().unwrap(); - let ret = std::mem::replace(&mut self.todo[i], replacement); - Some(ret) - } - } -} diff --git a/src/tls_util.rs b/src/tls_util.rs deleted file mode 100644 index 52c52110..00000000 --- a/src/tls_util.rs +++ /dev/null @@ -1,139 +0,0 @@ -use core::future::Future; -use core::task::{Context, Poll}; -use std::pin::Pin; -use std::sync::Arc; -use std::{fs, io}; - -use futures_util::future::*; -use hyper::client::connect::Connection; -use hyper::client::HttpConnector; -use hyper::service::Service; -use hyper::Uri; -use hyper_rustls::MaybeHttpsStream; -use rustls::internal::pemfile; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_rustls::TlsConnector; -use webpki::DNSNameRef; - -use crate::error::Error; - -pub fn load_certs(filename: &str) -> Result, Error> { - let certfile = fs::File::open(&filename)?; - let mut reader = io::BufReader::new(certfile); - - let certs = pemfile::certs(&mut reader).map_err(|_| { - Error::Message(format!( - "Could not deecode certificates from file: {}", - filename - )) - })?; - - if certs.is_empty() { - return Err(Error::Message(format!( - "Invalid certificate file: {}", - filename - ))); - } - Ok(certs) -} - -pub fn load_private_key(filename: &str) -> Result { - let keyfile = fs::File::open(&filename)?; - let mut reader = io::BufReader::new(keyfile); - - let keys = pemfile::rsa_private_keys(&mut reader).map_err(|_| { - Error::Message(format!( - "Could not decode private key from file: {}", - filename - )) - })?; - - if keys.len() != 1 { - return Err(Error::Message(format!( - "Invalid private key file: {} ({} private keys)", - filename, - keys.len() - ))); - } - Ok(keys[0].clone()) -} - -// ---- AWFUL COPYPASTA FROM HYPER-RUSTLS connector.rs -// ---- ALWAYS USE `garage` AS HOSTNAME FOR TLS VERIFICATION - -#[derive(Clone)] -pub struct HttpsConnectorFixedDnsname { - http: T, - tls_config: Arc, - fixed_dnsname: &'static str, -} - -type BoxError = Box; - -impl HttpsConnectorFixedDnsname { - pub fn new(mut tls_config: rustls::ClientConfig, fixed_dnsname: &'static str) -> Self { - let mut http = HttpConnector::new(); - http.enforce_http(false); - tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - Self { - http, - tls_config: Arc::new(tls_config), - fixed_dnsname, - } - } -} - -impl Service for HttpsConnectorFixedDnsname -where - T: Service, - T::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static, - T::Future: Send + 'static, - T::Error: Into, -{ - type Response = MaybeHttpsStream; - type Error = BoxError; - - #[allow(clippy::type_complexity)] - type Future = - Pin, BoxError>> + Send>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - match self.http.poll_ready(cx) { - Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Pending => Poll::Pending, - } - } - - fn call(&mut self, dst: Uri) -> Self::Future { - let is_https = dst.scheme_str() == Some("https"); - - if !is_https { - let connecting_future = self.http.call(dst); - - let f = async move { - let tcp = connecting_future.await.map_err(Into::into)?; - - Ok(MaybeHttpsStream::Http(tcp)) - }; - f.boxed() - } else { - let cfg = self.tls_config.clone(); - let connecting_future = self.http.call(dst); - - let dnsname = - DNSNameRef::try_from_ascii_str(self.fixed_dnsname).expect("Invalid fixed dnsname"); - - let f = async move { - let tcp = connecting_future.await.map_err(Into::into)?; - let connector = TlsConnector::from(cfg); - let tls = connector - .connect(dnsname, tcp) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - Ok(MaybeHttpsStream::Https(tls)) - }; - f.boxed() - } - } -} diff --git a/src/version_table.rs b/src/version_table.rs deleted file mode 100644 index 74174dce..00000000 --- a/src/version_table.rs +++ /dev/null @@ -1,94 +0,0 @@ -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; - -use crate::background::BackgroundRunner; -use crate::data::*; -use crate::error::Error; -use crate::table::*; -use crate::table_sharded::*; - -use crate::block_ref_table::*; - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub struct Version { - // Primary key - pub uuid: UUID, - - // Actual data: the blocks for this version - pub deleted: bool, - pub blocks: Vec, - - // Back link to bucket+key so that we can figure if - // this was deleted later on - pub bucket: String, - pub key: String, -} - -#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] -pub struct VersionBlock { - pub offset: u64, - pub hash: Hash, -} - -impl Entry for Version { - fn partition_key(&self) -> &Hash { - &self.uuid - } - fn sort_key(&self) -> &EmptyKey { - &EmptyKey - } - - fn merge(&mut self, other: &Self) { - if other.deleted { - self.deleted = true; - self.blocks.clear(); - } else if !self.deleted { - for bi in other.blocks.iter() { - match self.blocks.binary_search_by(|x| x.offset.cmp(&bi.offset)) { - Ok(_) => (), - Err(pos) => { - self.blocks.insert(pos, bi.clone()); - } - } - } - } - } -} - -pub struct VersionTable { - pub background: Arc, - pub block_ref_table: Arc>, -} - -#[async_trait] -impl TableSchema for VersionTable { - type P = Hash; - type S = EmptyKey; - type E = Version; - type Filter = (); - - async fn updated(&self, old: Option, new: Option) -> Result<(), Error> { - let block_ref_table = self.block_ref_table.clone(); - if let (Some(old_v), Some(new_v)) = (old, new) { - // Propagate deletion of version blocks - if new_v.deleted && !old_v.deleted { - let deleted_block_refs = old_v - .blocks - .iter() - .map(|vb| BlockRef { - block: vb.hash, - version: old_v.uuid, - deleted: true, - }) - .collect::>(); - block_ref_table.insert_many(&deleted_block_refs[..]).await?; - } - } - Ok(()) - } - - fn matches_filter(entry: &Self::E, _filter: &Self::Filter) -> bool { - !entry.deleted - } -} -- cgit v1.2.3