aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/client.rs5
-rw-r--r--src/endpoint.rs80
-rw-r--r--src/netapp.rs4
-rw-r--r--src/peering/fullmesh.rs8
-rw-r--r--src/proto.rs5
-rw-r--r--src/util.rs5
6 files changed, 67 insertions, 40 deletions
diff --git a/src/client.rs b/src/client.rs
index bce7aca..bc16fb1 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -227,9 +227,8 @@ impl ClientConn {
let code = resp[0];
if code == 0 {
- let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]);
- let res = T::Response::deserialize_msg(&mut deser, stream).await?;
- Ok(res)
+ let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
+ Ok(T::Response::deserialize_msg(ser_resp, stream).await)
} else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg))
diff --git a/src/endpoint.rs b/src/endpoint.rs
index 81ed036..c25365a 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -5,8 +5,7 @@ use std::sync::Arc;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
-use serde::de::Error as DeError;
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
+use serde::{Deserialize, Serialize};
use crate::error::Error;
use crate::netapp::*;
@@ -22,42 +21,61 @@ pub trait Message: SerializeMessage + Send + Sync {
/// A trait for de/serializing messages, with possible associated stream.
#[async_trait]
pub trait SerializeMessage: Sized {
- fn serialize_msg<S: Serializer>(
- &self,
- serializer: S,
- ) -> Result<(S::Ok, Option<AssociatedStream>), S::Error>;
+ type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
- async fn deserialize_msg<'de, D: Deserializer<'de> + Send>(
- deserializer: D,
- stream: AssociatedStream,
- ) -> Result<Self, D::Error>;
+ // TODO should return Result
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
+
+ // TODO should return Result
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self;
}
+pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
+
#[async_trait]
impl<T> SerializeMessage for T
where
- T: Serialize + for<'de> Deserialize<'de> + Send + Sync,
+ T: AutoSerialize,
{
- fn serialize_msg<S: Serializer>(
- &self,
- serializer: S,
- ) -> Result<(S::Ok, Option<AssociatedStream>), S::Error> {
- self.serialize(serializer).map(|r| (r, None))
+ type SerializableSelf = Self;
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
+ (self.clone(), None)
+ }
+
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
+ // TODO verify no stream
+ ser_self
+ }
+}
+
+impl AutoSerialize for () {}
+
+#[async_trait]
+impl<T, E> SerializeMessage for Result<T, E>
+where
+ T: SerializeMessage + Send,
+ E: SerializeMessage + Send,
+{
+ type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
+
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
+ match self {
+ Ok(ok) => {
+ let (msg, stream) = ok.serialize_msg();
+ (Ok(msg), stream)
+ }
+ Err(err) => {
+ let (msg, stream) = err.serialize_msg();
+ (Err(msg), stream)
+ }
+ }
}
- async fn deserialize_msg<'de, D: Deserializer<'de> + Send>(
- deserializer: D,
- mut stream: AssociatedStream,
- ) -> Result<Self, D::Error> {
- use futures::StreamExt;
-
- let res = Self::deserialize(deserializer)?;
- if stream.next().await.is_some() {
- return Err(D::Error::custom(
- "failed to deserialize: found associated stream when none expected",
- ));
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
+ match ser_self {
+ Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
+ Err(err) => Err(E::deserialize_msg(err, stream).await),
}
- Ok(res)
}
}
@@ -139,7 +157,7 @@ where
prio: RequestPriority,
) -> Result<<M as Message>::Response, Error>
where
- B: Borrow<M>,
+ B: Borrow<M> + Send + Sync,
{
if *target == self.netapp.id {
match self.handler.load_full() {
@@ -202,8 +220,8 @@ where
match self.0.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => {
- let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf);
- let req = M::deserialize_msg(&mut deser, stream).await?;
+ let req = rmp_serde::decode::from_read_ref(buf)?;
+ let req = M::deserialize_msg(req, stream).await;
let res = h.handle(&req, from).await;
let res_bytes = rmp_to_vec_all_named(&res)?;
Ok(res_bytes)
diff --git a/src/netapp.rs b/src/netapp.rs
index e9efa2e..27f17e6 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16];
/// Value of the Netapp version used in the version tag
pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct HelloMessage {
pub server_addr: Option<IpAddr>,
pub server_port: u16,
}
+impl AutoSerialize for HelloMessage {}
+
impl Message for HelloMessage {
type Response = ();
}
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index 012c5a0..7dfc5c4 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3;
// -- Protocol messages --
-#[derive(Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Clone)]
struct PingMessage {
pub id: u64,
pub peer_list_hash: hash::Digest,
@@ -39,7 +39,9 @@ impl Message for PingMessage {
type Response = PingMessage;
}
-#[derive(Serialize, Deserialize)]
+impl AutoSerialize for PingMessage {}
+
+#[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage {
pub list: Vec<(NodeID, SocketAddr)>,
}
@@ -48,6 +50,8 @@ impl Message for PeerListMessage {
type Response = PeerListMessage;
}
+impl AutoSerialize for PeerListMessage {}
+
// -- Algorithm data structures --
#[derive(Debug)]
diff --git a/src/proto.rs b/src/proto.rs
index ca1a3d2..073a317 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -151,9 +151,10 @@ impl Stream for DataReader {
}
let mut body = [0; MAX_CHUNK_LENGTH as usize];
- body[..buf.len()].copy_from_slice(&buf);
+ let len = buf.len();
+ body[..len].copy_from_slice(buf);
buf.clear();
- Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize)))
+ Poll::Ready(Some((body, len)))
}
}
}
diff --git a/src/util.rs b/src/util.rs
index 4333080..02b4e7d 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -8,6 +8,8 @@ use futures::Stream;
use log::info;
+use serde::Serialize;
+
use tokio::sync::watch;
/// A node's identifier, which is also its public cryptographic key
@@ -34,7 +36,8 @@ where
let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map()
.with_string_variants();
- let (_, stream) = val.serialize_msg(&mut se)?;
+ let (val, stream) = val.serialize_msg();
+ val.serialize(&mut se)?;
Ok((wr, stream))
}