diff options
Diffstat (limited to 'src/endpoint.rs')
-rw-r--r-- | src/endpoint.rs | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/src/endpoint.rs b/src/endpoint.rs new file mode 100644 index 0000000..0e1f5c8 --- /dev/null +++ b/src/endpoint.rs @@ -0,0 +1,125 @@ +use std::marker::PhantomData; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; + +use serde::{Deserialize, Serialize}; + +use crate::error::Error; +use crate::netapp::*; +use crate::proto::*; +use crate::util::*; + +/// This trait should be implemented by all messages your application +/// wants to handle (click to read more). +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +} + +pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>; + +#[async_trait] +pub trait EndpointHandler<M>: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response; +} + +pub struct Endpoint<M, H> +where + M: Message, + H: EndpointHandler<M>, +{ + phantom: PhantomData<M>, + netapp: Arc<NetApp>, + path: String, + handler: ArcSwapOption<H>, +} + +impl<M, H> Endpoint<M, H> +where + M: Message, + H: EndpointHandler<M>, +{ + pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self { + Self { + phantom: PhantomData::default(), + netapp, + path, + handler: ArcSwapOption::from(None), + } + } + pub fn set_handler(&self, h: Arc<H>) { + self.handler.swap(Some(h)); + } + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<<M as Message>::Response, Error> { + if *target == self.netapp.id { + match self.handler.load_full() { + None => Err(Error::NoHandler), + Some(h) => Ok(h.handle(req, self.netapp.id).await), + } + } else { + let conn = self + .netapp + .client_conns + .read() + .unwrap() + .get(target) + .cloned(); + match conn { + None => Err(Error::Message(format!( + "Not connected: {}", + hex::encode(target) + ))), + Some(c) => c.call(req, self.path.as_str(), prio).await, + } + } + } +} + +#[async_trait] +pub(crate) trait GenericEndpoint { + async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>; + fn clear_handler(&self); + fn clone_endpoint(&self) -> DynEndpoint; +} + +#[derive(Clone)] +pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>) +where + M: Message, + H: EndpointHandler<M>; + +#[async_trait] +impl<M, H> GenericEndpoint for EndpointArc<M, H> +where + M: Message + 'static, + H: EndpointHandler<M> + 'static, +{ + async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> { + match self.0.handler.load_full() { + None => Err(Error::NoHandler), + Some(h) => { + let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; + let res = h.handle(req, from).await; + let res_bytes = rmp_to_vec_all_named(&res)?; + Ok(res_bytes) + } + } + } + + fn clear_handler(&self) { + self.0.handler.swap(None); + } + + fn clone_endpoint(&self) -> DynEndpoint { + Box::new(Self(self.0.clone())) + } +} |