use std::fmt; use std::marker::PhantomData; use std::sync::Arc; use bytes::Bytes; use serde::{Deserialize, Serialize}; use futures::stream::{Stream, StreamExt}; use crate::error::*; use crate::util::*; /// Priority of a request (click to read more about priorities). /// /// This priority value is used to priorize messages /// in the send queue of the client, and their responses in the send queue of the /// server. Lower values mean higher priority. /// /// This mechanism is usefull for messages bigger than the maximum chunk size /// (set at `0x4000` bytes), such as large file transfers. /// In such case, all of the messages in the send queue with the highest priority /// will take turns to send individual chunks, in a round-robin fashion. /// Once all highest priority messages are sent successfully, the messages with /// the next highest priority will begin being sent in the same way. /// /// The same priority value is given to a request and to its associated response. pub type RequestPriority = u8; /// Priority class: high pub const PRIO_HIGH: RequestPriority = 0x20; /// Priority class: normal pub const PRIO_NORMAL: RequestPriority = 0x40; /// Priority class: background pub const PRIO_BACKGROUND: RequestPriority = 0x80; /// Priority: primary among given class pub const PRIO_PRIMARY: RequestPriority = 0x00; /// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) pub const PRIO_SECONDARY: RequestPriority = 0x01; // ---- /// This trait should be implemented by all messages your application /// wants to handle pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; } pub struct Req { pub(crate) _phantom: PhantomData, pub(crate) msg: Arc, pub(crate) msg_ser: Option, pub(crate) body: BodyData, } pub struct Resp { pub(crate) _phantom: PhantomData, pub(crate) msg: M::Response, pub(crate) body: BodyData, } pub(crate) enum BodyData { None, Fixed(Bytes), Stream(ByteStream), } impl BodyData { pub fn into_stream(self) -> Option { match self { BodyData::None => None, BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), BodyData::Stream(s) => Some(s), } } } // ---- impl Req { pub fn msg(&self) -> &M { &self.msg } pub fn with_fixed_body(self, b: Bytes) -> Self { Self { body: BodyData::Fixed(b), ..self } } pub fn with_streaming_body(self, b: ByteStream) -> Self { Self { body: BodyData::Stream(b), ..self } } } pub trait IntoReq { fn into_req(self) -> Result, rmp_serde::encode::Error>; fn into_req_local(self) -> Req; } impl IntoReq for M { fn into_req(self) -> Result, rmp_serde::encode::Error> { let msg_ser = rmp_to_vec_all_named(&self)?; Ok(Req { _phantom: Default::default(), msg: Arc::new(self), msg_ser: Some(Bytes::from(msg_ser)), body: BodyData::None, }) } fn into_req_local(self) -> Req { Req { _phantom: Default::default(), msg: Arc::new(self), msg_ser: None, body: BodyData::None, } } } impl IntoReq for Req { fn into_req(self) -> Result, rmp_serde::encode::Error> { Ok(self) } fn into_req_local(self) -> Req { self } } impl Clone for Req { fn clone(&self) -> Self { let body = match &self.body { BodyData::None => BodyData::None, BodyData::Fixed(b) => BodyData::Fixed(b.clone()), BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"), }; Self { _phantom: Default::default(), msg: self.msg.clone(), msg_ser: self.msg_ser.clone(), body, } } } impl fmt::Debug for Req where M: Message + fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(f, "Req[{:?}", self.msg)?; match &self.body { BodyData::None => write!(f, "]"), BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), BodyData::Stream(_) => write!(f, "; body=stream]"), } } } impl fmt::Debug for Resp where M: Message, ::Response: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(f, "Resp[{:?}", self.msg)?; match &self.body { BodyData::None => write!(f, "]"), BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), BodyData::Stream(_) => write!(f, "; body=stream]"), } } } impl Resp { pub fn new(v: M::Response) -> Self { Resp { _phantom: Default::default(), msg: v, body: BodyData::None, } } pub fn with_fixed_body(self, b: Bytes) -> Self { Self { body: BodyData::Fixed(b), ..self } } pub fn with_streaming_body(self, b: ByteStream) -> Self { Self { body: BodyData::Stream(b), ..self } } pub fn msg(&self) -> &M::Response { &self.msg } pub fn into_msg(self) -> M::Response { self.msg } } // ---- ---- pub(crate) struct QueryMessage<'a> { pub(crate) prio: RequestPriority, pub(crate) path: &'a [u8], pub(crate) telemetry_id: Option>, pub(crate) body: &'a [u8], } /// QueryMessage encoding: /// - priority: u8 /// - path length: u8 /// - path: [u8; path length] /// - telemetry id length: u8 /// - telemetry id: [u8; telemetry id length] /// - body [u8; ..] impl<'a> QueryMessage<'a> { pub(crate) fn encode(self) -> Vec { let tel_len = match &self.telemetry_id { Some(t) => t.len(), None => 0, }; let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); ret.push(self.prio); ret.push(self.path.len() as u8); ret.extend_from_slice(self.path); if let Some(t) = self.telemetry_id { ret.push(t.len() as u8); ret.extend(t); } else { ret.push(0u8); } ret.extend_from_slice(self.body); ret } pub(crate) fn decode(bytes: &'a [u8]) -> Result { if bytes.len() < 3 { return Err(Error::Message("Invalid protocol message".into())); } let path_length = bytes[1] as usize; if bytes.len() < 3 + path_length { return Err(Error::Message("Invalid protocol message".into())); } let telemetry_id_len = bytes[2 + path_length] as usize; if bytes.len() < 3 + path_length + telemetry_id_len { return Err(Error::Message("Invalid protocol message".into())); } let path = &bytes[2..2 + path_length]; let telemetry_id = if telemetry_id_len > 0 { Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) } else { None }; let body = &bytes[3 + path_length + telemetry_id_len..]; Ok(Self { prio: bytes[0], path, telemetry_id, body, }) } } // ---- ---- pub(crate) struct Framing { direct: Vec, stream: Option, } impl Framing { pub fn new(direct: Vec, stream: Option) -> Self { assert!(direct.len() <= u32::MAX as usize); Framing { direct, stream } } pub fn into_stream(self) -> ByteStream { use futures::stream; let len = self.direct.len() as u32; // required because otherwise the borrow-checker complains let Framing { direct, stream } = self; let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) }) .chain(stream::once(async move { Ok(direct.into()) })); if let Some(stream) = stream { Box::pin(res.chain(stream)) } else { Box::pin(res) } } pub async fn from_stream + Unpin + Send + Sync + 'static>( mut stream: S, ) -> Result { let mut packet = stream .next() .await .ok_or(Error::Framing)? .map_err(|_| Error::Framing)?; if packet.len() < 4 { return Err(Error::Framing); } let mut len = [0; 4]; len.copy_from_slice(&packet[..4]); let len = u32::from_be_bytes(len); packet = packet.slice(4..); let mut buffer = Vec::new(); let len = len as usize; loop { let max_cp = std::cmp::min(len - buffer.len(), packet.len()); buffer.extend_from_slice(&packet[..max_cp]); if buffer.len() == len { packet = packet.slice(max_cp..); break; } packet = stream .next() .await .ok_or(Error::Framing)? .map_err(|_| Error::Framing)?; } let stream: ByteStream = if packet.is_empty() { Box::pin(stream) } else { Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) }; Ok(Framing { direct: buffer, stream: Some(stream), }) } pub fn into_parts(self) -> (Vec, ByteStream) { let Framing { direct, stream } = self; (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) } }