aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/basalt.rs10
-rw-r--r--src/client.rs15
-rw-r--r--src/endpoint.rs45
-rw-r--r--src/message.rs205
-rw-r--r--src/netapp.rs5
-rw-r--r--src/peering/basalt.rs17
-rw-r--r--src/peering/fullmesh.rs12
7 files changed, 216 insertions, 93 deletions
diff --git a/examples/basalt.rs b/examples/basalt.rs
index dd56cd7..3841786 100644
--- a/examples/basalt.rs
+++ b/examples/basalt.rs
@@ -159,11 +159,15 @@ impl Example {
#[async_trait]
impl EndpointHandler<ExampleMessage> for Example {
- async fn handle(self: &Arc<Self>, msg: ExampleMessage, _from: NodeID) -> ExampleResponse {
+ async fn handle(
+ self: &Arc<Self>,
+ msg: Req<ExampleMessage>,
+ _from: NodeID,
+ ) -> Resp<ExampleMessage> {
debug!("Got example message: {:?}, sending example response", msg);
- ExampleResponse {
+ Resp::new(ExampleResponse {
example_field: false,
- }
+ })
}
}
diff --git a/src/client.rs b/src/client.rs
index 9d572aa..c878627 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -135,10 +135,10 @@ impl ClientConn {
pub(crate) async fn call<T>(
self: Arc<Self>,
- rq: T,
+ req: Req<T>,
path: &str,
prio: RequestPriority,
- ) -> Result<<T as Message>::Response, Error>
+ ) -> Result<Resp<T>, Error>
where
T: Message,
{
@@ -162,9 +162,8 @@ impl ClientConn {
};
// Encode request
- let (rq, stream) = rq.into_parts();
- let body = rmp_to_vec_all_named(&rq)?;
- drop(rq);
+ let body = req.msg_ser.unwrap().clone();
+ let stream = req.body.into_stream();
let request = QueryMessage {
prio,
@@ -216,7 +215,11 @@ impl ClientConn {
let code = resp[0];
if code == 0 {
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
- Ok(T::Response::from_parts(ser_resp, stream))
+ Ok(Resp {
+ _phantom: Default::default(),
+ msg: ser_resp,
+ body: BodyData::Stream(stream),
+ })
} else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg))
diff --git a/src/endpoint.rs b/src/endpoint.rs
index 97a7644..8ee64a5 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -18,7 +18,7 @@ pub trait EndpointHandler<M>: Send + Sync
where
M: Message,
{
- async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
+ async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>;
}
/// If one simply wants to use an endpoint in a client fashion,
@@ -27,7 +27,7 @@ where
/// it will panic if it is ever made to handle request.
#[async_trait]
impl<M: Message + 'static> EndpointHandler<M> for () {
- async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response {
+ async fn handle(self: &Arc<()>, _m: Req<M>, _from: NodeID) -> Resp<M> {
panic!("This endpoint should not have a local handler.");
}
}
@@ -80,16 +80,19 @@ where
/// Call this endpoint on a remote node (or on the local node,
/// for that matter)
- pub async fn call(
+ pub async fn call_full<T>(
&self,
target: &NodeID,
- req: M,
+ req: T,
prio: RequestPriority,
- ) -> Result<<M as Message>::Response, Error> {
+ ) -> Result<Resp<M>, Error>
+ where
+ T: IntoReq<M>,
+ {
if *target == self.netapp.id {
match self.handler.load_full() {
None => Err(Error::NoHandler),
- Some(h) => Ok(h.handle(req, self.netapp.id).await),
+ Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await),
}
} else {
let conn = self
@@ -104,10 +107,21 @@ where
"Not connected: {}",
hex::encode(&target[..8])
))),
- Some(c) => c.call(req, self.path.as_str(), prio).await,
+ Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await,
}
}
}
+
+ /// Call this endpoint on a remote node, without the possibility
+ /// of adding or receiving a body
+ pub async fn call(
+ &self,
+ target: &NodeID,
+ req: M,
+ prio: RequestPriority,
+ ) -> Result<<M as Message>::Response, Error> {
+ Ok(self.call_full(target, req, prio).await?.into_msg())
+ }
}
// ---- Internal stuff ----
@@ -148,11 +162,20 @@ where
None => Err(Error::NoHandler),
Some(h) => {
let req = rmp_serde::decode::from_read_ref(buf)?;
- let req = M::from_parts(req, stream);
+ let req = Req {
+ _phantom: Default::default(),
+ msg: Arc::new(req),
+ msg_ser: None,
+ body: BodyData::Stream(stream),
+ };
let res = h.handle(req, from).await;
- let (res, res_stream) = res.into_parts();
- let res_bytes = rmp_to_vec_all_named(&res)?;
- Ok((res_bytes, res_stream))
+ let Resp {
+ msg,
+ body,
+ _phantom,
+ } = res;
+ let res_bytes = rmp_to_vec_all_named(&msg)?;
+ Ok((res_bytes, body.into_stream()))
}
}
}
diff --git a/src/message.rs b/src/message.rs
index 22cae6a..d918c29 100644
--- a/src/message.rs
+++ b/src/message.rs
@@ -1,3 +1,7 @@
+use std::fmt;
+use std::marker::PhantomData;
+use std::sync::Arc;
+
use bytes::Bytes;
use serde::{Deserialize, Serialize};
@@ -37,94 +41,169 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
/// This trait should be implemented by all messages your application
/// wants to handle
-pub trait Message: SerializeMessage + Send + Sync {
- type Response: SerializeMessage + Send + Sync;
+pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
+ type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
+}
+
+pub struct Req<M: Message> {
+ pub(crate) _phantom: PhantomData<M>,
+ pub(crate) msg: Arc<M>,
+ pub(crate) msg_ser: Option<Bytes>,
+ pub(crate) body: BodyData,
}
-/// A trait for de/serializing messages, with possible associated stream.
-/// This is default-implemented by anything that can already be serialized
-/// and deserialized. Adapters are provided that implement this for
-/// adding a body, either from a fixed Bytes buffer (which allows the thing
-/// to be Clone), or from a streaming byte stream.
-pub trait SerializeMessage: Sized {
- type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
+pub struct Resp<M: Message> {
+ pub(crate) _phantom: PhantomData<M>,
+ pub(crate) msg: M::Response,
+ pub(crate) body: BodyData,
+}
- fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);
+pub(crate) enum BodyData {
+ None,
+ Fixed(Bytes),
+ Stream(ByteStream),
+}
- fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
+impl BodyData {
+ pub fn into_stream(self) -> Option<ByteStream> {
+ match self {
+ BodyData::None => None,
+ BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
+ BodyData::Stream(s) => Some(s),
+ }
+ }
}
// ----
-impl<T> SerializeMessage for T
-where
- T: Serialize + for<'de> Deserialize<'de> + Send,
-{
- type SerializableSelf = Self;
- fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
- (self, None)
+impl<M: Message> Req<M> {
+ pub fn msg(&self) -> &M {
+ &self.msg
}
- fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
- // TODO verify no stream
- ser_self
+ pub fn with_fixed_body(self, b: Bytes) -> Self {
+ Self {
+ body: BodyData::Fixed(b),
+ ..self
+ }
+ }
+
+ pub fn with_streaming_body(self, b: ByteStream) -> Self {
+ Self {
+ body: BodyData::Stream(b),
+ ..self
+ }
}
}
-// ----
+pub trait IntoReq<M: Message> {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>;
+ fn into_req_local(self) -> Req<M>;
+}
-/// An adapter that adds a body from a fixed Bytes buffer to a serializable message,
-/// implementing the SerializeMessage trait. This allows for the SerializeMessage object
-/// to be cloned, which is usefull for requests that must be sent to multiple servers.
-/// Note that cloning the body is cheap thanks to Bytes; make sure that your serializable
-/// part is also easily clonable (e.g. by wrapping it in an Arc).
-/// Note that this CANNOT be used for a response type, as it cannot be reconstructed
-/// from a remote stream.
-#[derive(Clone)]
-pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
- pub T,
- pub Bytes,
-);
-
-impl<T> SerializeMessage for WithFixedBody<T>
-where
- T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
-{
- type SerializableSelf = T;
-
- fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
- let body = self.1;
- (
- self.0,
- Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
- )
+impl<M: Message> IntoReq<M> for M {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
+ let msg_ser = rmp_to_vec_all_named(&self)?;
+ Ok(Req {
+ _phantom: Default::default(),
+ msg: Arc::new(self),
+ msg_ser: Some(Bytes::from(msg_ser)),
+ body: BodyData::None,
+ })
}
+ fn into_req_local(self) -> Req<M> {
+ Req {
+ _phantom: Default::default(),
+ msg: Arc::new(self),
+ msg_ser: None,
+ body: BodyData::None,
+ }
+ }
+}
- fn from_parts(_ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
- panic!("Cannot use a WithFixedBody as a response type");
+impl<M: Message> IntoReq<M> for Req<M> {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
+ Ok(self)
+ }
+ fn into_req_local(self) -> Req<M> {
+ self
}
}
-/// An adapter that adds a body from a ByteStream. This is usefull for receiving
-/// responses to requests that contain attached byte streams. This type is
-/// not clonable.
-pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>(
- pub T,
- pub ByteStream,
-);
+impl<M: Message> Clone for Req<M> {
+ fn clone(&self) -> Self {
+ let body = match &self.body {
+ BodyData::None => BodyData::None,
+ BodyData::Fixed(b) => BodyData::Fixed(b.clone()),
+ BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"),
+ };
+ Self {
+ _phantom: Default::default(),
+ msg: self.msg.clone(),
+ msg_ser: self.msg_ser.clone(),
+ body,
+ }
+ }
+}
-impl<T> SerializeMessage for WithStreamingBody<T>
+impl<M> fmt::Debug for Req<M>
where
- T: Serialize + for<'de> Deserialize<'de> + Send,
+ M: Message + fmt::Debug,
{
- type SerializableSelf = T;
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ write!(f, "Req[{:?}", self.msg)?;
+ match &self.body {
+ BodyData::None => write!(f, "]"),
+ BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
+ BodyData::Stream(_) => write!(f, "; body=stream]"),
+ }
+ }
+}
+
+impl<M> fmt::Debug for Resp<M>
+where
+ M: Message,
+ <M as Message>::Response: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ write!(f, "Resp[{:?}", self.msg)?;
+ match &self.body {
+ BodyData::None => write!(f, "]"),
+ BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
+ BodyData::Stream(_) => write!(f, "; body=stream]"),
+ }
+ }
+}
+
+impl<M: Message> Resp<M> {
+ pub fn new(v: M::Response) -> Self {
+ Resp {
+ _phantom: Default::default(),
+ msg: v,
+ body: BodyData::None,
+ }
+ }
+
+ pub fn with_fixed_body(self, b: Bytes) -> Self {
+ Self {
+ body: BodyData::Fixed(b),
+ ..self
+ }
+ }
+
+ pub fn with_streaming_body(self, b: ByteStream) -> Self {
+ Self {
+ body: BodyData::Stream(b),
+ ..self
+ }
+ }
- fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
- (self.0, Some(self.1))
+ pub fn msg(&self) -> &M::Response {
+ &self.msg
}
- fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
- WithStreamingBody(ser_self, stream)
+ pub fn into_msg(self) -> M::Response {
+ self.msg
}
}
diff --git a/src/netapp.rs b/src/netapp.rs
index 32a5c23..0cebac0 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -404,6 +404,7 @@ impl NetApp {
PRIO_NORMAL,
)
.await
+ .map(|_| ())
.log_err("Sending hello message");
});
}
@@ -432,7 +433,8 @@ impl NetApp {
#[async_trait]
impl EndpointHandler<HelloMessage> for NetApp {
- async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
+ async fn handle(self: &Arc<Self>, msg: Req<HelloMessage>, from: NodeID) -> Resp<HelloMessage> {
+ let msg = msg.msg();
debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg);
if let Some(h) = self.on_connected_handler.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&from) {
@@ -441,5 +443,6 @@ impl EndpointHandler<HelloMessage> for NetApp {
h(from, remote_addr, true);
}
}
+ Resp::new(())
}
}
diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs
index d7bc6a8..71dea84 100644
--- a/src/peering/basalt.rs
+++ b/src/peering/basalt.rs
@@ -468,15 +468,24 @@ impl Basalt {
#[async_trait]
impl EndpointHandler<PullMessage> for Basalt {
- async fn handle(self: &Arc<Self>, _pullmsg: PullMessage, _from: NodeID) -> PushMessage {
- self.make_push_message()
+ async fn handle(
+ self: &Arc<Self>,
+ _pullmsg: Req<PullMessage>,
+ _from: NodeID,
+ ) -> Resp<PullMessage> {
+ Resp::new(self.make_push_message())
}
}
#[async_trait]
impl EndpointHandler<PushMessage> for Basalt {
- async fn handle(self: &Arc<Self>, pushmsg: PushMessage, _from: NodeID) {
- self.handle_peer_list(&pushmsg.peers[..]);
+ async fn handle(
+ self: &Arc<Self>,
+ pushmsg: Req<PushMessage>,
+ _from: NodeID,
+ ) -> Resp<PushMessage> {
+ self.handle_peer_list(&pushmsg.msg().peers[..]);
+ Resp::new(())
}
}
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index f8348af..9b7b666 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -583,13 +583,14 @@ impl FullMeshPeeringStrategy {
#[async_trait]
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
- async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
+ async fn handle(self: &Arc<Self>, ping: Req<PingMessage>, from: NodeID) -> Resp<PingMessage> {
+ let ping = ping.msg();
let ping_resp = PingMessage {
id: ping.id,
peer_list_hash: self.known_hosts.read().unwrap().hash,
};
debug!("Ping from {}", hex::encode(&from[..8]));
- ping_resp
+ Resp::new(ping_resp)
}
}
@@ -597,11 +598,12 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
async fn handle(
self: &Arc<Self>,
- peer_list: PeerListMessage,
+ peer_list: Req<PeerListMessage>,
_from: NodeID,
- ) -> PeerListMessage {
+ ) -> Resp<PeerListMessage> {
+ let peer_list = peer_list.msg();
self.handle_peer_list(&peer_list.list[..]);
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
- PeerListMessage { list: peer_list }
+ Resp::new(PeerListMessage { list: peer_list })
}
}