diff options
Diffstat (limited to 'src/message.rs')
-rw-r--r-- | src/message.rs | 167 |
1 files changed, 102 insertions, 65 deletions
diff --git a/src/message.rs b/src/message.rs index 61d01d0..ca68cac 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use std::sync::Arc; use bytes::{BufMut, Bytes, BytesMut}; +use rand::prelude::*; use serde::{Deserialize, Serialize}; use futures::stream::StreamExt; @@ -40,6 +41,24 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; // ---- +#[derive(Clone, Copy)] +pub struct OrderTagStream(u64); +#[derive(Clone, Copy, Serialize, Deserialize, Debug)] +pub struct OrderTag(u64, u64); + +impl OrderTag { + pub fn stream() -> OrderTagStream { + OrderTagStream(thread_rng().gen()) + } +} +impl OrderTagStream { + pub fn order(&self, order: u64) -> OrderTag { + OrderTag(self.0, order) + } +} + +// ---- + /// This trait should be implemented by all messages your application /// wants to handle pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { @@ -56,6 +75,7 @@ pub struct Req<M: Message> { pub(crate) msg: Arc<M>, pub(crate) msg_ser: Option<Bytes>, pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option<OrderTag>, } impl<M: Message> Req<M> { @@ -77,6 +97,13 @@ impl<M: Message> Req<M> { } } + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + pub fn msg(&self) -> &M { &self.msg } @@ -97,6 +124,7 @@ impl<M: Message> Req<M> { telemetry_id, msg: self.msg_ser.unwrap(), stream: self.stream.into_stream(), + order_tag: self.order_tag, } } @@ -109,6 +137,7 @@ impl<M: Message> Req<M> { .stream .map(AttachedStream::Stream) .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, }) } } @@ -125,6 +154,7 @@ impl<M: Message> IntoReq<M> for M { msg: Arc::new(self), msg_ser: Some(Bytes::from(msg_ser)), stream: AttachedStream::None, + order_tag: None, }) } fn into_req_local(self) -> Req<M> { @@ -132,6 +162,7 @@ impl<M: Message> IntoReq<M> for M { msg: Arc::new(self), msg_ser: None, stream: AttachedStream::None, + order_tag: None, } } } @@ -158,6 +189,7 @@ impl<M: Message> Clone for Req<M> { msg: self.msg.clone(), msg_ser: self.msg_ser.clone(), stream, + order_tag: self.order_tag, } } } @@ -184,6 +216,7 @@ pub struct Resp<M: Message> { pub(crate) _phantom: PhantomData<M>, pub(crate) msg: M::Response, pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option<OrderTag>, } impl<M: Message> Resp<M> { @@ -192,6 +225,7 @@ impl<M: Message> Resp<M> { _phantom: Default::default(), msg: v, stream: AttachedStream::None, + order_tag: None, } } @@ -209,6 +243,13 @@ impl<M: Message> Resp<M> { } } + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + pub fn msg(&self) -> &M::Response { &self.msg } @@ -222,26 +263,24 @@ impl<M: Message> Resp<M> { } pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> { - Ok(RespEnc::Success { + Ok(RespEnc { msg: rmp_to_vec_all_named(&self.msg)?.into(), stream: self.stream.into_stream(), + order_tag: self.order_tag, }) } pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> { - match enc { - RespEnc::Success { msg, stream } => { - let msg = rmp_serde::decode::from_read_ref(&msg)?; - Ok(Self { - _phantom: Default::default(), - msg, - stream: stream - .map(AttachedStream::Stream) - .unwrap_or(AttachedStream::None), - }) - } - RespEnc::Error { code, message } => Err(Error::Remote(code, message)), - } + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) } } @@ -295,10 +334,11 @@ pub(crate) struct ReqEnc { pub(crate) telemetry_id: Bytes, pub(crate) msg: Bytes, pub(crate) stream: Option<ByteStream>, + pub(crate) order_tag: Option<OrderTag>, } impl ReqEnc { - pub(crate) fn encode(self) -> ByteStream { + pub(crate) fn encode(self) -> (ByteStream, Option<OrderTag>) { let mut buf = BytesMut::with_capacity( self.path.len() + self.telemetry_id.len() + self.msg.len() + 16, ); @@ -315,15 +355,18 @@ impl ReqEnc { let header = buf.freeze(); - if let Some(stream) = self.stream { + let res_stream: ByteStream = if let Some(stream) = self.stream { Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream)) } else { Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)])) - } + }; + (res_stream, self.order_tag) } pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { - Self::decode_aux(stream).await.map_err(|_| Error::Framing) + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) } async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { @@ -346,6 +389,7 @@ impl ReqEnc { telemetry_id, msg, stream: Some(reader.into_stream()), + order_tag: None, }) } } @@ -360,74 +404,67 @@ impl ReqEnc { /// - message length + 1: u8 /// - error code: u8 /// - message: [u8; message_length] -pub(crate) enum RespEnc { - Error { - code: u8, - message: String, - }, - Success { - msg: Bytes, - stream: Option<ByteStream>, - }, +pub(crate) struct RespEnc { + msg: Bytes, + stream: Option<ByteStream>, + order_tag: Option<OrderTag>, } impl RespEnc { - pub(crate) fn from_err(e: Error) -> Self { - RespEnc::Error { - code: e.code(), - message: format!("{}", e), - } - } - - pub(crate) fn encode(self) -> ByteStream { - match self { - RespEnc::Success { msg, stream } => { - let mut buf = BytesMut::with_capacity(msg.len() + 8); - - buf.put_u8(0); + pub(crate) fn encode(resp: Result<Self, Error>) -> (ByteStream, Option<OrderTag>) { + match resp { + Ok(Self { + msg, + stream, + order_tag, + }) => { + let mut buf = BytesMut::with_capacity(4); buf.put_u32(msg.len() as u32); - let header = buf.freeze(); - if let Some(stream) = stream { + let res_stream: ByteStream = if let Some(stream) = stream { Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream)) } else { Box::pin(futures::stream::iter([Ok(header), Ok(msg)])) - } + }; + (res_stream, order_tag) } - RespEnc::Error { code, message } => { - let mut buf = BytesMut::with_capacity(message.len() + 8); - buf.put_u8(1 + message.len() as u8); - buf.put_u8(code); - buf.put(message.as_bytes()); - let header = buf.freeze(); - Box::pin(futures::stream::once(async move { Ok(header) })) + Err(err) => { + let err = std::io::Error::new( + std::io::ErrorKind::Other, + format!("netapp error: {}", err), + ); + ( + Box::pin(futures::stream::once(async move { Err(err) })), + None, + ) } } } pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { - Self::decode_aux(stream).await.map_err(|_| Error::Framing) + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) } async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { let mut reader = ByteStreamReader::new(stream); - let is_err = reader.read_u8().await?; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; - if is_err > 0 { - let code = reader.read_u8().await?; - let message = reader.read_exact(is_err as usize - 1).await?; - let message = String::from_utf8(message.to_vec()).unwrap_or_default(); - Ok(RespEnc::Error { code, message }) - } else { - let msg_len = reader.read_u32().await?; - let msg = reader.read_exact(msg_len as usize).await?; + Ok(Self { + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} - Ok(RespEnc::Success { - msg, - stream: Some(reader.into_stream()), - }) - } +fn read_exact_error_to_error(e: ReadExactError) -> Error { + match e { + ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()), + ReadExactError::UnexpectedEos => Error::Framing, } } |