aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2022-07-21 19:05:51 +0200
committerAlex Auvolat <alex@adnab.me>2022-07-21 19:05:51 +0200
commit44bbc1c00c2532e08dff0d4a547b0a707e89f32d (patch)
treea6c021ae50370b3c065e3485ef1dd06052a962c9
parent26989bba1409bfc093e58ef98e75885b10ad7c1c (diff)
downloadnetapp-44bbc1c00c2532e08dff0d4a547b0a707e89f32d.tar.gz
netapp-44bbc1c00c2532e08dff0d4a547b0a707e89f32d.zip
Rename AutoSerialize into SimpleMessage and refactor a bit
-rw-r--r--Cargo.toml2
-rw-r--r--src/client.rs10
-rw-r--r--src/endpoint.rs22
-rw-r--r--src/message.rs118
-rw-r--r--src/netapp.rs6
-rw-r--r--src/peering/fullmesh.rs12
-rw-r--r--src/send.rs2
-rw-r--r--src/util.rs11
8 files changed, 114 insertions, 69 deletions
diff --git a/Cargo.toml b/Cargo.toml
index d8a4908..a19e11a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -26,7 +26,7 @@ tokio = { version = "1.0", default-features = false, features = ["net", "rt", "r
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
tokio-stream = "0.1.7"
-serde = { version = "1.0", default-features = false, features = ["derive"] }
+serde = { version = "1.0", default-features = false, features = ["derive", "rc"] }
rmp-serde = "0.14.3"
hex = "0.4.2"
diff --git a/src/client.rs b/src/client.rs
index 663a3e4..cf80746 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -134,15 +134,14 @@ impl ClientConn {
self.query_send.store(None);
}
- pub(crate) async fn call<T, B>(
+ pub(crate) async fn call<T>(
self: Arc<Self>,
- rq: B,
+ rq: T,
path: &str,
prio: RequestPriority,
) -> Result<<T as Message>::Response, Error>
where
T: Message,
- B: Borrow<T>,
{
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@@ -164,7 +163,8 @@ impl ClientConn {
};
// Encode request
- let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
+ let (rq, stream) = rq.into_parts();
+ let body = rmp_to_vec_all_named(&rq)?;
drop(rq);
let request = QueryMessage {
@@ -217,7 +217,7 @@ impl ClientConn {
let code = resp[0];
if code == 0 {
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
- Ok(T::Response::deserialize_msg(ser_resp, stream).await)
+ Ok(T::Response::from_parts(ser_resp, stream))
} else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg))
diff --git a/src/endpoint.rs b/src/endpoint.rs
index e6b2236..3f292d9 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -19,7 +19,7 @@ pub trait EndpointHandler<M>: Send + Sync
where
M: Message,
{
- async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
+ async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
}
/// If one simply wants to use an endpoint in a client fashion,
@@ -28,7 +28,7 @@ where
/// it will panic if it is ever made to handle request.
#[async_trait]
impl<M: Message + 'static> EndpointHandler<M> for () {
- async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
+ async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response {
panic!("This endpoint should not have a local handler.");
}
}
@@ -81,19 +81,16 @@ where
/// Call this endpoint on a remote node (or on the local node,
/// for that matter)
- pub async fn call<B>(
+ pub async fn call(
&self,
target: &NodeID,
- req: B,
+ req: M,
prio: RequestPriority,
- ) -> Result<<M as Message>::Response, Error>
- where
- B: Borrow<M> + Send + Sync,
- {
+ ) -> Result<<M as Message>::Response, Error> {
if *target == self.netapp.id {
match self.handler.load_full() {
None => Err(Error::NoHandler),
- Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
+ Some(h) => Ok(h.handle(req, self.netapp.id).await),
}
} else {
let conn = self
@@ -152,10 +149,11 @@ where
None => Err(Error::NoHandler),
Some(h) => {
let req = rmp_serde::decode::from_read_ref(buf)?;
- let req = M::deserialize_msg(req, stream).await;
- let res = h.handle(&req, from).await;
+ let req = M::from_parts(req, stream);
+ let res = h.handle(req, from).await;
+ let (res, res_stream) = res.into_parts();
let res_bytes = rmp_to_vec_all_named(&res)?;
- Ok(res_bytes)
+ Ok((res_bytes, res_stream))
}
}
}
diff --git a/src/message.rs b/src/message.rs
index 6d50254..f92eb8c 100644
--- a/src/message.rs
+++ b/src/message.rs
@@ -1,7 +1,9 @@
use async_trait::async_trait;
-use futures::stream::{Stream, StreamExt};
+use bytes::Bytes;
use serde::{Deserialize, Serialize};
+use futures::stream::{Stream, StreamExt};
+
use crate::error::*;
use crate::util::*;
@@ -41,66 +43,112 @@ pub trait Message: SerializeMessage + Send + Sync {
}
/// A trait for de/serializing messages, with possible associated stream.
-#[async_trait]
pub trait SerializeMessage: Sized {
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
- fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>);
+ fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);
+
+ fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
+}
+
+// ----
+
+impl<T, E> SerializeMessage for Result<T, E>
+where
+ T: SerializeMessage + Send,
+ E: Serialize + for<'de> Deserialize<'de> + Send,
+{
+ type SerializableSelf = Result<T::SerializableSelf, E>;
+
+ fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ match self {
+ Ok(ok) => {
+ let (msg, stream) = ok.into_parts();
+ (Ok(msg), stream)
+ }
+ Err(err) => (Err(err), None),
+ }
+ }
- // TODO should return Result
- async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
+ fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
+ match ser_self {
+ Ok(ok) => Ok(T::from_parts(ok, stream)),
+ Err(err) => Err(err),
+ }
+ }
}
-pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
+// ---
+
+pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
-#[async_trait]
impl<T> SerializeMessage for T
where
- T: AutoSerialize,
+ T: SimpleMessage,
{
type SerializableSelf = Self;
- fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
- (self.clone(), None)
+ fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ (self, None)
}
- async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
+ fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
// TODO verify no stream
ser_self
}
}
-impl AutoSerialize for () {}
+impl SimpleMessage for () {}
-#[async_trait]
-impl<T, E> SerializeMessage for Result<T, E>
+impl<T: SimpleMessage> SimpleMessage for std::sync::Arc<T> {}
+
+// ----
+
+#[derive(Clone)]
+pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
+ pub T,
+ pub Bytes,
+);
+
+impl<T> SerializeMessage for WithFixedBody<T>
where
- T: SerializeMessage + Send,
- E: SerializeMessage + Send,
+ T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
- type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
+ type SerializableSelf = T;
+
+ fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ let body = self.1;
+ (
+ self.0,
+ Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
+ )
+ }
- fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
- match self {
- Ok(ok) => {
- let (msg, stream) = ok.serialize_msg();
- (Ok(msg), stream)
- }
- Err(err) => {
- let (msg, stream) = err.serialize_msg();
- (Err(msg), stream)
- }
- }
+ fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
+ panic!("Cannot reconstruct a WithFixedBody type from parts");
}
+}
- async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
- match ser_self {
- Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
- Err(err) => Err(E::deserialize_msg(err, stream).await),
- }
+pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>(
+ pub T,
+ pub ByteStream,
+);
+
+impl<T> SerializeMessage for WithStreamingBody<T>
+where
+ T: Serialize + for<'de> Deserialize<'de> + Send,
+{
+ type SerializableSelf = T;
+
+ fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ (self.0, Some(self.1))
+ }
+
+ fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
+ WithStreamingBody(ser_self, stream)
}
}
-// ----
+// ---- ----
pub(crate) struct QueryMessage<'a> {
pub(crate) prio: RequestPriority,
@@ -175,6 +223,8 @@ impl<'a> QueryMessage<'a> {
}
}
+// ---- ----
+
pub(crate) struct Framing {
direct: Vec<u8>,
stream: Option<ByteStream>,
diff --git a/src/netapp.rs b/src/netapp.rs
index dd22d90..8365de0 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -38,7 +38,7 @@ pub(crate) struct HelloMessage {
pub server_port: u16,
}
-impl AutoSerialize for HelloMessage {}
+impl SimpleMessage for HelloMessage {}
impl Message for HelloMessage {
type Response = ();
@@ -399,7 +399,7 @@ impl NetApp {
hello_endpoint
.call(
&conn.peer_id,
- &HelloMessage {
+ HelloMessage {
server_addr,
server_port,
},
@@ -434,7 +434,7 @@ impl NetApp {
#[async_trait]
impl EndpointHandler<HelloMessage> for NetApp {
- async fn handle(self: &Arc<Self>, msg: &HelloMessage, from: NodeID) {
+ async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg);
if let Some(h) = self.on_connected_handler.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&from) {
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index 5b489ae..3eeebb3 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -40,7 +40,7 @@ impl Message for PingMessage {
type Response = PingMessage;
}
-impl AutoSerialize for PingMessage {}
+impl SimpleMessage for PingMessage {}
#[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage {
@@ -51,7 +51,7 @@ impl Message for PeerListMessage {
type Response = PeerListMessage;
}
-impl AutoSerialize for PeerListMessage {}
+impl SimpleMessage for PeerListMessage {}
// -- Algorithm data structures --
@@ -379,7 +379,7 @@ impl FullMeshPeeringStrategy {
ping_time
);
let ping_response = select! {
- r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r,
+ r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r,
_ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())),
};
@@ -431,7 +431,7 @@ impl FullMeshPeeringStrategy {
let pex_message = PeerListMessage { list: peer_list };
match self
.peer_list_endpoint
- .call(id, &pex_message, PRIO_BACKGROUND)
+ .call(id, pex_message, PRIO_BACKGROUND)
.await
{
Err(e) => warn!("Error doing peer exchange: {}", e),
@@ -587,7 +587,7 @@ impl FullMeshPeeringStrategy {
#[async_trait]
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
- async fn handle(self: &Arc<Self>, ping: &PingMessage, from: NodeID) -> PingMessage {
+ async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
let ping_resp = PingMessage {
id: ping.id,
peer_list_hash: self.known_hosts.read().unwrap().hash,
@@ -601,7 +601,7 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
async fn handle(
self: &Arc<Self>,
- peer_list: &PeerListMessage,
+ peer_list: PeerListMessage,
_from: NodeID,
) -> PeerListMessage {
self.handle_peer_list(&peer_list.list[..]);
diff --git a/src/send.rs b/src/send.rs
index 660e85c..cc28d7c 100644
--- a/src/send.rs
+++ b/src/send.rs
@@ -3,8 +3,8 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
-use bytes::Bytes;
use async_trait::async_trait;
+use bytes::Bytes;
use log::trace;
use futures::AsyncWriteExt;
diff --git a/src/util.rs b/src/util.rs
index e81a89c..f860672 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -2,9 +2,9 @@ use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
+use bytes::Bytes;
use log::info;
use serde::Serialize;
-use bytes::Bytes;
use futures::Stream;
use tokio::sync::watch;
@@ -35,19 +35,16 @@ pub type Packet = Result<Bytes, u8>;
///
/// Field names and variant names are included in the serialization.
/// This is used internally by the netapp communication protocol.
-pub fn rmp_to_vec_all_named<T>(
- val: &T,
-) -> Result<(Vec<u8>, Option<ByteStream>), rmp_serde::encode::Error>
+pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
where
- T: SerializeMessage + ?Sized,
+ T: Serialize + ?Sized,
{
let mut wr = Vec::with_capacity(128);
let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map()
.with_string_variants();
- let (val, stream) = val.serialize_msg();
val.serialize(&mut se)?;
- Ok((wr, stream))
+ Ok(wr)
}
/// This async function returns only when a true signal was received