aboutsummaryrefslogtreecommitdiff
path: root/src/message.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/message.rs')
-rw-r--r--src/message.rs255
1 files changed, 255 insertions, 0 deletions
diff --git a/src/message.rs b/src/message.rs
new file mode 100644
index 0000000..dbcc857
--- /dev/null
+++ b/src/message.rs
@@ -0,0 +1,255 @@
+use async_trait::async_trait;
+use futures::stream::{Stream, StreamExt};
+use serde::{Deserialize, Serialize};
+
+use crate::error::*;
+use crate::util::*;
+
+/// Priority of a request (click to read more about priorities).
+///
+/// This priority value is used to priorize messages
+/// in the send queue of the client, and their responses in the send queue of the
+/// server. Lower values mean higher priority.
+///
+/// This mechanism is usefull for messages bigger than the maximum chunk size
+/// (set at `0x4000` bytes), such as large file transfers.
+/// In such case, all of the messages in the send queue with the highest priority
+/// will take turns to send individual chunks, in a round-robin fashion.
+/// Once all highest priority messages are sent successfully, the messages with
+/// the next highest priority will begin being sent in the same way.
+///
+/// The same priority value is given to a request and to its associated response.
+pub type RequestPriority = u8;
+
+/// Priority class: high
+pub const PRIO_HIGH: RequestPriority = 0x20;
+/// Priority class: normal
+pub const PRIO_NORMAL: RequestPriority = 0x40;
+/// Priority class: background
+pub const PRIO_BACKGROUND: RequestPriority = 0x80;
+/// Priority: primary among given class
+pub const PRIO_PRIMARY: RequestPriority = 0x00;
+/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
+pub const PRIO_SECONDARY: RequestPriority = 0x01;
+
+// ----
+
+/// This trait should be implemented by all messages your application
+/// wants to handle
+pub trait Message: SerializeMessage + Send + Sync {
+ type Response: SerializeMessage + Send + Sync;
+}
+
+/// A trait for de/serializing messages, with possible associated stream.
+#[async_trait]
+pub trait SerializeMessage: Sized {
+ type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
+
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>);
+
+ // TODO should return Result
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
+}
+
+pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
+
+#[async_trait]
+impl<T> SerializeMessage for T
+where
+ T: AutoSerialize,
+{
+ type SerializableSelf = Self;
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ (self.clone(), None)
+ }
+
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
+ // TODO verify no stream
+ ser_self
+ }
+}
+
+impl AutoSerialize for () {}
+
+#[async_trait]
+impl<T, E> SerializeMessage for Result<T, E>
+where
+ T: SerializeMessage + Send,
+ E: SerializeMessage + Send,
+{
+ type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
+
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ match self {
+ Ok(ok) => {
+ let (msg, stream) = ok.serialize_msg();
+ (Ok(msg), stream)
+ }
+ Err(err) => {
+ let (msg, stream) = err.serialize_msg();
+ (Err(msg), stream)
+ }
+ }
+ }
+
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
+ match ser_self {
+ Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
+ Err(err) => Err(E::deserialize_msg(err, stream).await),
+ }
+ }
+}
+
+// ----
+
+pub(crate) struct QueryMessage<'a> {
+ pub(crate) prio: RequestPriority,
+ pub(crate) path: &'a [u8],
+ pub(crate) telemetry_id: Option<Vec<u8>>,
+ pub(crate) body: &'a [u8],
+}
+
+/// QueryMessage encoding:
+/// - priority: u8
+/// - path length: u8
+/// - path: [u8; path length]
+/// - telemetry id length: u8
+/// - telemetry id: [u8; telemetry id length]
+/// - body [u8; ..]
+impl<'a> QueryMessage<'a> {
+ pub(crate) fn encode(self) -> Vec<u8> {
+ let tel_len = match &self.telemetry_id {
+ Some(t) => t.len(),
+ None => 0,
+ };
+
+ let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
+
+ ret.push(self.prio);
+
+ ret.push(self.path.len() as u8);
+ ret.extend_from_slice(self.path);
+
+ if let Some(t) = self.telemetry_id {
+ ret.push(t.len() as u8);
+ ret.extend(t);
+ } else {
+ ret.push(0u8);
+ }
+
+ ret.extend_from_slice(self.body);
+
+ ret
+ }
+
+ pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
+ if bytes.len() < 3 {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let path_length = bytes[1] as usize;
+ if bytes.len() < 3 + path_length {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let telemetry_id_len = bytes[2 + path_length] as usize;
+ if bytes.len() < 3 + path_length + telemetry_id_len {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let path = &bytes[2..2 + path_length];
+ let telemetry_id = if telemetry_id_len > 0 {
+ Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
+ } else {
+ None
+ };
+
+ let body = &bytes[3 + path_length + telemetry_id_len..];
+
+ Ok(Self {
+ prio: bytes[0],
+ path,
+ telemetry_id,
+ body,
+ })
+ }
+}
+
+pub(crate) struct Framing {
+ direct: Vec<u8>,
+ stream: Option<ByteStream>,
+}
+
+impl Framing {
+ pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self {
+ assert!(direct.len() <= u32::MAX as usize);
+ Framing { direct, stream }
+ }
+
+ pub fn into_stream(self) -> ByteStream {
+ use futures::stream;
+ let len = self.direct.len() as u32;
+ // required because otherwise the borrow-checker complains
+ let Framing { direct, stream } = self;
+
+ let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
+ .chain(stream::once(async move { Ok(direct) }));
+
+ if let Some(stream) = stream {
+ Box::pin(res.chain(stream))
+ } else {
+ Box::pin(res)
+ }
+ }
+
+ pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
+ mut stream: S,
+ ) -> Result<Self, Error> {
+ let mut packet = stream
+ .next()
+ .await
+ .ok_or(Error::Framing)?
+ .map_err(|_| Error::Framing)?;
+ if packet.len() < 4 {
+ return Err(Error::Framing);
+ }
+
+ let mut len = [0; 4];
+ len.copy_from_slice(&packet[..4]);
+ let len = u32::from_be_bytes(len);
+ packet.drain(..4);
+
+ let mut buffer = Vec::new();
+ let len = len as usize;
+ loop {
+ let max_cp = std::cmp::min(len - buffer.len(), packet.len());
+
+ buffer.extend_from_slice(&packet[..max_cp]);
+ if buffer.len() == len {
+ packet.drain(..max_cp);
+ break;
+ }
+ packet = stream
+ .next()
+ .await
+ .ok_or(Error::Framing)?
+ .map_err(|_| Error::Framing)?;
+ }
+
+ let stream: ByteStream = if packet.is_empty() {
+ Box::pin(stream)
+ } else {
+ Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
+ };
+
+ Ok(Framing {
+ direct: buffer,
+ stream: Some(stream),
+ })
+ }
+
+ pub fn into_parts(self) -> (Vec<u8>, ByteStream) {
+ let Framing { direct, stream } = self;
+ (direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
+ }
+}