diff options
-rw-r--r-- | examples/basalt.rs | 10 | ||||
-rw-r--r-- | src/endpoint.rs | 58 | ||||
-rw-r--r-- | src/message.rs | 2 | ||||
-rw-r--r-- | src/netapp.rs | 6 | ||||
-rw-r--r-- | src/peering/basalt.rs | 17 | ||||
-rw-r--r-- | src/peering/fullmesh.rs | 12 | ||||
-rw-r--r-- | src/util.rs | 2 |
7 files changed, 59 insertions, 48 deletions
diff --git a/examples/basalt.rs b/examples/basalt.rs index 3841786..a5a25c3 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -159,15 +159,11 @@ impl Example { #[async_trait] impl EndpointHandler<ExampleMessage> for Example { - async fn handle( - self: &Arc<Self>, - msg: Req<ExampleMessage>, - _from: NodeID, - ) -> Resp<ExampleMessage> { + async fn handle(self: &Arc<Self>, msg: &ExampleMessage, _from: NodeID) -> ExampleResponse { debug!("Got example message: {:?}, sending example response", msg); - Resp::new(ExampleResponse { + ExampleResponse { example_field: false, - }) + } } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 8ee64a5..ff626d8 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -10,11 +10,13 @@ use crate::netapp::*; use crate::util::*; /// This trait should be implemented by an object of your application -/// that can handle a message of type `M`. +/// that can handle a message of type `M`, if it wishes to handle +/// streams attached to the request and/or to send back streams +/// attached to the response.. /// /// The handler object should be in an Arc, see `Endpoint::set_handler` #[async_trait] -pub trait EndpointHandler<M>: Send + Sync +pub trait StreamingEndpointHandler<M>: Send + Sync where M: Message, { @@ -27,11 +29,34 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl<M: Message + 'static> EndpointHandler<M> for () { - async fn handle(self: &Arc<()>, _m: Req<M>, _from: NodeID) -> Resp<M> { + async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { panic!("This endpoint should not have a local handler."); } } +// ---- + +#[async_trait] +pub trait EndpointHandler<M>: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> <M as Message>::Response; +} + +#[async_trait] +impl<T, M> StreamingEndpointHandler<M> for T +where + T: EndpointHandler<M>, + M: Message + 'static, +{ + async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M> { + Resp::new(EndpointHandler::handle(self, m.msg(), from).await) + } +} + +// ---- + /// This struct represents an endpoint for message of type `M`. /// /// Creating a new endpoint is done by calling `NetApp::endpoint`. @@ -41,13 +66,13 @@ impl<M: Message + 'static> EndpointHandler<M> for () { /// An `Endpoint` is used both to send requests to remote nodes, /// and to specify the handler for such requests on the local node. /// The type `H` represents the type of the handler object for -/// endpoint messages (see `EndpointHandler`). +/// endpoint messages (see `StreamingEndpointHandler`). pub struct Endpoint<M, H> where M: Message, - H: EndpointHandler<M>, + H: StreamingEndpointHandler<M>, { - phantom: PhantomData<M>, + _phantom: PhantomData<M>, netapp: Arc<NetApp>, path: String, handler: ArcSwapOption<H>, @@ -56,11 +81,11 @@ where impl<M, H> Endpoint<M, H> where M: Message, - H: EndpointHandler<M>, + H: StreamingEndpointHandler<M>, { pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self { Self { - phantom: PhantomData::default(), + _phantom: PhantomData::default(), netapp, path, handler: ArcSwapOption::from(None), @@ -79,8 +104,10 @@ where } /// Call this endpoint on a remote node (or on the local node, - /// for that matter) - pub async fn call_full<T>( + /// for that matter). This function invokes the full version that + /// allows to attach a streaming body to the request and to + /// receive such a body attached to the response. + pub async fn call_streaming<T>( &self, target: &NodeID, req: T, @@ -112,15 +139,16 @@ where } } - /// Call this endpoint on a remote node, without the possibility - /// of adding or receiving a body + /// Call this endpoint on a remote node. This function is the simplified + /// version that doesn't allow to have streams attached to the request + /// or the response; see `call_streaming` for the full version. pub async fn call( &self, target: &NodeID, req: M, prio: RequestPriority, ) -> Result<<M as Message>::Response, Error> { - Ok(self.call_full(target, req, prio).await?.into_msg()) + Ok(self.call_streaming(target, req, prio).await?.into_msg()) } } @@ -144,13 +172,13 @@ pub(crate) trait GenericEndpoint { pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>) where M: Message, - H: EndpointHandler<M>; + H: StreamingEndpointHandler<M>; #[async_trait] impl<M, H> GenericEndpoint for EndpointArc<M, H> where M: Message + 'static, - H: EndpointHandler<M> + 'static, + H: StreamingEndpointHandler<M> + 'static, { async fn handle( &self, diff --git a/src/message.rs b/src/message.rs index d918c29..5721318 100644 --- a/src/message.rs +++ b/src/message.rs @@ -311,7 +311,7 @@ impl Framing { } } - pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>( + pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + Sync + 'static>( mut stream: S, ) -> Result<Self, Error> { let mut packet = stream diff --git a/src/netapp.rs b/src/netapp.rs index 0cebac0..166f560 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -152,7 +152,7 @@ impl NetApp { pub fn endpoint<M, H>(self: &Arc<Self>, path: String) -> Arc<Endpoint<M, H>> where M: Message + 'static, - H: EndpointHandler<M> + 'static, + H: StreamingEndpointHandler<M> + 'static, { let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), path.clone())); let endpoint_arc = EndpointArc(endpoint.clone()); @@ -433,8 +433,7 @@ impl NetApp { #[async_trait] impl EndpointHandler<HelloMessage> for NetApp { - async fn handle(self: &Arc<Self>, msg: Req<HelloMessage>, from: NodeID) -> Resp<HelloMessage> { - let msg = msg.msg(); + async fn handle(self: &Arc<Self>, msg: &HelloMessage, from: NodeID) { debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { @@ -443,6 +442,5 @@ impl EndpointHandler<HelloMessage> for NetApp { h(from, remote_addr, true); } } - Resp::new(()) } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 71dea84..310077f 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -468,24 +468,15 @@ impl Basalt { #[async_trait] impl EndpointHandler<PullMessage> for Basalt { - async fn handle( - self: &Arc<Self>, - _pullmsg: Req<PullMessage>, - _from: NodeID, - ) -> Resp<PullMessage> { - Resp::new(self.make_push_message()) + async fn handle(self: &Arc<Self>, _pullmsg: &PullMessage, _from: NodeID) -> PushMessage { + self.make_push_message() } } #[async_trait] impl EndpointHandler<PushMessage> for Basalt { - async fn handle( - self: &Arc<Self>, - pushmsg: Req<PushMessage>, - _from: NodeID, - ) -> Resp<PushMessage> { - self.handle_peer_list(&pushmsg.msg().peers[..]); - Resp::new(()) + async fn handle(self: &Arc<Self>, pushmsg: &PushMessage, _from: NodeID) { + self.handle_peer_list(&pushmsg.peers[..]); } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 9b7b666..ccbd0ba 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -583,14 +583,13 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy { - async fn handle(self: &Arc<Self>, ping: Req<PingMessage>, from: NodeID) -> Resp<PingMessage> { - let ping = ping.msg(); + async fn handle(self: &Arc<Self>, ping: &PingMessage, from: NodeID) -> PingMessage { let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, }; debug!("Ping from {}", hex::encode(&from[..8])); - Resp::new(ping_resp) + ping_resp } } @@ -598,12 +597,11 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy { impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy { async fn handle( self: &Arc<Self>, - peer_list: Req<PeerListMessage>, + peer_list: &PeerListMessage, _from: NodeID, - ) -> Resp<PeerListMessage> { - let peer_list = peer_list.msg(); + ) -> PeerListMessage { self.handle_peer_list(&peer_list.list[..]); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); - Resp::new(PeerListMessage { list: peer_list }) + PeerListMessage { list: peer_list } } } diff --git a/src/util.rs b/src/util.rs index e7ecea8..01c392c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -24,7 +24,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// Error code 255 means the stream was cut before its end. Other codes have no predefined /// meaning, it's up to your application to define their semantic. -pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send>>; +pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; pub type Packet = Result<Bytes, u8>; |