diff options
Diffstat (limited to 'src/message.rs')
-rw-r--r-- | src/message.rs | 357 |
1 files changed, 208 insertions, 149 deletions
diff --git a/src/message.rs b/src/message.rs index 5721318..ba06551 100644 --- a/src/message.rs +++ b/src/message.rs @@ -2,12 +2,13 @@ use std::fmt; use std::marker::PhantomData; use std::sync::Arc; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use serde::{Deserialize, Serialize}; -use futures::stream::{Stream, StreamExt}; +use futures::stream::StreamExt; use crate::error::*; +use crate::stream::*; use crate::util::*; /// Priority of a request (click to read more about priorities). @@ -45,6 +46,15 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; } +// ---- + +/// The Req<M> is a helper object used to create requests and attach them +/// a streaming body. If the body is a fixed Bytes and not a ByteStream, +/// Req<M> is cheaply clonable to allow the request to be sent to different +/// peers (Clone will panic if the body is a ByteStream). +/// +/// Internally, this is also used to encode and decode requests +/// from/to byte streams to be sent over the network. pub struct Req<M: Message> { pub(crate) _phantom: PhantomData<M>, pub(crate) msg: Arc<M>, @@ -52,30 +62,6 @@ pub struct Req<M: Message> { pub(crate) body: BodyData, } -pub struct Resp<M: Message> { - pub(crate) _phantom: PhantomData<M>, - pub(crate) msg: M::Response, - pub(crate) body: BodyData, -} - -pub(crate) enum BodyData { - None, - Fixed(Bytes), - Stream(ByteStream), -} - -impl BodyData { - pub fn into_stream(self) -> Option<ByteStream> { - match self { - BodyData::None => None, - BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), - BodyData::Stream(s) => Some(s), - } - } -} - -// ---- - impl<M: Message> Req<M> { pub fn msg(&self) -> &M { &self.msg @@ -94,6 +80,31 @@ impl<M: Message> Req<M> { ..self } } + + pub(crate) fn into_enc( + self, + prio: RequestPriority, + path: Bytes, + telemetry_id: Bytes, + ) -> ReqEnc { + ReqEnc { + prio, + path, + telemetry_id, + msg: self.msg_ser.unwrap(), + stream: self.body.into_stream(), + } + } + + pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> { + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Req { + _phantom: Default::default(), + msg: Arc::new(msg), + msg_ser: Some(enc.msg), + body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None), + }) + } } pub trait IntoReq<M: Message> { @@ -160,19 +171,14 @@ where } } -impl<M> fmt::Debug for Resp<M> -where - M: Message, - <M as Message>::Response: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(f, "Resp[{:?}", self.msg)?; - match &self.body { - BodyData::None => write!(f, "]"), - BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), - BodyData::Stream(_) => write!(f, "; body=stream]"), - } - } +// ---- + +/// The Resp<M> represents a full response from a RPC that may have +/// an attached body stream. +pub struct Resp<M: Message> { + pub(crate) _phantom: PhantomData<M>, + pub(crate) msg: M::Response, + pub(crate) body: BodyData, } impl<M: Message> Resp<M> { @@ -205,160 +211,213 @@ impl<M: Message> Resp<M> { pub fn into_msg(self) -> M::Response { self.msg } + + pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> { + Ok(RespEnc::Success { + msg: rmp_to_vec_all_named(&self.msg)?.into(), + stream: self.body.into_stream(), + }) + } + + pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> { + match enc { + RespEnc::Success { msg, stream } => { + let msg = rmp_serde::decode::from_read_ref(&msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + body: stream.map(BodyData::Stream).unwrap_or(BodyData::None), + }) + } + RespEnc::Error { code, message } => Err(Error::Remote(code, message)), + } + } } -// ---- ---- +impl<M> fmt::Debug for Resp<M> +where + M: Message, + <M as Message>::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option<Vec<u8>>, - pub(crate) body: &'a [u8], +// ---- + +pub(crate) enum BodyData { + None, + Fixed(Bytes), + Stream(ByteStream), +} + +impl BodyData { + pub fn into_stream(self) -> Option<ByteStream> { + match self { + BodyData::None => None, + BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + BodyData::Stream(s) => Some(s), + } + } } -/// QueryMessage encoding: +// ---- ---- + +/// Encoding for requests into a ByteStream: /// - priority: u8 /// - path length: u8 /// - path: [u8; path length] /// - telemetry id length: u8 /// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec<u8> { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +pub(crate) struct ReqEnc { + pub(crate) prio: RequestPriority, + pub(crate) path: Bytes, + pub(crate) telemetry_id: Bytes, + pub(crate) msg: Bytes, + pub(crate) stream: Option<ByteStream>, +} + +impl ReqEnc { + pub(crate) fn encode(self) -> ByteStream { + let mut buf = BytesMut::with_capacity(64); + + buf.put_u8(self.prio); - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + buf.put_u8(self.path.len() as u8); + buf.put(self.path); - ret.push(self.prio); + buf.put_u8(self.telemetry_id.len() as u8); + buf.put(&self.telemetry_id[..]); - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); + buf.put_u32(self.msg.len() as u32); + buf.put(&self.msg[..]); - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); + let header = buf.freeze(); + + if let Some(stream) = self.stream { + Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) } else { - ret.push(0u8); + Box::pin(futures::stream::once(async move { Ok(header) })) } + } - ret.extend_from_slice(self.body); - - ret + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream).await.map_err(|_| Error::Framing) } - pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } + pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } + let prio = reader.read_u8().await?; - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } + let path_len = reader.read_u8().await?; + let path = reader.read_exact(path_len as usize).await?; - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; + let telemetry_id_len = reader.read_u8().await?; + let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?; - let body = &bytes[3 + path_length + telemetry_id_len..]; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; Ok(Self { - prio: bytes[0], + prio, path, telemetry_id, - body, + msg, + stream: Some(reader.into_stream()), }) } } -// ---- ---- - -pub(crate) struct Framing { - direct: Vec<u8>, - stream: Option<ByteStream>, +/// Encoding for responses into a ByteStream: +/// IF SUCCESS: +/// - 0: u8 +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +/// IF ERROR: +/// - message length + 1: u8 +/// - error code: u8 +/// - message: [u8; message_length] +pub(crate) enum RespEnc { + Error { + code: u8, + message: String, + }, + Success { + msg: Bytes, + stream: Option<ByteStream>, + }, } -impl Framing { - pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } +impl RespEnc { + pub(crate) fn from_err(e: Error) -> Self { + RespEnc::Error { + code: e.code(), + message: format!("{}", e), + } } - pub fn into_stream(self) -> ByteStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; + pub(crate) fn encode(self) -> ByteStream { + match self { + RespEnc::Success { msg, stream } => { + let mut buf = BytesMut::with_capacity(64); - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) }) - .chain(stream::once(async move { Ok(direct.into()) })); + buf.put_u8(0); - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } + buf.put_u32(msg.len() as u32); + buf.put(&msg[..]); - pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + Sync + 'static>( - mut stream: S, - ) -> Result<Self, Error> { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); + let header = buf.freeze(); + + if let Some(stream) = stream { + Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) + } else { + Box::pin(futures::stream::once(async move { Ok(header) })) + } + } + RespEnc::Error { code, message } => { + let mut buf = BytesMut::with_capacity(64); + buf.put_u8(1 + message.len() as u8); + buf.put_u8(code); + buf.put(message.as_bytes()); + let header = buf.freeze(); + Box::pin(futures::stream::once(async move { Ok(header) })) + } } + } - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet = packet.slice(4..); + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream).await.map_err(|_| Error::Framing) + } - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet = packet.slice(max_cp..); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } + let is_err = reader.read_u8().await?; - let stream: ByteStream = if packet.is_empty() { - Box::pin(stream) + if is_err > 0 { + let code = reader.read_u8().await?; + let message = reader.read_exact(is_err as usize - 1).await?; + let message = String::from_utf8(message.to_vec()).unwrap_or_default(); + Ok(RespEnc::Error { code, message }) } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec<u8>, ByteStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + Ok(RespEnc::Success { + msg, + stream: Some(reader.into_stream()), + }) + } } } |