use bytes::Bytes;
use serde::{Deserialize, Serialize};
use futures::stream::{Stream, StreamExt};
use crate::error::*;
use crate::util::*;
/// Priority of a request (click to read more about priorities).
///
/// This priority value is used to priorize messages
/// in the send queue of the client, and their responses in the send queue of the
/// server. Lower values mean higher priority.
///
/// This mechanism is usefull for messages bigger than the maximum chunk size
/// (set at `0x4000` bytes), such as large file transfers.
/// In such case, all of the messages in the send queue with the highest priority
/// will take turns to send individual chunks, in a round-robin fashion.
/// Once all highest priority messages are sent successfully, the messages with
/// the next highest priority will begin being sent in the same way.
///
/// The same priority value is given to a request and to its associated response.
pub type RequestPriority = u8;
/// Priority class: high
pub const PRIO_HIGH: RequestPriority = 0x20;
/// Priority class: normal
pub const PRIO_NORMAL: RequestPriority = 0x40;
/// Priority class: background
pub const PRIO_BACKGROUND: RequestPriority = 0x80;
/// Priority: primary among given class
pub const PRIO_PRIMARY: RequestPriority = 0x00;
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
pub const PRIO_SECONDARY: RequestPriority = 0x01;
// ----
/// This trait should be implemented by all messages your application
/// wants to handle
pub trait Message: SerializeMessage + Send + Sync {
type Response: SerializeMessage + Send + Sync;
}
/// A trait for de/serializing messages, with possible associated stream.
/// This is default-implemented by anything that can already be serialized
/// and deserialized. Adapters are provided that implement this for
/// adding a body, either from a fixed Bytes buffer (which allows the thing
/// to be Clone), or from a streaming byte stream.
pub trait SerializeMessage: Sized {
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
}
// ----
impl<T> SerializeMessage for T
where
T: Serialize + for<'de> Deserialize<'de> + Send,
{
type SerializableSelf = Self;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self, None)
}
fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
// TODO verify no stream
ser_self
}
}
// ----
/// An adapter that adds a body from a fixed Bytes buffer to a serializable message,
/// implementing the SerializeMessage trait. This allows for the SerializeMessage object
/// to be cloned, which is usefull for requests that must be sent to multiple servers.
/// Note that cloning the body is cheap thanks to Bytes; make sure that your serializable
/// part is also easily clonable (e.g. by wrapping it in an Arc).
/// Note that this CANNOT be used for a response type, as it cannot be reconstructed
/// from a remote stream.
#[derive(Clone)]
pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
pub T,
pub Bytes,
);
impl<T> SerializeMessage for WithFixedBody<T>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
type SerializableSelf = T;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
let body = self.1;
(
self.0,
Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
)
}
fn from_parts(_ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
panic!("Cannot use a WithFixedBody as a response type");
}
}
/// An adapter that adds a body from a ByteStream. This is usefull for receiving
/// responses to requests that contain attached byte streams. This type is
/// not clonable.
pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>(
pub T,
pub ByteStream,
);
impl<T> SerializeMessage for WithStreamingBody<T>
where
T: Serialize + for<'de> Deserialize<'de> + Send,
{
type SerializableSelf = T;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self.0, Some(self.1))
}
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
WithStreamingBody(ser_self, stream)
}
}
// ---- ----
pub(crate) struct QueryMessage<'a> {
pub(crate) prio: RequestPriority,
pub(crate) path: &'a [u8],
pub(crate) telemetry_id: Option<Vec<u8>>,
pub(crate) body: &'a [u8],
}
/// QueryMessage encoding:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
/// - body [u8; ..]
impl<'a> QueryMessage<'a> {
pub(crate) fn encode(self) -> Vec<u8> {
let tel_len = match &self.telemetry_id {
Some(t) => t.len(),
None => 0,
};
let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
ret.push(self.prio);
ret.push(self.path.len() as u8);
ret.extend_from_slice(self.path);
if let Some(t) = self.telemetry_id {
ret.push(t.len() as u8);
ret.extend(t);
} else {
ret.push(0u8);
}
ret.extend_from_slice(self.body);
ret
}
pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
if bytes.len() < 3 {
return Err(Error::Message("Invalid protocol message".into()));
}
let path_length = bytes[1] as usize;
if bytes.len() < 3 + path_length {
return Err(Error::Message("Invalid protocol message".into()));
}
let telemetry_id_len = bytes[2 + path_length] as usize;
if bytes.len() < 3 + path_length + telemetry_id_len {
return Err(Error::Message("Invalid protocol message".into()));
}
let path = &bytes[2..2 + path_length];
let telemetry_id = if telemetry_id_len > 0 {
Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
} else {
None
};
let body = &bytes[3 + path_length + telemetry_id_len..];
Ok(Self {
prio: bytes[0],
path,
telemetry_id,
body,
})
}
}
// ---- ----
pub(crate) struct Framing {
direct: Vec<u8>,
stream: Option<ByteStream>,
}
impl Framing {
pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self {
assert!(direct.len() <= u32::MAX as usize);
Framing { direct, stream }
}
pub fn into_stream(self) -> ByteStream {
use futures::stream;
let len = self.direct.len() as u32;
// required because otherwise the borrow-checker complains
let Framing { direct, stream } = self;
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) })
.chain(stream::once(async move { Ok(direct.into()) }));
if let Some(stream) = stream {
Box::pin(res.chain(stream))
} else {
Box::pin(res)
}
}
pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
mut stream: S,
) -> Result<Self, Error> {
let mut packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
if packet.len() < 4 {
return Err(Error::Framing);
}
let mut len = [0; 4];
len.copy_from_slice(&packet[..4]);
let len = u32::from_be_bytes(len);
packet = packet.slice(4..);
let mut buffer = Vec::new();
let len = len as usize;
loop {
let max_cp = std::cmp::min(len - buffer.len(), packet.len());
buffer.extend_from_slice(&packet[..max_cp]);
if buffer.len() == len {
packet = packet.slice(max_cp..);
break;
}
packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
}
let stream: ByteStream = if packet.is_empty() {
Box::pin(stream)
} else {
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
};
Ok(Framing {
direct: buffer,
stream: Some(stream),
})
}
pub fn into_parts(self) -> (Vec<u8>, ByteStream) {
let Framing { direct, stream } = self;
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
}
}