aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2020-04-18 19:21:34 +0200
committerAlex Auvolat <alex@adnab.me>2020-04-18 19:21:34 +0200
commitf41583e1b731574b4bb13a20d4b3fd9fe3a899f5 (patch)
treea2c1d32284fa0dc30fdf5408afad8255d50e51f6
parent3f40ef149f6dd4d61ceb326b5691e186aec178c3 (diff)
downloadgarage-f41583e1b731574b4bb13a20d4b3fd9fe3a899f5.tar.gz
garage-f41583e1b731574b4bb13a20d4b3fd9fe3a899f5.zip
Massive RPC refactoring
-rw-r--r--src/api_server.rs15
-rw-r--r--src/block.rs143
-rw-r--r--src/main.rs14
-rw-r--r--src/membership.rs84
-rw-r--r--src/proto.rs19
-rw-r--r--src/rpc_client.rs247
-rw-r--r--src/rpc_server.rs310
-rw-r--r--src/server.rs45
-rw-r--r--src/table.rs122
-rw-r--r--src/table_sync.rs15
10 files changed, 570 insertions, 444 deletions
diff --git a/src/api_server.rs b/src/api_server.rs
index c6d52d16..f213b4dd 100644
--- a/src/api_server.rs
+++ b/src/api_server.rs
@@ -9,7 +9,6 @@ use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
-use crate::block::*;
use crate::data::*;
use crate::error::Error;
use crate::http_util::*;
@@ -151,7 +150,9 @@ async fn handle_put(
let mut next_offset = first_block.len();
let mut put_curr_version_block =
put_block_meta(garage.clone(), &version, 0, first_block_hash.clone());
- let mut put_curr_block = rpc_put_block(&garage.system, first_block_hash, first_block);
+ let mut put_curr_block = garage
+ .block_manager
+ .rpc_put_block(first_block_hash, first_block);
loop {
let (_, _, next_block) =
@@ -165,7 +166,7 @@ async fn handle_put(
next_offset as u64,
block_hash.clone(),
);
- put_curr_block = rpc_put_block(&garage.system, block_hash, block);
+ put_curr_block = garage.block_manager.rpc_put_block(block_hash, block);
next_offset += block_len;
} else {
break;
@@ -300,7 +301,7 @@ async fn handle_get(
Ok(resp_builder.body(body)?)
}
ObjectVersionData::FirstBlock(first_block_hash) => {
- let read_first_block = rpc_get_block(&garage.system, &first_block_hash);
+ let read_first_block = garage.block_manager.rpc_get_block(&first_block_hash);
let get_next_blocks = garage.version_table.get(&last_v.uuid, &EmptySortKey);
let (first_block, version) = futures::try_join!(read_first_block, get_next_blocks)?;
@@ -323,7 +324,11 @@ async fn handle_get(
if let Some(data) = data_opt {
Ok(Bytes::from(data))
} else {
- rpc_get_block(&garage.system, &hash).await.map(Bytes::from)
+ garage
+ .block_manager
+ .rpc_get_block(&hash)
+ .await
+ .map(Bytes::from)
}
}
})
diff --git a/src/block.rs b/src/block.rs
index 6add24b7..879cff2c 100644
--- a/src/block.rs
+++ b/src/block.rs
@@ -5,6 +5,7 @@ use std::time::Duration;
use arc_swap::ArcSwapOption;
use futures::future::*;
use futures::stream::*;
+use serde::{Deserialize, Serialize};
use tokio::fs;
use tokio::prelude::*;
use tokio::sync::{watch, Mutex};
@@ -15,22 +16,40 @@ use crate::error::Error;
use crate::membership::System;
use crate::proto::*;
use crate::rpc_client::*;
+use crate::rpc_server::*;
use crate::server::Garage;
const NEED_BLOCK_QUERY_TIMEOUT: Duration = Duration::from_secs(5);
const RESYNC_RETRY_TIMEOUT: Duration = Duration::from_secs(10);
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Message {
+ Ok,
+ GetBlock(Hash),
+ PutBlock(PutBlockMessage),
+ NeedBlockQuery(Hash),
+ NeedBlockReply(bool),
+}
+
+impl RpcMessage for Message {}
+
pub struct BlockManager {
pub data_dir: PathBuf,
pub rc: sled::Tree,
pub resync_queue: sled::Tree,
pub lock: Mutex<()>,
pub system: Arc<System>,
+ rpc_client: Arc<RpcClient<Message>>,
pub garage: ArcSwapOption<Garage>,
}
impl BlockManager {
- pub fn new(db: &sled::Db, data_dir: PathBuf, system: Arc<System>) -> Arc<Self> {
+ pub fn new(
+ db: &sled::Db,
+ data_dir: PathBuf,
+ system: Arc<System>,
+ rpc_server: &mut RpcServer,
+ ) -> Arc<Self> {
let rc = db
.open_tree("block_local_rc")
.expect("Unable to open block_local_rc tree");
@@ -40,14 +59,38 @@ impl BlockManager {
.open_tree("block_local_resync_queue")
.expect("Unable to open block_local_resync_queue tree");
- Arc::new(Self {
+ let rpc_path = "block_manager";
+ let rpc_client = system.rpc_client::<Message>(rpc_path);
+
+ let block_manager = Arc::new(Self {
rc,
resync_queue,
data_dir,
lock: Mutex::new(()),
system,
+ rpc_client,
garage: ArcSwapOption::from(None),
- })
+ });
+ block_manager
+ .clone()
+ .register_handler(rpc_server, rpc_path.into());
+ block_manager
+ }
+
+ fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) {
+ rpc_server.add_handler::<Message, _, _>(path, move |msg, _addr| {
+ let self2 = self.clone();
+ async move {
+ match msg {
+ Message::PutBlock(m) => self2.write_block(&m.hash, &m.data).await,
+ Message::GetBlock(h) => self2.read_block(&h).await,
+ Message::NeedBlockQuery(h) => {
+ self2.need_block(&h).await.map(Message::NeedBlockReply)
+ }
+ _ => Err(Error::Message(format!("Invalid RPC"))),
+ }
+ }
+ });
}
pub async fn spawn_background_worker(self: Arc<Self>) {
@@ -214,10 +257,11 @@ impl BlockManager {
if needed_by_others {
let ring = garage.system.ring.borrow().clone();
let who = ring.walk_ring(&hash, garage.system.config.data_replication_factor);
- let msg = Message::NeedBlockQuery(hash.clone());
- let who_needs_fut = who
- .iter()
- .map(|to| rpc_call(garage.system.clone(), to, &msg, NEED_BLOCK_QUERY_TIMEOUT));
+ let msg = Arc::new(Message::NeedBlockQuery(hash.clone()));
+ let who_needs_fut = who.iter().map(|to| {
+ self.rpc_client
+ .call(to, msg.clone(), NEED_BLOCK_QUERY_TIMEOUT)
+ });
let who_needs = join_all(who_needs_fut).await;
let mut need_nodes = vec![];
@@ -242,13 +286,10 @@ impl BlockManager {
if need_nodes.len() > 0 {
let put_block_message = self.read_block(hash).await?;
- let put_responses = rpc_call_many(
- garage.system.clone(),
- &need_nodes[..],
- put_block_message,
- BLOCK_RW_TIMEOUT,
- )
- .await;
+ let put_responses = self
+ .rpc_client
+ .call_many(&need_nodes[..], put_block_message, BLOCK_RW_TIMEOUT)
+ .await;
for resp in put_responses {
resp?;
}
@@ -262,12 +303,48 @@ impl BlockManager {
// TODO find a way to not do this if they are sending it to us
// Let's suppose this isn't an issue for now with the BLOCK_RW_TIMEOUT delay
// between the RC being incremented and this part being called.
- let block_data = rpc_get_block(&self.system, &hash).await?;
+ let block_data = self.rpc_get_block(&hash).await?;
self.write_block(hash, &block_data[..]).await?;
}
Ok(())
}
+
+ pub async fn rpc_get_block(&self, hash: &Hash) -> Result<Vec<u8>, Error> {
+ let ring = self.system.ring.borrow().clone();
+ let who = ring.walk_ring(&hash, self.system.config.data_replication_factor);
+ let msg = Arc::new(Message::GetBlock(hash.clone()));
+ let mut resp_stream = who
+ .iter()
+ .map(|to| self.rpc_client.call(to, msg.clone(), BLOCK_RW_TIMEOUT))
+ .collect::<FuturesUnordered<_>>();
+
+ while let Some(resp) = resp_stream.next().await {
+ if let Ok(Message::PutBlock(msg)) = resp {
+ if data::hash(&msg.data[..]) == *hash {
+ return Ok(msg.data);
+ }
+ }
+ }
+ Err(Error::Message(format!(
+ "Unable to read block {:?}: no valid blocks returned",
+ hash
+ )))
+ }
+
+ pub async fn rpc_put_block(&self, hash: Hash, data: Vec<u8>) -> Result<(), Error> {
+ let ring = self.system.ring.borrow().clone();
+ let who = ring.walk_ring(&hash, self.system.config.data_replication_factor);
+ self.rpc_client
+ .try_call_many(
+ &who[..],
+ Message::PutBlock(PutBlockMessage { hash, data }),
+ (self.system.config.data_replication_factor + 1) / 2,
+ BLOCK_RW_TIMEOUT,
+ )
+ .await?;
+ Ok(())
+ }
}
fn u64_from_bytes(bytes: &[u8]) -> u64 {
@@ -297,39 +374,3 @@ fn rc_merge(_key: &[u8], old: Option<&[u8]>, new: &[u8]) -> Option<Vec<u8>> {
Some(u64::to_be_bytes(new).to_vec())
}
}
-
-pub async fn rpc_get_block(system: &Arc<System>, hash: &Hash) -> Result<Vec<u8>, Error> {
- let ring = system.ring.borrow().clone();
- let who = ring.walk_ring(&hash, system.config.data_replication_factor);
- let msg = Message::GetBlock(hash.clone());
- let mut resp_stream = who
- .iter()
- .map(|to| rpc_call(system.clone(), to, &msg, BLOCK_RW_TIMEOUT))
- .collect::<FuturesUnordered<_>>();
-
- while let Some(resp) = resp_stream.next().await {
- if let Ok(Message::PutBlock(msg)) = resp {
- if data::hash(&msg.data[..]) == *hash {
- return Ok(msg.data);
- }
- }
- }
- Err(Error::Message(format!(
- "Unable to read block {:?}: no valid blocks returned",
- hash
- )))
-}
-
-pub async fn rpc_put_block(system: &Arc<System>, hash: Hash, data: Vec<u8>) -> Result<(), Error> {
- let ring = system.ring.borrow().clone();
- let who = ring.walk_ring(&hash, system.config.data_replication_factor);
- rpc_try_call_many(
- system.clone(),
- &who[..],
- Message::PutBlock(PutBlockMessage { hash, data }),
- (system.config.data_replication_factor + 1) / 2,
- BLOCK_RW_TIMEOUT,
- )
- .await?;
- Ok(())
-}
diff --git a/src/main.rs b/src/main.rs
index ebf97a29..84b8c2bc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -22,12 +22,14 @@ mod tls_util;
use std::collections::HashSet;
use std::net::SocketAddr;
use std::path::PathBuf;
+use std::sync::Arc;
use structopt::StructOpt;
use data::*;
use error::Error;
+use membership::Message;
use proto::*;
-use rpc_client::RpcClient;
+use rpc_client::*;
use server::TlsConfig;
#[derive(StructOpt, Debug)]
@@ -113,7 +115,9 @@ async fn main() {
}
};
- let rpc_cli = RpcClient::new(&tls_config).expect("Could not create RPC client");
+ let rpc_http_cli =
+ Arc::new(RpcHttpClient::new(&tls_config).expect("Could not create RPC client"));
+ let rpc_cli = RpcAddrClient::new(rpc_http_cli, "_membership".into());
let resp = match opt.cmd {
Command::Server(server_opt) => {
@@ -137,7 +141,7 @@ async fn main() {
}
}
-async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Error> {
+async fn cmd_status(rpc_cli: RpcAddrClient<Message>, rpc_host: SocketAddr) -> Result<(), Error> {
let status = match rpc_cli
.call(&rpc_host, &Message::PullStatus, DEFAULT_TIMEOUT)
.await?
@@ -196,7 +200,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro
}
async fn cmd_configure(
- rpc_cli: RpcClient,
+ rpc_cli: RpcAddrClient<Message>,
rpc_host: SocketAddr,
args: ConfigureOpt,
) -> Result<(), Error> {
@@ -249,7 +253,7 @@ async fn cmd_configure(
}
async fn cmd_remove(
- rpc_cli: RpcClient,
+ rpc_cli: RpcAddrClient<Message>,
rpc_host: SocketAddr,
args: RemoveOpt,
) -> Result<(), Error> {
diff --git a/src/membership.rs b/src/membership.rs
index 6d758c59..499637fb 100644
--- a/src/membership.rs
+++ b/src/membership.rs
@@ -10,6 +10,7 @@ use std::time::Duration;
use futures::future::join_all;
use futures::select;
use futures_util::future::*;
+use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::prelude::*;
use tokio::sync::watch;
@@ -20,17 +21,31 @@ use crate::data::*;
use crate::error::Error;
use crate::proto::*;
use crate::rpc_client::*;
+use crate::rpc_server::*;
use crate::server::Config;
const PING_INTERVAL: Duration = Duration::from_secs(10);
const PING_TIMEOUT: Duration = Duration::from_secs(2);
const MAX_FAILED_PINGS: usize = 3;
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Message {
+ Ok,
+ Ping(PingMessage),
+ PullStatus,
+ PullConfig,
+ AdvertiseNodesUp(Vec<AdvertisedNode>),
+ AdvertiseConfig(NetworkConfig),
+}
+
+impl RpcMessage for Message {}
+
pub struct System {
pub config: Config,
pub id: UUID,
- pub rpc_client: RpcClient,
+ pub rpc_http_client: Arc<RpcHttpClient>,
+ rpc_client: Arc<RpcClient<Message>>,
pub status: watch::Receiver<Arc<Status>>,
pub ring: watch::Receiver<Arc<Ring>>,
@@ -199,7 +214,12 @@ fn read_network_config(metadata_dir: &PathBuf) -> Result<NetworkConfig, Error> {
}
impl System {
- pub fn new(config: Config, id: UUID, background: Arc<BackgroundRunner>) -> Self {
+ pub fn new(
+ config: Config,
+ id: UUID,
+ background: Arc<BackgroundRunner>,
+ rpc_server: &mut RpcServer,
+ ) -> Arc<Self> {
let net_config = match read_network_config(&config.metadata_dir) {
Ok(x) => x,
Err(e) => {
@@ -228,17 +248,54 @@ impl System {
ring.rebuild_ring();
let (update_ring, ring) = watch::channel(Arc::new(ring));
- let rpc_client = RpcClient::new(&config.rpc_tls).expect("Could not create RPC client");
+ let rpc_http_client =
+ Arc::new(RpcHttpClient::new(&config.rpc_tls).expect("Could not create RPC client"));
+
+ let rpc_path = "_membership";
+ let rpc_client = RpcClient::new(
+ RpcAddrClient::<Message>::new(rpc_http_client.clone(), rpc_path.into()),
+ background.clone(),
+ status.clone(),
+ );
- System {
+ let sys = Arc::new(System {
config,
id,
+ rpc_http_client,
rpc_client,
status,
ring,
update_lock: Mutex::new((update_status, update_ring)),
background,
- }
+ });
+ sys.clone().register_handler(rpc_server, rpc_path.into());
+ sys
+ }
+
+ fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) {
+ rpc_server.add_handler::<Message, _, _>(path, move |msg, addr| {
+ let self2 = self.clone();
+ async move {
+ match msg {
+ Message::Ping(ping) => self2.handle_ping(&addr, &ping).await,
+
+ Message::PullStatus => self2.handle_pull_status(),
+ Message::PullConfig => self2.handle_pull_config(),
+ Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await,
+ Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await,
+
+ _ => Err(Error::Message(format!("Unexpected RPC message"))),
+ }
+ }
+ });
+ }
+
+ pub fn rpc_client<M: RpcMessage + 'static>(self: &Arc<Self>, path: &str) -> Arc<RpcClient<M>> {
+ RpcClient::new(
+ RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()),
+ self.background.clone(),
+ self.status.clone(),
+ )
}
async fn save_network_config(self: Arc<Self>) -> Result<(), Error> {
@@ -272,7 +329,7 @@ impl System {
.filter(|x| **x != self.id)
.cloned()
.collect::<Vec<_>>();
- rpc_call_many(self.clone(), &to[..], msg, timeout).await;
+ self.rpc_client.call_many(&to[..], msg, timeout).await;
}
pub async fn bootstrap(self: Arc<Self>) {
@@ -299,7 +356,10 @@ impl System {
(
id_option,
addr.clone(),
- sys.rpc_client.call(&addr, ping_msg_ref, PING_TIMEOUT).await,
+ sys.rpc_client
+ .by_addr()
+ .call(&addr, ping_msg_ref, PING_TIMEOUT)
+ .await,
)
}
}))
@@ -509,7 +569,10 @@ impl System {
peer: UUID,
) -> impl futures::future::Future<Output = ()> + Send + 'static {
async move {
- let resp = rpc_call(self.clone(), &peer, &Message::PullStatus, PING_TIMEOUT).await;
+ let resp = self
+ .rpc_client
+ .call(&peer, Message::PullStatus, PING_TIMEOUT)
+ .await;
if let Ok(Message::AdvertiseNodesUp(nodes)) = resp {
let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await;
}
@@ -517,7 +580,10 @@ impl System {
}
pub async fn pull_config(self: Arc<Self>, peer: UUID) {
- let resp = rpc_call(self.clone(), &peer, &Message::PullConfig, PING_TIMEOUT).await;
+ let resp = self
+ .rpc_client
+ .call(&peer, Message::PullConfig, PING_TIMEOUT)
+ .await;
if let Ok(Message::AdvertiseConfig(config)) = resp {
let _: Result<_, _> = self.handle_advertise_config(&config).await;
}
diff --git a/src/proto.rs b/src/proto.rs
index cf7ed1cc..d51aa36b 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -8,25 +8,6 @@ pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
pub const BLOCK_RW_TIMEOUT: Duration = Duration::from_secs(42);
#[derive(Debug, Serialize, Deserialize)]
-pub enum Message {
- Ok,
- Error(String),
-
- Ping(PingMessage),
- PullStatus,
- PullConfig,
- AdvertiseNodesUp(Vec<AdvertisedNode>),
- AdvertiseConfig(NetworkConfig),
-
- GetBlock(Hash),
- PutBlock(PutBlockMessage),
- NeedBlockQuery(Hash),
- NeedBlockReply(bool),
-
- TableRPC(String, #[serde(with = "serde_bytes")] Vec<u8>),
-}
-
-#[derive(Debug, Serialize, Deserialize)]
pub struct PingMessage {
pub id: UUID,
pub rpc_port: u16,
diff --git a/src/rpc_client.rs b/src/rpc_client.rs
index f8da778c..6d26d86a 100644
--- a/src/rpc_client.rs
+++ b/src/rpc_client.rs
@@ -1,4 +1,5 @@
use std::borrow::Borrow;
+use std::marker::PhantomData;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
@@ -9,110 +10,166 @@ use futures::stream::StreamExt;
use futures_util::future::FutureExt;
use hyper::client::{Client, HttpConnector};
use hyper::{Body, Method, Request, StatusCode};
+use tokio::sync::watch;
+use crate::background::*;
use crate::data::*;
use crate::error::Error;
-use crate::membership::System;
-use crate::proto::Message;
+use crate::membership::Status;
+use crate::rpc_server::RpcMessage;
use crate::server::*;
use crate::tls_util;
-pub async fn rpc_call_many(
- sys: Arc<System>,
- to: &[UUID],
- msg: Message,
- timeout: Duration,
-) -> Vec<Result<Message, Error>> {
- let msg = Arc::new(msg);
- let mut resp_stream = to
- .iter()
- .map(|to| rpc_call(sys.clone(), to, msg.clone(), timeout))
- .collect::<FuturesUnordered<_>>();
-
- let mut results = vec![];
- while let Some(resp) = resp_stream.next().await {
- results.push(resp);
- }
- results
+pub struct RpcClient<M: RpcMessage> {
+ status: watch::Receiver<Arc<Status>>,
+ background: Arc<BackgroundRunner>,
+
+ pub rpc_addr_client: RpcAddrClient<M>,
}
-pub async fn rpc_try_call_many(
- sys: Arc<System>,
- to: &[UUID],
- msg: Message,
- stop_after: usize,
- timeout: Duration,
-) -> Result<Vec<Message>, Error> {
- let sys2 = sys.clone();
- let msg = Arc::new(msg);
- let mut resp_stream = to
- .to_vec()
- .into_iter()
- .map(move |to| rpc_call(sys2.clone(), to.clone(), msg.clone(), timeout))
- .collect::<FuturesUnordered<_>>();
-
- let mut results = vec![];
- let mut errors = vec![];
-
- while let Some(resp) = resp_stream.next().await {
- match resp {
- Ok(msg) => {
- results.push(msg);
- if results.len() >= stop_after {
- break;
+impl<M: RpcMessage + 'static> RpcClient<M> {
+ pub fn new(
+ rac: RpcAddrClient<M>,
+ background: Arc<BackgroundRunner>,
+ status: watch::Receiver<Arc<Status>>,
+ ) -> Arc<Self> {
+ Arc::new(Self {
+ rpc_addr_client: rac,
+ background,
+ status,
+ })
+ }
+
+ pub fn by_addr(&self) -> &RpcAddrClient<M> {
+ &self.rpc_addr_client
+ }
+
+ pub async fn call<MB: Borrow<M>, N: Borrow<UUID>>(
+ &self,
+ to: N,
+ msg: MB,
+ timeout: Duration,
+ ) -> Result<M, Error> {
+ let addr = {
+ let status = self.status.borrow().clone();
+ match status.nodes.get(to.borrow()) {
+ Some(status) => status.addr.clone(),
+ None => {
+ return Err(Error::Message(format!(
+ "Peer ID not found: {:?}",
+ to.borrow()
+ )))
}
}
- Err(e) => {
- errors.push(e);
- }
+ };
+ self.rpc_addr_client.call(&addr, msg, timeout).await
+ }
+
+ pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec<Result<M, Error>> {
+ let msg = Arc::new(msg);
+ let mut resp_stream = to
+ .iter()
+ .map(|to| self.call(to, msg.clone(), timeout))
+ .collect::<FuturesUnordered<_>>();
+
+ let mut results = vec![];
+ while let Some(resp) = resp_stream.next().await {
+ results.push(resp);
}
+ results
}
- if results.len() >= stop_after {
- // Continue requests in background
- // TODO: make this optionnal (only usefull for write requests)
- sys.background.spawn(async move {
- resp_stream.collect::<Vec<_>>().await;
- Ok(())
- });
-
- Ok(results)
- } else {
- let mut msg = "Too many failures:".to_string();
- for e in errors {
- msg += &format!("\n{}", e);
+ pub async fn try_call_many(
+ self: &Arc<Self>,
+ to: &[UUID],
+ msg: M,
+ stop_after: usize,
+ timeout: Duration,
+ ) -> Result<Vec<M>, Error> {
+ let msg = Arc::new(msg);
+ let mut resp_stream = to
+ .to_vec()
+ .into_iter()
+ .map(|to| {
+ let self2 = self.clone();
+ let msg = msg.clone();
+ async move { self2.call(to.clone(), msg, timeout).await }
+ })
+ .collect::<FuturesUnordered<_>>();
+
+ let mut results = vec![];
+ let mut errors = vec![];
+
+ while let Some(resp) = resp_stream.next().await {
+ match resp {
+ Ok(msg) => {
+ results.push(msg);
+ if results.len() >= stop_after {
+ break;
+ }
+ }
+ Err(e) => {
+ errors.push(e);
+ }
+ }
+ }
+
+ if results.len() >= stop_after {
+ // Continue requests in background
+ // TODO: make this optionnal (only usefull for write requests)
+ self.clone().background.spawn(async move {
+ resp_stream.collect::<Vec<_>>().await;
+ Ok(())
+ });
+
+ Ok(results)
+ } else {
+ let mut msg = "Too many failures:".to_string();
+ for e in errors {
+ msg += &format!("\n{}", e);
+ }
+ Err(Error::Message(msg))
}
- Err(Error::Message(msg))
}
}
-pub async fn rpc_call<M: Borrow<Message>, N: Borrow<UUID>>(
- sys: Arc<System>,
- to: N,
- msg: M,
- timeout: Duration,
-) -> Result<Message, Error> {
- let addr = {
- let status = sys.status.borrow().clone();
- match status.nodes.get(to.borrow()) {
- Some(status) => status.addr.clone(),
- None => {
- return Err(Error::Message(format!(
- "Peer ID not found: {:?}",
- to.borrow()
- )))
- }
+pub struct RpcAddrClient<M: RpcMessage> {
+ phantom: PhantomData<M>,
+
+ pub http_client: Arc<RpcHttpClient>,
+ pub path: String,
+}
+
+impl<M: RpcMessage> RpcAddrClient<M> {
+ pub fn new(http_client: Arc<RpcHttpClient>, path: String) -> Self {
+ Self {
+ phantom: PhantomData::default(),
+ http_client: http_client,
+ path,
}
- };
- sys.rpc_client.call(&addr, msg, timeout).await
+ }
+
+ pub async fn call<MB>(
+ &self,
+ to_addr: &SocketAddr,
+ msg: MB,
+ timeout: Duration,
+ ) -> Result<M, Error>
+ where
+ MB: Borrow<M>,
+ {
+ self.http_client
+ .call(&self.path, to_addr, msg, timeout)
+ .await
+ }
}
-pub enum RpcClient {
+pub enum RpcHttpClient {
HTTP(Client<HttpConnector, hyper::Body>),
HTTPS(Client<tls_util::HttpsConnectorFixedDnsname<HttpConnector>, hyper::Body>),
}
-impl RpcClient {
+impl RpcHttpClient {
pub fn new(tls_config: &Option<TlsConfig>) -> Result<Self, Error> {
if let Some(cf) = tls_config {
let ca_certs = tls_util::load_certs(&cf.ca_cert)?;
@@ -130,21 +187,26 @@ impl RpcClient {
let connector =
tls_util::HttpsConnectorFixedDnsname::<HttpConnector>::new(config, "garage");
- Ok(RpcClient::HTTPS(Client::builder().build(connector)))
+ Ok(RpcHttpClient::HTTPS(Client::builder().build(connector)))
} else {
- Ok(RpcClient::HTTP(Client::new()))
+ Ok(RpcHttpClient::HTTP(Client::new()))
}
}
- pub async fn call<M: Borrow<Message>>(
+ async fn call<M, MB>(
&self,
+ path: &str,
to_addr: &SocketAddr,
- msg: M,
+ msg: MB,
timeout: Duration,
- ) -> Result<Message, Error> {
+ ) -> Result<M, Error>
+ where
+ MB: Borrow<M>,
+ M: RpcMessage,
+ {
let uri = match self {
- RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr),
- RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr),
+ RpcHttpClient::HTTP(_) => format!("http://{}/{}", to_addr, path),
+ RpcHttpClient::HTTPS(_) => format!("https://{}/{}", to_addr, path),
};
let req = Request::builder()
@@ -153,8 +215,8 @@ impl RpcClient {
.body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?;
let resp_fut = match self {
- RpcClient::HTTP(client) => client.request(req).fuse(),
- RpcClient::HTTPS(client) => client.request(req).fuse(),
+ RpcHttpClient::HTTP(client) => client.request(req).fuse(),
+ RpcHttpClient::HTTPS(client) => client.request(req).fuse(),
};
let resp = tokio::time::timeout(timeout, resp_fut)
.await?
@@ -168,11 +230,8 @@ impl RpcClient {
if resp.status() == StatusCode::OK {
let body = hyper::body::to_bytes(resp.into_body()).await?;
- let msg = rmp_serde::decode::from_read::<_, Message>(body.into_buf())?;
- match msg {
- Message::Error(e) => Err(Error::RPCError(e)),
- x => Ok(x),
- }
+ let msg = rmp_serde::decode::from_read::<_, Result<M, String>>(body.into_buf())?;
+ msg.map_err(Error::RPCError)
} else {
Err(Error::RPCError(format!("Status code {}", resp.status())))
}
diff --git a/src/rpc_server.rs b/src/rpc_server.rs
index 3410ab97..83f8ddc9 100644
--- a/src/rpc_server.rs
+++ b/src/rpc_server.rs
@@ -1,4 +1,6 @@
+use std::collections::HashMap;
use std::net::SocketAddr;
+use std::pin::Pin;
use std::sync::Arc;
use bytes::IntoBuf;
@@ -8,175 +10,197 @@ use futures_util::stream::*;
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
+use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
use crate::data::*;
use crate::error::Error;
-use crate::proto::Message;
-use crate::server::Garage;
+use crate::server::TlsConfig;
use crate::tls_util;
-fn err_to_msg(x: Result<Message, Error>) -> Message {
- match x {
- Err(e) => Message::Error(format!("{}", e)),
- Ok(msg) => msg,
- }
+pub trait RpcMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {}
+
+type ResponseFuture = Pin<Box<dyn Future<Output = Result<Response<Body>, Error>> + Send>>;
+type Handler = Box<dyn Fn(Request<Body>, SocketAddr) -> ResponseFuture + Send + Sync>;
+
+pub struct RpcServer {
+ pub bind_addr: SocketAddr,
+ pub tls_config: Option<TlsConfig>,
+
+ handlers: HashMap<String, Handler>,
}
-async fn handler(
- garage: Arc<Garage>,
+async fn handle_func<M, F, Fut>(
+ handler: Arc<F>,
req: Request<Body>,
- addr: SocketAddr,
-) -> Result<Response<Body>, Error> {
- if req.method() != &Method::POST {
- let mut bad_request = Response::default();
- *bad_request.status_mut() = StatusCode::BAD_REQUEST;
- return Ok(bad_request);
- }
-
+ sockaddr: SocketAddr,
+) -> Result<Response<Body>, Error>
+where
+ M: RpcMessage + 'static,
+ F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static,
+ Fut: Future<Output = Result<M, Error>> + Send + 'static,
+{
let whole_body = hyper::body::to_bytes(req.into_body()).await?;
- let msg = rmp_serde::decode::from_read::<_, Message>(whole_body.into_buf())?;
-
- // eprintln!(
- // "RPC from {}: {} ({} bytes)",
- // addr,
- // debug_serialize(&msg),
- // whole_body.len()
- // );
-
- let sys = garage.system.clone();
- let resp = err_to_msg(match msg {
- Message::Ping(ping) => sys.handle_ping(&addr, &ping).await,
-
- Message::PullStatus => sys.handle_pull_status(),
- Message::PullConfig => sys.handle_pull_config(),
- Message::AdvertiseNodesUp(adv) => sys.handle_advertise_nodes_up(&adv).await,
- Message::AdvertiseConfig(adv) => sys.handle_advertise_config(&adv).await,
-
- Message::PutBlock(m) => {
- // A RPC can be interrupted in the middle, however we don't want to write partial blocks,
- // which might happen if the write_block() future is cancelled in the middle.
- // To solve this, the write itself is in a spawned task that has its own separate lifetime,
- // and the request handler simply sits there waiting for the task to finish.
- // (if it's cancelled, that's not an issue)
- // (TODO FIXME except if garage happens to shut down at that point)
- let write_fut = async move { garage.block_manager.write_block(&m.hash, &m.data).await };
- tokio::spawn(write_fut).await?
+ let msg = rmp_serde::decode::from_read::<_, M>(whole_body.into_buf())?;
+ match handler(msg, sockaddr).await {
+ Ok(resp) => {
+ let resp_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Ok(resp))?;
+ Ok(Response::new(Body::from(resp_bytes)))
}
- Message::GetBlock(h) => garage.block_manager.read_block(&h).await,
- Message::NeedBlockQuery(h) => garage
- .block_manager
- .need_block(&h)
- .await
- .map(Message::NeedBlockReply),
-
- Message::TableRPC(table, msg) => {
- // Same trick for table RPCs than for PutBlock
- let op_fut = async move {
- if let Some(rpc_handler) = garage.table_rpc_handlers.get(&table) {
- rpc_handler
- .handle(&msg[..])
- .await
- .map(|rep| Message::TableRPC(table.to_string(), rep))
- } else {
- Ok(Message::Error(format!("Unknown table: {}", table)))
- }
- };
- tokio::spawn(op_fut).await?
+ Err(e) => {
+ let err_str = format!("{}", e);
+ let rep_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Err(err_str))?;
+ let mut err_response = Response::new(Body::from(rep_bytes));
+ *err_response.status_mut() = e.http_status_code();
+ Ok(err_response)
}
-
- _ => Ok(Message::Error(format!("Unexpected message: {:?}", msg))),
- });
-
- // eprintln!("reply to {}: {}", addr, debug_serialize(&resp));
-
- Ok(Response::new(Body::from(rmp_to_vec_all_named(&resp)?)))
+ }
}
-pub async fn run_rpc_server(
- garage: Arc<Garage>,
- shutdown_signal: impl Future<Output = ()>,
-) -> Result<(), Error> {
- let bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], garage.system.config.rpc_port).into();
+impl RpcServer {
+ pub fn new(bind_addr: SocketAddr, tls_config: Option<TlsConfig>) -> Self {
+ Self {
+ bind_addr,
+ tls_config,
+ handlers: HashMap::new(),
+ }
+ }
- if let Some(tls_config) = &garage.system.config.rpc_tls {
- let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?;
- let node_certs = tls_util::load_certs(&tls_config.node_cert)?;
- let node_key = tls_util::load_private_key(&tls_config.node_key)?;
+ pub fn add_handler<M, F, Fut>(&mut self, name: String, handler: F)
+ where
+ M: RpcMessage + 'static,
+ F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static,
+ Fut: Future<Output = Result<M, Error>> + Send + 'static,
+ {
+ let handler_arc = Arc::new(handler);
+ let handler = Box::new(move |req: Request<Body>, sockaddr: SocketAddr| {
+ let handler2 = handler_arc.clone();
+ let b: ResponseFuture = Box::pin(handle_func(handler2, req, sockaddr));
+ b
+ });
+ self.handlers.insert(name, handler);
+ }
- let mut ca_store = rustls::RootCertStore::empty();
- for crt in ca_certs.iter() {
- ca_store.add(crt)?;
+ async fn handler(
+ self: Arc<Self>,
+ req: Request<Body>,
+ addr: SocketAddr,
+ ) -> Result<Response<Body>, Error> {
+ if req.method() != &Method::POST {
+ let mut bad_request = Response::default();
+ *bad_request.status_mut() = StatusCode::BAD_REQUEST;
+ return Ok(bad_request);
}
- let mut config =
- rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store));
- config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?;
- let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config)));
-
- let mut listener = TcpListener::bind(&bind_addr).await?;
- let incoming = listener.incoming().filter_map(|socket| async {
- match socket {
- Ok(stream) => match tls_acceptor.clone().accept(stream).await {
- Ok(x) => Some(Ok::<_, hyper::Error>(x)),
- Err(e) => {
- eprintln!("RPC server TLS error: {}", e);
- None
- }
- },
- Err(_) => None,
+ let path = &req.uri().path()[1..];
+ let handler = match self.handlers.get(path) {
+ Some(h) => h,
+ None => {
+ let mut not_found = Response::default();
+ *not_found.status_mut() = StatusCode::NOT_FOUND;
+ return Ok(not_found);
}
- });
- let incoming = hyper::server::accept::from_stream(incoming);
-
- let service = make_service_fn(|conn: &TlsStream<TcpStream>| {
- let client_addr = conn
- .get_ref()
- .0
- .peer_addr()
- .unwrap_or(([0, 0, 0, 0], 0).into());
- let garage = garage.clone();
- async move {
- Ok::<_, Error>(service_fn(move |req: Request<Body>| {
- let garage = garage.clone();
- handler(garage, req, client_addr).map_err(|e| {
- eprintln!("RPC handler error: {}", e);
- e
- })
- }))
+ };
+
+ let resp_waiter = tokio::spawn(handler(req, addr));
+ match resp_waiter.await {
+ Err(_err) => {
+ let mut ise = Response::default();
+ *ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
+ Ok(ise)
}
- });
+ Ok(Err(err)) => {
+ let mut bad_request = Response::new(Body::from(format!("{}", err)));
+ *bad_request.status_mut() = StatusCode::BAD_REQUEST;
+ Ok(bad_request)
+ }
+ Ok(Ok(resp)) => Ok(resp),
+ }
+ }
- let server = Server::builder(incoming).serve(service);
-
- let graceful = server.with_graceful_shutdown(shutdown_signal);
- println!("RPC server listening on http://{}", bind_addr);
-
- graceful.await?;
- } else {
- let service = make_service_fn(|conn: &AddrStream| {
- let client_addr = conn.remote_addr();
- let garage = garage.clone();
- async move {
- Ok::<_, Error>(service_fn(move |req: Request<Body>| {
- let garage = garage.clone();
- handler(garage, req, client_addr).map_err(|e| {
- eprintln!("RPC handler error: {}", e);
- e
- })
- }))
+ pub async fn run(
+ self: Arc<Self>,
+ shutdown_signal: impl Future<Output = ()>,
+ ) -> Result<(), Error> {
+ if let Some(tls_config) = self.tls_config.as_ref() {
+ let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?;
+ let node_certs = tls_util::load_certs(&tls_config.node_cert)?;
+ let node_key = tls_util::load_private_key(&tls_config.node_key)?;
+
+ let mut ca_store = rustls::RootCertStore::empty();
+ for crt in ca_certs.iter() {
+ ca_store.add(crt)?;
}
- });
- let server = Server::bind(&bind_addr).serve(service);
+ let mut config =
+ rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store));
+ config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?;
+ let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config)));
+
+ let mut listener = TcpListener::bind(&self.bind_addr).await?;
+ let incoming = listener.incoming().filter_map(|socket| async {
+ match socket {
+ Ok(stream) => match tls_acceptor.clone().accept(stream).await {
+ Ok(x) => Some(Ok::<_, hyper::Error>(x)),
+ Err(e) => {
+ eprintln!("RPC server TLS error: {}", e);
+ None
+ }
+ },
+ Err(_) => None,
+ }
+ });
+ let incoming = hyper::server::accept::from_stream(incoming);
+
+ let self_arc = self.clone();
+ let service = make_service_fn(|conn: &TlsStream<TcpStream>| {
+ let client_addr = conn
+ .get_ref()
+ .0
+ .peer_addr()
+ .unwrap_or(([0, 0, 0, 0], 0).into());
+ let self_arc = self_arc.clone();
+ async move {
+ Ok::<_, Error>(service_fn(move |req: Request<Body>| {
+ self_arc.clone().handler(req, client_addr).map_err(|e| {
+ eprintln!("RPC handler error: {}", e);
+ e
+ })
+ }))
+ }
+ });
+
+ let server = Server::builder(incoming).serve(service);
+
+ let graceful = server.with_graceful_shutdown(shutdown_signal);
+ println!("RPC server listening on http://{}", self.bind_addr);
+
+ graceful.await?;
+ } else {
+ let self_arc = self.clone();
+ let service = make_service_fn(move |conn: &AddrStream| {
+ let client_addr = conn.remote_addr();
+ let self_arc = self_arc.clone();
+ async move {
+ Ok::<_, Error>(service_fn(move |req: Request<Body>| {
+ self_arc.clone().handler(req, client_addr).map_err(|e| {
+ eprintln!("RPC handler error: {}", e);
+ e
+ })
+ }))
+ }
+ });
- let graceful = server.with_graceful_shutdown(shutdown_signal);
- println!("RPC server listening on http://{}", bind_addr);
+ let server = Server::bind(&self.bind_addr).serve(service);
- graceful.await?;
- }
+ let graceful = server.with_graceful_shutdown(shutdown_signal);
+ println!("RPC server listening on http://{}", self.bind_addr);
- Ok(())
+ graceful.await?;
+ }
+
+ Ok(())
+ }
}
diff --git a/src/server.rs b/src/server.rs
index 591a7bf9..57faea21 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,4 +1,3 @@
-use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::path::PathBuf;
@@ -15,7 +14,7 @@ use crate::data::*;
use crate::error::Error;
use crate::membership::System;
use crate::proto::*;
-use crate::rpc_server;
+use crate::rpc_server::RpcServer;
use crate::table::*;
#[derive(Deserialize, Debug, Clone)]
@@ -53,8 +52,6 @@ pub struct Garage {
pub system: Arc<System>,
pub block_manager: Arc<BlockManager>,
- pub table_rpc_handlers: HashMap<String, Box<dyn TableRpcHandler + Sync + Send>>,
-
pub object_table: Arc<Table<ObjectTable>>,
pub version_table: Arc<Table<VersionTable>>,
pub block_ref_table: Arc<Table<BlockRefTable>>,
@@ -66,12 +63,14 @@ impl Garage {
id: UUID,
db: sled::Db,
background: Arc<BackgroundRunner>,
+ rpc_server: &mut RpcServer,
) -> Arc<Self> {
println!("Initialize membership management system...");
- let system = Arc::new(System::new(config.clone(), id, background.clone()));
+ let system = System::new(config.clone(), id, background.clone(), rpc_server);
println!("Initialize block manager...");
- let block_manager = BlockManager::new(&db, config.data_dir.clone(), system.clone());
+ let block_manager =
+ BlockManager::new(&db, config.data_dir.clone(), system.clone(), rpc_server);
let data_rep_param = TableReplicationParams {
replication_factor: system.config.data_replication_factor,
@@ -97,6 +96,7 @@ impl Garage {
&db,
"block_ref".to_string(),
data_rep_param.clone(),
+ rpc_server,
)
.await;
@@ -110,6 +110,7 @@ impl Garage {
&db,
"version".to_string(),
meta_rep_param.clone(),
+ rpc_server,
)
.await;
@@ -123,35 +124,20 @@ impl Garage {
&db,
"object".to_string(),
meta_rep_param.clone(),
+ rpc_server,
)
.await;
println!("Initialize Garage...");
- let mut garage = Self {
+ let garage = Arc::new(Self {
db,
system: system.clone(),
block_manager,
background,
- table_rpc_handlers: HashMap::new(),
object_table,
version_table,
block_ref_table,
- };
-
- garage.table_rpc_handlers.insert(
- garage.object_table.name.clone(),
- garage.object_table.clone().rpc_handler(),
- );
- garage.table_rpc_handlers.insert(
- garage.version_table.name.clone(),
- garage.version_table.clone().rpc_handler(),
- );
- garage.table_rpc_handlers.insert(
- garage.block_ref_table.name.clone(),
- garage.block_ref_table.clone().rpc_handler(),
- );
-
- let garage = Arc::new(garage);
+ });
println!("Start block manager background thread...");
garage.block_manager.garage.swap(Some(garage.clone()));
@@ -232,20 +218,23 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> {
db_path.push("db");
let db = sled::open(db_path).expect("Unable to open DB");
- let (send_cancel, watch_cancel) = watch::channel(false);
+ println!("Initialize RPC server...");
+ let rpc_bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], config.rpc_port).into();
+ let mut rpc_server = RpcServer::new(rpc_bind_addr, config.rpc_tls.clone());
println!("Initializing background runner...");
+ let (send_cancel, watch_cancel) = watch::channel(false);
let background = BackgroundRunner::new(8, watch_cancel.clone());
- let garage = Garage::new(config, id, db, background.clone()).await;
+ let garage = Garage::new(config, id, db, background.clone(), &mut rpc_server).await;
println!("Initializing RPC and API servers...");
- let rpc_server = rpc_server::run_rpc_server(garage.clone(), wait_from(watch_cancel.clone()));
+ let run_rpc_server = Arc::new(rpc_server).run(wait_from(watch_cancel.clone()));
let api_server = api_server::run_api_server(garage.clone(), wait_from(watch_cancel.clone()));
futures::try_join!(
garage.system.clone().bootstrap().map(Ok),
- rpc_server,
+ run_rpc_server,
api_server,
background.run().map(Ok),
shutdown_signal(send_cancel),
diff --git a/src/table.rs b/src/table.rs
index 3ad08cff..f7354376 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -11,14 +11,15 @@ use serde_bytes::ByteBuf;
use crate::data::*;
use crate::error::Error;
use crate::membership::System;
-use crate::proto::*;
use crate::rpc_client::*;
+use crate::rpc_server::*;
use crate::table_sync::*;
pub struct Table<F: TableSchema> {
pub instance: F,
pub name: String,
+ pub rpc_client: Arc<RpcClient<TableRPC<F>>>,
pub system: Arc<System>,
pub store: sled::Tree,
@@ -35,24 +36,6 @@ pub struct TableReplicationParams {
pub timeout: Duration,
}
-#[async_trait]
-pub trait TableRpcHandler {
- async fn handle(&self, rpc: &[u8]) -> Result<Vec<u8>, Error>;
-}
-
-struct TableRpcHandlerAdapter<F: TableSchema> {
- table: Arc<Table<F>>,
-}
-
-#[async_trait]
-impl<F: TableSchema + 'static> TableRpcHandler for TableRpcHandlerAdapter<F> {
- async fn handle(&self, rpc: &[u8]) -> Result<Vec<u8>, Error> {
- let msg = rmp_serde::decode::from_read_ref::<_, TableRPC<F>>(rpc)?;
- let rep = self.table.handle(msg).await?;
- Ok(rmp_to_vec_all_named(&rep)?)
- }
-}
-
#[derive(Serialize, Deserialize)]
pub enum TableRPC<F: TableSchema> {
Ok,
@@ -67,6 +50,8 @@ pub enum TableRPC<F: TableSchema> {
SyncRPC(SyncRPC),
}
+impl<F: TableSchema> RpcMessage for TableRPC<F> {}
+
pub trait PartitionKey {
fn hash(&self) -> Hash;
}
@@ -136,18 +121,27 @@ impl<F: TableSchema + 'static> Table<F> {
db: &sled::Db,
name: String,
param: TableReplicationParams,
+ rpc_server: &mut RpcServer,
) -> Arc<Self> {
let store = db.open_tree(&name).expect("Unable to open DB tree");
+
+ let rpc_path = format!("table_{}", name);
+ let rpc_client = system.rpc_client::<TableRPC<F>>(&rpc_path);
+
let table = Arc::new(Self {
instance,
name,
+ rpc_client,
system,
store,
param,
syncer: ArcSwapOption::from(None),
});
+ table.clone().register_handler(rpc_server, rpc_path);
+
let syncer = TableSyncer::launch(table.clone()).await;
table.syncer.swap(Some(syncer));
+
table
}
@@ -158,9 +152,10 @@ impl<F: TableSchema + 'static> Table<F> {
//eprintln!("insert who: {:?}", who);
let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(e)?));
- let rpc = &TableRPC::<F>::Update(vec![e_enc]);
+ let rpc = TableRPC::<F>::Update(vec![e_enc]);
- self.rpc_try_call_many(&who[..], &rpc, self.param.write_quorum)
+ self.rpc_client
+ .try_call_many(&who[..], rpc, self.param.write_quorum, self.param.timeout)
.await?;
Ok(())
}
@@ -183,10 +178,8 @@ impl<F: TableSchema + 'static> Table<F> {
let call_futures = call_list.drain().map(|(node, entries)| async move {
let rpc = TableRPC::<F>::Update(entries);
- let rpc_bytes = rmp_to_vec_all_named(&rpc)?;
- let message = Message::TableRPC(self.name.to_string(), rpc_bytes);
- let resp = rpc_call(self.system.clone(), &node, &message, self.param.timeout).await?;
+ let resp = self.rpc_client.call(&node, rpc, self.param.timeout).await?;
Ok::<_, Error>((node, resp))
});
let mut resps = call_futures.collect::<FuturesUnordered<_>>();
@@ -214,9 +207,10 @@ impl<F: TableSchema + 'static> Table<F> {
let who = ring.walk_ring(&hash, self.param.replication_factor);
//eprintln!("get who: {:?}", who);
- let rpc = &TableRPC::<F>::ReadEntry(partition_key.clone(), sort_key.clone());
+ let rpc = TableRPC::<F>::ReadEntry(partition_key.clone(), sort_key.clone());
let resps = self
- .rpc_try_call_many(&who[..], &rpc, self.param.read_quorum)
+ .rpc_client
+ .try_call_many(&who[..], rpc, self.param.read_quorum, self.param.timeout)
.await?;
let mut ret = None;
@@ -264,9 +258,10 @@ impl<F: TableSchema + 'static> Table<F> {
let who = ring.walk_ring(&hash, self.param.replication_factor);
let rpc =
- &TableRPC::<F>::ReadRange(partition_key.clone(), begin_sort_key.clone(), filter, limit);
+ TableRPC::<F>::ReadRange(partition_key.clone(), begin_sort_key.clone(), filter, limit);
let resps = self
- .rpc_try_call_many(&who[..], &rpc, self.param.read_quorum)
+ .rpc_client
+ .try_call_many(&who[..], rpc, self.param.read_quorum, self.param.timeout)
.await?;
let mut ret = BTreeMap::new();
@@ -315,71 +310,24 @@ impl<F: TableSchema + 'static> Table<F> {
async fn repair_on_read(&self, who: &[UUID], what: F::E) -> Result<(), Error> {
let what_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(&what)?));
- self.rpc_try_call_many(&who[..], &TableRPC::<F>::Update(vec![what_enc]), who.len())
+ self.rpc_client
+ .try_call_many(
+ &who[..],
+ TableRPC::<F>::Update(vec![what_enc]),
+ who.len(),
+ self.param.timeout,
+ )
.await?;
Ok(())
}
- async fn rpc_try_call_many(
- &self,
- who: &[UUID],
- rpc: &TableRPC<F>,
- quorum: usize,
- ) -> Result<Vec<TableRPC<F>>, Error> {
- //eprintln!("Table RPC to {:?}: {}", who, serde_json::to_string(&rpc)?);
-
- let rpc_bytes = rmp_to_vec_all_named(rpc)?;
- let rpc_msg = Message::TableRPC(self.name.to_string(), rpc_bytes);
-
- let resps = rpc_try_call_many(
- self.system.clone(),
- who,
- rpc_msg,
- quorum,
- self.param.timeout,
- )
- .await?;
-
- let mut resps_vals = vec![];
- for resp in resps {
- if let Message::TableRPC(tbl, rep_by) = &resp {
- if *tbl == self.name {
- resps_vals.push(rmp_serde::decode::from_read_ref(&rep_by)?);
- continue;
- }
- }
- return Err(Error::Message(format!(
- "Invalid reply to TableRPC: {:?}",
- resp
- )));
- }
- //eprintln!(
- // "Table RPC responses: {}",
- // serde_json::to_string(&resps_vals)?
- //);
- Ok(resps_vals)
- }
-
- pub async fn rpc_call(&self, who: &UUID, rpc: &TableRPC<F>) -> Result<TableRPC<F>, Error> {
- let rpc_bytes = rmp_to_vec_all_named(rpc)?;
- let rpc_msg = Message::TableRPC(self.name.to_string(), rpc_bytes);
-
- let resp = rpc_call(self.system.clone(), who, &rpc_msg, self.param.timeout).await?;
- if let Message::TableRPC(tbl, rep_by) = &resp {
- if *tbl == self.name {
- return Ok(rmp_serde::decode::from_read_ref(&rep_by)?);
- }
- }
- Err(Error::Message(format!(
- "Invalid reply to TableRPC: {:?}",
- resp
- )))
- }
-
// =============== HANDLERS FOR RPC OPERATIONS (SERVER SIDE) ==============
- pub fn rpc_handler(self: Arc<Self>) -> Box<dyn TableRpcHandler + Send + Sync> {
- Box::new(TableRpcHandlerAdapter::<F> { table: self })
+ fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) {
+ rpc_server.add_handler::<TableRPC<F>, _, _>(path, move |msg, _addr| {
+ let self2 = self.clone();
+ async move { self2.handle(msg).await }
+ })
}
async fn handle(self: &Arc<Self>, msg: TableRPC<F>) -> Result<TableRPC<F>, Error> {
diff --git a/src/table_sync.rs b/src/table_sync.rs
index 024e239f..3ba2fc6a 100644
--- a/src/table_sync.rs
+++ b/src/table_sync.rs
@@ -360,12 +360,14 @@ impl<F: TableSchema + 'static> TableSyncer<F> {
// If their root checksum has level > than us, use that as a reference
let root_cks_resp = self
.table
- .rpc_call(
+ .rpc_client
+ .call(
&who,
&TableRPC::<F>::SyncRPC(SyncRPC::GetRootChecksumRange(
partition.begin.clone(),
partition.end.clone(),
)),
+ self.table.param.timeout,
)
.await?;
if let TableRPC::<F>::SyncRPC(SyncRPC::RootChecksumRange(range)) = root_cks_resp {
@@ -392,9 +394,11 @@ impl<F: TableSchema + 'static> TableSyncer<F> {
let rpc_resp = self
.table
- .rpc_call(
+ .rpc_client
+ .call(
&who,
&TableRPC::<F>::SyncRPC(SyncRPC::Checksums(step, retain)),
+ self.table.param.timeout,
)
.await?;
if let TableRPC::<F>::SyncRPC(SyncRPC::Difference(mut diff_ranges, diff_items)) =
@@ -451,7 +455,12 @@ impl<F: TableSchema + 'static> TableSyncer<F> {
}
let rpc_resp = self
.table
- .rpc_call(&who, &TableRPC::<F>::Update(values))
+ .rpc_client
+ .call(
+ &who,
+ &TableRPC::<F>::Update(values),
+ self.table.param.timeout,
+ )
.await?;
if let TableRPC::<F>::Ok = rpc_resp {
Ok(())