aboutsummaryrefslogblamecommitdiff
path: root/src/message.rs
blob: f92eb8c99c31254ec273e50e5c4ed344bd1d4b1d (plain) (tree)
1
2
3
4
5
6
                             
                 

                                    

                                         






































                                                                                 


                                                                            






















                                                                                    
 





                                                                                     

 


                                                                                       
 

                              
                         

                                     

                                                                             

         
                                                                                      




                                        
                            
 










                                                                                            
     
                                                                          
 








                                                                                       
 

                                                                                     
         
 
 
















                                                                                     


         
            









































































                                                                                                   

            
















                                                                          

                                                                                                






















                                                                                    
                                           







                                                                                     
                                                                

























                                                                                                
use async_trait::async_trait;
use bytes::Bytes;
use serde::{Deserialize, Serialize};

use futures::stream::{Stream, StreamExt};

use crate::error::*;
use crate::util::*;

/// Priority of a request (click to read more about priorities).
///
/// This priority value is used to priorize messages
/// in the send queue of the client, and their responses in the send queue of the
/// server. Lower values mean higher priority.
///
/// This mechanism is usefull for messages bigger than the maximum chunk size
/// (set at `0x4000` bytes), such as large file transfers.
/// In such case, all of the messages in the send queue with the highest priority
/// will take turns to send individual chunks, in a round-robin fashion.
/// Once all highest priority messages are sent successfully, the messages with
/// the next highest priority will begin being sent in the same way.
///
/// The same priority value is given to a request and to its associated response.
pub type RequestPriority = u8;

/// Priority class: high
pub const PRIO_HIGH: RequestPriority = 0x20;
/// Priority class: normal
pub const PRIO_NORMAL: RequestPriority = 0x40;
/// Priority class: background
pub const PRIO_BACKGROUND: RequestPriority = 0x80;
/// Priority: primary among given class
pub const PRIO_PRIMARY: RequestPriority = 0x00;
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
pub const PRIO_SECONDARY: RequestPriority = 0x01;

// ----

/// This trait should be implemented by all messages your application
/// wants to handle
pub trait Message: SerializeMessage + Send + Sync {
	type Response: SerializeMessage + Send + Sync;
}

/// A trait for de/serializing messages, with possible associated stream.
pub trait SerializeMessage: Sized {
	type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;

	fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);

	fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
}

// ----

impl<T, E> SerializeMessage for Result<T, E>
where
	T: SerializeMessage + Send,
	E: Serialize + for<'de> Deserialize<'de> + Send,
{
	type SerializableSelf = Result<T::SerializableSelf, E>;

	fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
		match self {
			Ok(ok) => {
				let (msg, stream) = ok.into_parts();
				(Ok(msg), stream)
			}
			Err(err) => (Err(err), None),
		}
	}

	fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
		match ser_self {
			Ok(ok) => Ok(T::from_parts(ok, stream)),
			Err(err) => Err(err),
		}
	}
}

// ---

pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}

impl<T> SerializeMessage for T
where
	T: SimpleMessage,
{
	type SerializableSelf = Self;
	fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
		(self, None)
	}

	fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
		// TODO verify no stream
		ser_self
	}
}

impl SimpleMessage for () {}

impl<T: SimpleMessage> SimpleMessage for std::sync::Arc<T> {}

// ----

#[derive(Clone)]
pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
	pub T,
	pub Bytes,
);

impl<T> SerializeMessage for WithFixedBody<T>
where
	T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
	type SerializableSelf = T;

	fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
		let body = self.1;
		(
			self.0,
			Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
		)
	}

	fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
		panic!("Cannot reconstruct a WithFixedBody type from parts");
	}
}

pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>(
	pub T,
	pub ByteStream,
);

impl<T> SerializeMessage for WithStreamingBody<T>
where
	T: Serialize + for<'de> Deserialize<'de> + Send,
{
	type SerializableSelf = T;

	fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
		(self.0, Some(self.1))
	}

	fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
		WithStreamingBody(ser_self, stream)
	}
}

// ---- ----

pub(crate) struct QueryMessage<'a> {
	pub(crate) prio: RequestPriority,
	pub(crate) path: &'a [u8],
	pub(crate) telemetry_id: Option<Vec<u8>>,
	pub(crate) body: &'a [u8],
}

/// QueryMessage encoding:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
/// - body [u8; ..]
impl<'a> QueryMessage<'a> {
	pub(crate) fn encode(self) -> Vec<u8> {
		let tel_len = match &self.telemetry_id {
			Some(t) => t.len(),
			None => 0,
		};

		let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());

		ret.push(self.prio);

		ret.push(self.path.len() as u8);
		ret.extend_from_slice(self.path);

		if let Some(t) = self.telemetry_id {
			ret.push(t.len() as u8);
			ret.extend(t);
		} else {
			ret.push(0u8);
		}

		ret.extend_from_slice(self.body);

		ret
	}

	pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
		if bytes.len() < 3 {
			return Err(Error::Message("Invalid protocol message".into()));
		}

		let path_length = bytes[1] as usize;
		if bytes.len() < 3 + path_length {
			return Err(Error::Message("Invalid protocol message".into()));
		}

		let telemetry_id_len = bytes[2 + path_length] as usize;
		if bytes.len() < 3 + path_length + telemetry_id_len {
			return Err(Error::Message("Invalid protocol message".into()));
		}

		let path = &bytes[2..2 + path_length];
		let telemetry_id = if telemetry_id_len > 0 {
			Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
		} else {
			None
		};

		let body = &bytes[3 + path_length + telemetry_id_len..];

		Ok(Self {
			prio: bytes[0],
			path,
			telemetry_id,
			body,
		})
	}
}

// ---- ----

pub(crate) struct Framing {
	direct: Vec<u8>,
	stream: Option<ByteStream>,
}

impl Framing {
	pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self {
		assert!(direct.len() <= u32::MAX as usize);
		Framing { direct, stream }
	}

	pub fn into_stream(self) -> ByteStream {
		use futures::stream;
		let len = self.direct.len() as u32;
		// required because otherwise the borrow-checker complains
		let Framing { direct, stream } = self;

		let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) })
			.chain(stream::once(async move { Ok(direct.into()) }));

		if let Some(stream) = stream {
			Box::pin(res.chain(stream))
		} else {
			Box::pin(res)
		}
	}

	pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
		mut stream: S,
	) -> Result<Self, Error> {
		let mut packet = stream
			.next()
			.await
			.ok_or(Error::Framing)?
			.map_err(|_| Error::Framing)?;
		if packet.len() < 4 {
			return Err(Error::Framing);
		}

		let mut len = [0; 4];
		len.copy_from_slice(&packet[..4]);
		let len = u32::from_be_bytes(len);
		packet = packet.slice(4..);

		let mut buffer = Vec::new();
		let len = len as usize;
		loop {
			let max_cp = std::cmp::min(len - buffer.len(), packet.len());

			buffer.extend_from_slice(&packet[..max_cp]);
			if buffer.len() == len {
				packet = packet.slice(max_cp..);
				break;
			}
			packet = stream
				.next()
				.await
				.ok_or(Error::Framing)?
				.map_err(|_| Error::Framing)?;
		}

		let stream: ByteStream = if packet.is_empty() {
			Box::pin(stream)
		} else {
			Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
		};

		Ok(Framing {
			direct: buffer,
			stream: Some(stream),
		})
	}

	pub fn into_parts(self) -> (Vec<u8>, ByteStream) {
		let Framing { direct, stream } = self;
		(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
	}
}