diff options
Diffstat (limited to 'src/message.rs')
-rw-r--r-- | src/message.rs | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..dbcc857 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,255 @@ +use async_trait::async_trait; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::error::*; +use crate::util::*; + +/// Priority of a request (click to read more about priorities). +/// +/// This priority value is used to priorize messages +/// in the send queue of the client, and their responses in the send queue of the +/// server. Lower values mean higher priority. +/// +/// This mechanism is usefull for messages bigger than the maximum chunk size +/// (set at `0x4000` bytes), such as large file transfers. +/// In such case, all of the messages in the send queue with the highest priority +/// will take turns to send individual chunks, in a round-robin fashion. +/// Once all highest priority messages are sent successfully, the messages with +/// the next highest priority will begin being sent in the same way. +/// +/// The same priority value is given to a request and to its associated response. +pub type RequestPriority = u8; + +/// Priority class: high +pub const PRIO_HIGH: RequestPriority = 0x20; +/// Priority class: normal +pub const PRIO_NORMAL: RequestPriority = 0x40; +/// Priority class: background +pub const PRIO_BACKGROUND: RequestPriority = 0x80; +/// Priority: primary among given class +pub const PRIO_PRIMARY: RequestPriority = 0x00; +/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) +pub const PRIO_SECONDARY: RequestPriority = 0x01; + +// ---- + +/// This trait should be implemented by all messages your application +/// wants to handle +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +} + +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + +#[async_trait] +impl<T> SerializeMessage for T +where + T: AutoSerialize, +{ + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) { + (self.clone(), None) + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { + // TODO verify no stream + ser_self + } +} + +impl AutoSerialize for () {} + +#[async_trait] +impl<T, E> SerializeMessage for Result<T, E> +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), + } + } +} + +// ---- + +pub(crate) struct QueryMessage<'a> { + pub(crate) prio: RequestPriority, + pub(crate) path: &'a [u8], + pub(crate) telemetry_id: Option<Vec<u8>>, + pub(crate) body: &'a [u8], +} + +/// QueryMessage encoding: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - body [u8; ..] +impl<'a> QueryMessage<'a> { + pub(crate) fn encode(self) -> Vec<u8> { + let tel_len = match &self.telemetry_id { + Some(t) => t.len(), + None => 0, + }; + + let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + + ret.push(self.prio); + + ret.push(self.path.len() as u8); + ret.extend_from_slice(self.path); + + if let Some(t) = self.telemetry_id { + ret.push(t.len() as u8); + ret.extend(t); + } else { + ret.push(0u8); + } + + ret.extend_from_slice(self.body); + + ret + } + + pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> { + if bytes.len() < 3 { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path_length = bytes[1] as usize; + if bytes.len() < 3 + path_length { + return Err(Error::Message("Invalid protocol message".into())); + } + + let telemetry_id_len = bytes[2 + path_length] as usize; + if bytes.len() < 3 + path_length + telemetry_id_len { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path = &bytes[2..2 + path_length]; + let telemetry_id = if telemetry_id_len > 0 { + Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) + } else { + None + }; + + let body = &bytes[3 + path_length + telemetry_id_len..]; + + Ok(Self { + prio: bytes[0], + path, + telemetry_id, + body, + }) + } +} + +pub(crate) struct Framing { + direct: Vec<u8>, + stream: Option<ByteStream>, +} + +impl Framing { + pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } + } + + pub fn into_stream(self) -> ByteStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; + + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); + + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) + } + } + + pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>( + mut stream: S, + ) -> Result<Self, Error> { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + } + + let stream: ByteStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec<u8>, ByteStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + } +} |