use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
use bytes::{BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use futures::stream::StreamExt;
use crate::error::*;
use crate::stream::*;
use crate::util::*;
/// Priority of a request (click to read more about priorities).
///
/// This priority value is used to priorize messages
/// in the send queue of the client, and their responses in the send queue of the
/// server. Lower values mean higher priority.
///
/// This mechanism is usefull for messages bigger than the maximum chunk size
/// (set at `0x4000` bytes), such as large file transfers.
/// In such case, all of the messages in the send queue with the highest priority
/// will take turns to send individual chunks, in a round-robin fashion.
/// Once all highest priority messages are sent successfully, the messages with
/// the next highest priority will begin being sent in the same way.
///
/// The same priority value is given to a request and to its associated response.
pub type RequestPriority = u8;
/// Priority class: high
pub const PRIO_HIGH: RequestPriority = 0x20;
/// Priority class: normal
pub const PRIO_NORMAL: RequestPriority = 0x40;
/// Priority class: background
pub const PRIO_BACKGROUND: RequestPriority = 0x80;
/// Priority: primary among given class
pub const PRIO_PRIMARY: RequestPriority = 0x00;
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
pub const PRIO_SECONDARY: RequestPriority = 0x01;
// ----
/// This trait should be implemented by all messages your application
/// wants to handle
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
}
// ----
/// The Req<M> is a helper object used to create requests and attach them
/// a streaming body. If the body is a fixed Bytes and not a ByteStream,
/// Req<M> is cheaply clonable to allow the request to be sent to different
/// peers (Clone will panic if the body is a ByteStream).
///
/// Internally, this is also used to encode and decode requests
/// from/to byte streams to be sent over the network.
pub struct Req<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: Arc<M>,
pub(crate) msg_ser: Option<Bytes>,
pub(crate) body: BodyData,
}
impl<M: Message> Req<M> {
pub fn msg(&self) -> &M {
&self.msg
}
pub fn with_fixed_body(self, b: Bytes) -> Self {
Self {
body: BodyData::Fixed(b),
..self
}
}
pub fn with_streaming_body(self, b: ByteStream) -> Self {
Self {
body: BodyData::Stream(b),
..self
}
}
pub fn take_stream(&mut self) -> Option<ByteStream> {
std::mem::replace(&mut self.body, BodyData::None).into_stream()
}
pub(crate) fn into_enc(
self,
prio: RequestPriority,
path: Bytes,
telemetry_id: Bytes,
) -> ReqEnc {
ReqEnc {
prio,
path,
telemetry_id,
msg: self.msg_ser.unwrap(),
stream: self.body.into_stream(),
}
}
pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> {
let msg = rmp_serde::decode::from_read_ref(&enc.msg)?;
Ok(Req {
_phantom: Default::default(),
msg: Arc::new(msg),
msg_ser: Some(enc.msg),
body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None),
})
}
}
pub trait IntoReq<M: Message> {
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>;
fn into_req_local(self) -> Req<M>;
}
impl<M: Message> IntoReq<M> for M {
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
let msg_ser = rmp_to_vec_all_named(&self)?;
Ok(Req {
_phantom: Default::default(),
msg: Arc::new(self),
msg_ser: Some(Bytes::from(msg_ser)),
body: BodyData::None,
})
}
fn into_req_local(self) -> Req<M> {
Req {
_phantom: Default::default(),
msg: Arc::new(self),
msg_ser: None,
body: BodyData::None,
}
}
}
impl<M: Message> IntoReq<M> for Req<M> {
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
Ok(self)
}
fn into_req_local(self) -> Req<M> {
self
}
}
impl<M: Message> Clone for Req<M> {
fn clone(&self) -> Self {
let body = match &self.body {
BodyData::None => BodyData::None,
BodyData::Fixed(b) => BodyData::Fixed(b.clone()),
BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"),
};
Self {
_phantom: Default::default(),
msg: self.msg.clone(),
msg_ser: self.msg_ser.clone(),
body,
}
}
}
impl<M> fmt::Debug for Req<M>
where
M: Message + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Req[{:?}", self.msg)?;
match &self.body {
BodyData::None => write!(f, "]"),
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
BodyData::Stream(_) => write!(f, "; body=stream]"),
}
}
}
// ----
/// The Resp<M> represents a full response from a RPC that may have
/// an attached body stream.
pub struct Resp<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: M::Response,
pub(crate) body: BodyData,
}
impl<M: Message> Resp<M> {
pub fn new(v: M::Response) -> Self {
Resp {
_phantom: Default::default(),
msg: v,
body: BodyData::None,
}
}
pub fn with_fixed_body(self, b: Bytes) -> Self {
Self {
body: BodyData::Fixed(b),
..self
}
}
pub fn with_streaming_body(self, b: ByteStream) -> Self {
Self {
body: BodyData::Stream(b),
..self
}
}
pub fn msg(&self) -> &M::Response {
&self.msg
}
pub fn into_msg(self) -> M::Response {
self.msg
}
pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> {
Ok(RespEnc::Success {
msg: rmp_to_vec_all_named(&self.msg)?.into(),
stream: self.body.into_stream(),
})
}
pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> {
match enc {
RespEnc::Success { msg, stream } => {
let msg = rmp_serde::decode::from_read_ref(&msg)?;
Ok(Self {
_phantom: Default::default(),
msg,
body: stream.map(BodyData::Stream).unwrap_or(BodyData::None),
})
}
RespEnc::Error { code, message } => Err(Error::Remote(code, message)),
}
}
}
impl<M> fmt::Debug for Resp<M>
where
M: Message,
<M as Message>::Response: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Resp[{:?}", self.msg)?;
match &self.body {
BodyData::None => write!(f, "]"),
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
BodyData::Stream(_) => write!(f, "; body=stream]"),
}
}
}
// ----
pub(crate) enum BodyData {
None,
Fixed(Bytes),
Stream(ByteStream),
}
impl BodyData {
pub fn into_stream(self) -> Option<ByteStream> {
match self {
BodyData::None => None,
BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
BodyData::Stream(s) => Some(s),
}
}
}
// ---- ----
/// Encoding for requests into a ByteStream:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
/// - msg len: u32
/// - msg [u8; ..]
/// - the attached stream as the rest of the encoded stream
pub(crate) struct ReqEnc {
pub(crate) prio: RequestPriority,
pub(crate) path: Bytes,
pub(crate) telemetry_id: Bytes,
pub(crate) msg: Bytes,
pub(crate) stream: Option<ByteStream>,
}
impl ReqEnc {
pub(crate) fn encode(self) -> ByteStream {
let mut buf = BytesMut::with_capacity(64);
buf.put_u8(self.prio);
buf.put_u8(self.path.len() as u8);
buf.put(self.path);
buf.put_u8(self.telemetry_id.len() as u8);
buf.put(&self.telemetry_id[..]);
buf.put_u32(self.msg.len() as u32);
buf.put(&self.msg[..]);
let header = buf.freeze();
if let Some(stream) = self.stream {
Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
} else {
Box::pin(futures::stream::once(async move { Ok(header) }))
}
}
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
Self::decode_aux(stream).await.map_err(|_| Error::Framing)
}
pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
let mut reader = ByteStreamReader::new(stream);
let prio = reader.read_u8().await?;
let path_len = reader.read_u8().await?;
let path = reader.read_exact(path_len as usize).await?;
let telemetry_id_len = reader.read_u8().await?;
let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?;
let msg_len = reader.read_u32().await?;
let msg = reader.read_exact(msg_len as usize).await?;
Ok(Self {
prio,
path,
telemetry_id,
msg,
stream: Some(reader.into_stream()),
})
}
}
/// Encoding for responses into a ByteStream:
/// IF SUCCESS:
/// - 0: u8
/// - msg len: u32
/// - msg [u8; ..]
/// - the attached stream as the rest of the encoded stream
/// IF ERROR:
/// - message length + 1: u8
/// - error code: u8
/// - message: [u8; message_length]
pub(crate) enum RespEnc {
Error {
code: u8,
message: String,
},
Success {
msg: Bytes,
stream: Option<ByteStream>,
},
}
impl RespEnc {
pub(crate) fn from_err(e: Error) -> Self {
RespEnc::Error {
code: e.code(),
message: format!("{}", e),
}
}
pub(crate) fn encode(self) -> ByteStream {
match self {
RespEnc::Success { msg, stream } => {
let mut buf = BytesMut::with_capacity(64);
buf.put_u8(0);
buf.put_u32(msg.len() as u32);
buf.put(&msg[..]);
let header = buf.freeze();
if let Some(stream) = stream {
Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
} else {
Box::pin(futures::stream::once(async move { Ok(header) }))
}
}
RespEnc::Error { code, message } => {
let mut buf = BytesMut::with_capacity(64);
buf.put_u8(1 + message.len() as u8);
buf.put_u8(code);
buf.put(message.as_bytes());
let header = buf.freeze();
Box::pin(futures::stream::once(async move { Ok(header) }))
}
}
}
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
Self::decode_aux(stream).await.map_err(|_| Error::Framing)
}
pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
let mut reader = ByteStreamReader::new(stream);
let is_err = reader.read_u8().await?;
if is_err > 0 {
let code = reader.read_u8().await?;
let message = reader.read_exact(is_err as usize - 1).await?;
let message = String::from_utf8(message.to_vec()).unwrap_or_default();
Ok(RespEnc::Error { code, message })
} else {
let msg_len = reader.read_u32().await?;
let msg = reader.read_exact(msg_len as usize).await?;
Ok(RespEnc::Success {
msg,
stream: Some(reader.into_stream()),
})
}
}
}