diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 15 | ||||
-rw-r--r-- | src/endpoint.rs | 45 | ||||
-rw-r--r-- | src/message.rs | 205 | ||||
-rw-r--r-- | src/netapp.rs | 5 | ||||
-rw-r--r-- | src/peering/basalt.rs | 17 | ||||
-rw-r--r-- | src/peering/fullmesh.rs | 12 |
6 files changed, 209 insertions, 90 deletions
diff --git a/src/client.rs b/src/client.rs index 9d572aa..c878627 100644 --- a/src/client.rs +++ b/src/client.rs @@ -135,10 +135,10 @@ impl ClientConn { pub(crate) async fn call<T>( self: Arc<Self>, - rq: T, + req: Req<T>, path: &str, prio: RequestPriority, - ) -> Result<<T as Message>::Response, Error> + ) -> Result<Resp<T>, Error> where T: Message, { @@ -162,9 +162,8 @@ impl ClientConn { }; // Encode request - let (rq, stream) = rq.into_parts(); - let body = rmp_to_vec_all_named(&rq)?; - drop(rq); + let body = req.msg_ser.unwrap().clone(); + let stream = req.body.into_stream(); let request = QueryMessage { prio, @@ -216,7 +215,11 @@ impl ClientConn { let code = resp[0]; if code == 0 { let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; - Ok(T::Response::from_parts(ser_resp, stream)) + Ok(Resp { + _phantom: Default::default(), + msg: ser_resp, + body: BodyData::Stream(stream), + }) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index 97a7644..8ee64a5 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -18,7 +18,7 @@ pub trait EndpointHandler<M>: Send + Sync where M: Message, { - async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response; + async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>; } /// If one simply wants to use an endpoint in a client fashion, @@ -27,7 +27,7 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl<M: Message + 'static> EndpointHandler<M> for () { - async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response { + async fn handle(self: &Arc<()>, _m: Req<M>, _from: NodeID) -> Resp<M> { panic!("This endpoint should not have a local handler."); } } @@ -80,16 +80,19 @@ where /// Call this endpoint on a remote node (or on the local node, /// for that matter) - pub async fn call( + pub async fn call_full<T>( &self, target: &NodeID, - req: M, + req: T, prio: RequestPriority, - ) -> Result<<M as Message>::Response, Error> { + ) -> Result<Resp<M>, Error> + where + T: IntoReq<M>, + { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req, self.netapp.id).await), + Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await), } } else { let conn = self @@ -104,10 +107,21 @@ where "Not connected: {}", hex::encode(&target[..8]) ))), - Some(c) => c.call(req, self.path.as_str(), prio).await, + Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await, } } } + + /// Call this endpoint on a remote node, without the possibility + /// of adding or receiving a body + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<<M as Message>::Response, Error> { + Ok(self.call_full(target, req, prio).await?.into_msg()) + } } // ---- Internal stuff ---- @@ -148,11 +162,20 @@ where None => Err(Error::NoHandler), Some(h) => { let req = rmp_serde::decode::from_read_ref(buf)?; - let req = M::from_parts(req, stream); + let req = Req { + _phantom: Default::default(), + msg: Arc::new(req), + msg_ser: None, + body: BodyData::Stream(stream), + }; let res = h.handle(req, from).await; - let (res, res_stream) = res.into_parts(); - let res_bytes = rmp_to_vec_all_named(&res)?; - Ok((res_bytes, res_stream)) + let Resp { + msg, + body, + _phantom, + } = res; + let res_bytes = rmp_to_vec_all_named(&msg)?; + Ok((res_bytes, body.into_stream())) } } } diff --git a/src/message.rs b/src/message.rs index 22cae6a..d918c29 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,3 +1,7 @@ +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -37,94 +41,169 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: SerializeMessage + Send + Sync { - type Response: SerializeMessage + Send + Sync; +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +} + +pub struct Req<M: Message> { + pub(crate) _phantom: PhantomData<M>, + pub(crate) msg: Arc<M>, + pub(crate) msg_ser: Option<Bytes>, + pub(crate) body: BodyData, } -/// A trait for de/serializing messages, with possible associated stream. -/// This is default-implemented by anything that can already be serialized -/// and deserialized. Adapters are provided that implement this for -/// adding a body, either from a fixed Bytes buffer (which allows the thing -/// to be Clone), or from a streaming byte stream. -pub trait SerializeMessage: Sized { - type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; +pub struct Resp<M: Message> { + pub(crate) _phantom: PhantomData<M>, + pub(crate) msg: M::Response, + pub(crate) body: BodyData, +} - fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>); +pub(crate) enum BodyData { + None, + Fixed(Bytes), + Stream(ByteStream), +} - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +impl BodyData { + pub fn into_stream(self) -> Option<ByteStream> { + match self { + BodyData::None => None, + BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + BodyData::Stream(s) => Some(s), + } + } } // ---- -impl<T> SerializeMessage for T -where - T: Serialize + for<'de> Deserialize<'de> + Send, -{ - type SerializableSelf = Self; - fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) { - (self, None) +impl<M: Message> Req<M> { + pub fn msg(&self) -> &M { + &self.msg } - fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { - // TODO verify no stream - ser_self + pub fn with_fixed_body(self, b: Bytes) -> Self { + Self { + body: BodyData::Fixed(b), + ..self + } + } + + pub fn with_streaming_body(self, b: ByteStream) -> Self { + Self { + body: BodyData::Stream(b), + ..self + } } } -// ---- +pub trait IntoReq<M: Message> { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>; + fn into_req_local(self) -> Req<M>; +} -/// An adapter that adds a body from a fixed Bytes buffer to a serializable message, -/// implementing the SerializeMessage trait. This allows for the SerializeMessage object -/// to be cloned, which is usefull for requests that must be sent to multiple servers. -/// Note that cloning the body is cheap thanks to Bytes; make sure that your serializable -/// part is also easily clonable (e.g. by wrapping it in an Arc). -/// Note that this CANNOT be used for a response type, as it cannot be reconstructed -/// from a remote stream. -#[derive(Clone)] -pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>( - pub T, - pub Bytes, -); - -impl<T> SerializeMessage for WithFixedBody<T> -where - T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static, -{ - type SerializableSelf = T; - - fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) { - let body = self.1; - ( - self.0, - Some(Box::pin(futures::stream::once(async move { Ok(body) }))), - ) +impl<M: Message> IntoReq<M> for M { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> { + let msg_ser = rmp_to_vec_all_named(&self)?; + Ok(Req { + _phantom: Default::default(), + msg: Arc::new(self), + msg_ser: Some(Bytes::from(msg_ser)), + body: BodyData::None, + }) } + fn into_req_local(self) -> Req<M> { + Req { + _phantom: Default::default(), + msg: Arc::new(self), + msg_ser: None, + body: BodyData::None, + } + } +} - fn from_parts(_ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { - panic!("Cannot use a WithFixedBody as a response type"); +impl<M: Message> IntoReq<M> for Req<M> { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> { + Ok(self) + } + fn into_req_local(self) -> Req<M> { + self } } -/// An adapter that adds a body from a ByteStream. This is usefull for receiving -/// responses to requests that contain attached byte streams. This type is -/// not clonable. -pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>( - pub T, - pub ByteStream, -); +impl<M: Message> Clone for Req<M> { + fn clone(&self) -> Self { + let body = match &self.body { + BodyData::None => BodyData::None, + BodyData::Fixed(b) => BodyData::Fixed(b.clone()), + BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"), + }; + Self { + _phantom: Default::default(), + msg: self.msg.clone(), + msg_ser: self.msg_ser.clone(), + body, + } + } +} -impl<T> SerializeMessage for WithStreamingBody<T> +impl<M> fmt::Debug for Req<M> where - T: Serialize + for<'de> Deserialize<'de> + Send, + M: Message + fmt::Debug, { - type SerializableSelf = T; + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Req[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} + +impl<M> fmt::Debug for Resp<M> +where + M: Message, + <M as Message>::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} + +impl<M: Message> Resp<M> { + pub fn new(v: M::Response) -> Self { + Resp { + _phantom: Default::default(), + msg: v, + body: BodyData::None, + } + } + + pub fn with_fixed_body(self, b: Bytes) -> Self { + Self { + body: BodyData::Fixed(b), + ..self + } + } + + pub fn with_streaming_body(self, b: ByteStream) -> Self { + Self { + body: BodyData::Stream(b), + ..self + } + } - fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) { - (self.0, Some(self.1)) + pub fn msg(&self) -> &M::Response { + &self.msg } - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { - WithStreamingBody(ser_self, stream) + pub fn into_msg(self) -> M::Response { + self.msg } } diff --git a/src/netapp.rs b/src/netapp.rs index 32a5c23..0cebac0 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -404,6 +404,7 @@ impl NetApp { PRIO_NORMAL, ) .await + .map(|_| ()) .log_err("Sending hello message"); }); } @@ -432,7 +433,8 @@ impl NetApp { #[async_trait] impl EndpointHandler<HelloMessage> for NetApp { - async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) { + async fn handle(self: &Arc<Self>, msg: Req<HelloMessage>, from: NodeID) -> Resp<HelloMessage> { + let msg = msg.msg(); debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { @@ -441,5 +443,6 @@ impl EndpointHandler<HelloMessage> for NetApp { h(from, remote_addr, true); } } + Resp::new(()) } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index d7bc6a8..71dea84 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -468,15 +468,24 @@ impl Basalt { #[async_trait] impl EndpointHandler<PullMessage> for Basalt { - async fn handle(self: &Arc<Self>, _pullmsg: PullMessage, _from: NodeID) -> PushMessage { - self.make_push_message() + async fn handle( + self: &Arc<Self>, + _pullmsg: Req<PullMessage>, + _from: NodeID, + ) -> Resp<PullMessage> { + Resp::new(self.make_push_message()) } } #[async_trait] impl EndpointHandler<PushMessage> for Basalt { - async fn handle(self: &Arc<Self>, pushmsg: PushMessage, _from: NodeID) { - self.handle_peer_list(&pushmsg.peers[..]); + async fn handle( + self: &Arc<Self>, + pushmsg: Req<PushMessage>, + _from: NodeID, + ) -> Resp<PushMessage> { + self.handle_peer_list(&pushmsg.msg().peers[..]); + Resp::new(()) } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index f8348af..9b7b666 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -583,13 +583,14 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy { - async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage { + async fn handle(self: &Arc<Self>, ping: Req<PingMessage>, from: NodeID) -> Resp<PingMessage> { + let ping = ping.msg(); let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, }; debug!("Ping from {}", hex::encode(&from[..8])); - ping_resp + Resp::new(ping_resp) } } @@ -597,11 +598,12 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy { impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy { async fn handle( self: &Arc<Self>, - peer_list: PeerListMessage, + peer_list: Req<PeerListMessage>, _from: NodeID, - ) -> PeerListMessage { + ) -> Resp<PeerListMessage> { + let peer_list = peer_list.msg(); self.handle_peer_list(&peer_list.list[..]); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); - PeerListMessage { list: peer_list } + Resp::new(PeerListMessage { list: peer_list }) } } |