aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/client.rs58
-rw-r--r--src/endpoint.rs30
-rw-r--r--src/lib.rs1
-rw-r--r--src/message.rs357
-rw-r--r--src/recv.rs2
-rw-r--r--src/send.rs2
-rw-r--r--src/server.rs53
-rw-r--r--src/stream.rs176
-rw-r--r--src/util.rs15
9 files changed, 429 insertions, 265 deletions
diff --git a/src/client.rs b/src/client.rs
index c878627..42eeaa3 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
+use bytes::Bytes;
use log::{debug, error, trace};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
@@ -28,6 +29,7 @@ use crate::message::*;
use crate::netapp::*;
use crate::recv::*;
use crate::send::*;
+use crate::stream::*;
use crate::util::*;
pub(crate) struct ClientConn {
@@ -155,24 +157,16 @@ impl ClientConn {
.with_kind(SpanKind::Client)
.start(&tracer);
let propagator = BinaryPropagator::new();
- let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec());
+ let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into();
} else {
- let telemetry_id: Option<Vec<u8>> = None;
+ let telemetry_id: Bytes = Bytes::new();
}
};
// Encode request
- let body = req.msg_ser.unwrap().clone();
- let stream = req.body.into_stream();
-
- let request = QueryMessage {
- prio,
- path: path.as_bytes(),
- telemetry_id,
- body: &body[..],
- };
- let bytes = request.encode();
- drop(body);
+ let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id);
+ let req_msg_len = req_enc.msg.len();
+ let req_stream = req_enc.encode();
// Send request through
let (resp_send, resp_recv) = oneshot::channel();
@@ -181,17 +175,19 @@ impl ClientConn {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
- if old_ch.send(unbounded().1).is_err() {
- debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
- }
+ let _ = old_ch.send(unbounded().1);
}
- trace!("request: query_send {}, {} bytes", id, bytes.len());
+ trace!(
+ "request: query_send {} (serialized message: {} bytes)",
+ id,
+ req_msg_len
+ );
#[cfg(feature = "telemetry")]
- span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
+ span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
- query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
+ query_send.send((id, prio, req_stream))?;
cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] {
@@ -202,28 +198,10 @@ impl ClientConn {
let stream = resp_recv.await?;
}
}
- let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
- if resp.is_empty() {
- return Err(Error::Message(
- "Response is 0 bytes, either a collision or a protocol error".into(),
- ));
- }
-
- trace!("request response {}: ", id);
-
- let code = resp[0];
- if code == 0 {
- let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
- Ok(Resp {
- _phantom: Default::default(),
- msg: ser_resp,
- body: BodyData::Stream(stream),
- })
- } else {
- let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
- Err(Error::Remote(code, msg))
- }
+ let resp_enc = RespEnc::decode(Box::pin(stream)).await?;
+ trace!("request response {}", id);
+ Resp::from_enc(resp_enc)
}
}
diff --git a/src/endpoint.rs b/src/endpoint.rs
index ff626d8..d8dc6c4 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -158,12 +158,7 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
#[async_trait]
pub(crate) trait GenericEndpoint {
- async fn handle(
- &self,
- buf: &[u8],
- stream: ByteStream,
- from: NodeID,
- ) -> Result<(Vec<u8>, Option<ByteStream>), Error>;
+ async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>;
fn drop_handler(&self);
fn clone_endpoint(&self) -> DynEndpoint;
}
@@ -180,30 +175,13 @@ where
M: Message + 'static,
H: StreamingEndpointHandler<M> + 'static,
{
- async fn handle(
- &self,
- buf: &[u8],
- stream: ByteStream,
- from: NodeID,
- ) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
+ async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> {
match self.0.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => {
- let req = rmp_serde::decode::from_read_ref(buf)?;
- let req = Req {
- _phantom: Default::default(),
- msg: Arc::new(req),
- msg_ser: None,
- body: BodyData::Stream(stream),
- };
+ let req = Req::from_enc(req_enc)?;
let res = h.handle(req, from).await;
- let Resp {
- msg,
- body,
- _phantom,
- } = res;
- let res_bytes = rmp_to_vec_all_named(&msg)?;
- Ok((res_bytes, body.into_stream()))
+ Ok(res.into_enc()?)
}
}
}
diff --git a/src/lib.rs b/src/lib.rs
index 1edb919..ce94682 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -14,6 +14,7 @@
//! Also check out the examples to learn how to use this crate.
pub mod error;
+pub mod stream;
pub mod util;
pub mod endpoint;
diff --git a/src/message.rs b/src/message.rs
index 5721318..ba06551 100644
--- a/src/message.rs
+++ b/src/message.rs
@@ -2,12 +2,13 @@ use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
-use bytes::Bytes;
+use bytes::{BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
-use futures::stream::{Stream, StreamExt};
+use futures::stream::StreamExt;
use crate::error::*;
+use crate::stream::*;
use crate::util::*;
/// Priority of a request (click to read more about priorities).
@@ -45,6 +46,15 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
}
+// ----
+
+/// The Req<M> is a helper object used to create requests and attach them
+/// a streaming body. If the body is a fixed Bytes and not a ByteStream,
+/// Req<M> is cheaply clonable to allow the request to be sent to different
+/// peers (Clone will panic if the body is a ByteStream).
+///
+/// Internally, this is also used to encode and decode requests
+/// from/to byte streams to be sent over the network.
pub struct Req<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: Arc<M>,
@@ -52,30 +62,6 @@ pub struct Req<M: Message> {
pub(crate) body: BodyData,
}
-pub struct Resp<M: Message> {
- pub(crate) _phantom: PhantomData<M>,
- pub(crate) msg: M::Response,
- pub(crate) body: BodyData,
-}
-
-pub(crate) enum BodyData {
- None,
- Fixed(Bytes),
- Stream(ByteStream),
-}
-
-impl BodyData {
- pub fn into_stream(self) -> Option<ByteStream> {
- match self {
- BodyData::None => None,
- BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
- BodyData::Stream(s) => Some(s),
- }
- }
-}
-
-// ----
-
impl<M: Message> Req<M> {
pub fn msg(&self) -> &M {
&self.msg
@@ -94,6 +80,31 @@ impl<M: Message> Req<M> {
..self
}
}
+
+ pub(crate) fn into_enc(
+ self,
+ prio: RequestPriority,
+ path: Bytes,
+ telemetry_id: Bytes,
+ ) -> ReqEnc {
+ ReqEnc {
+ prio,
+ path,
+ telemetry_id,
+ msg: self.msg_ser.unwrap(),
+ stream: self.body.into_stream(),
+ }
+ }
+
+ pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> {
+ let msg = rmp_serde::decode::from_read_ref(&enc.msg)?;
+ Ok(Req {
+ _phantom: Default::default(),
+ msg: Arc::new(msg),
+ msg_ser: Some(enc.msg),
+ body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None),
+ })
+ }
}
pub trait IntoReq<M: Message> {
@@ -160,19 +171,14 @@ where
}
}
-impl<M> fmt::Debug for Resp<M>
-where
- M: Message,
- <M as Message>::Response: fmt::Debug,
-{
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
- write!(f, "Resp[{:?}", self.msg)?;
- match &self.body {
- BodyData::None => write!(f, "]"),
- BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
- BodyData::Stream(_) => write!(f, "; body=stream]"),
- }
- }
+// ----
+
+/// The Resp<M> represents a full response from a RPC that may have
+/// an attached body stream.
+pub struct Resp<M: Message> {
+ pub(crate) _phantom: PhantomData<M>,
+ pub(crate) msg: M::Response,
+ pub(crate) body: BodyData,
}
impl<M: Message> Resp<M> {
@@ -205,160 +211,213 @@ impl<M: Message> Resp<M> {
pub fn into_msg(self) -> M::Response {
self.msg
}
+
+ pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> {
+ Ok(RespEnc::Success {
+ msg: rmp_to_vec_all_named(&self.msg)?.into(),
+ stream: self.body.into_stream(),
+ })
+ }
+
+ pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> {
+ match enc {
+ RespEnc::Success { msg, stream } => {
+ let msg = rmp_serde::decode::from_read_ref(&msg)?;
+ Ok(Self {
+ _phantom: Default::default(),
+ msg,
+ body: stream.map(BodyData::Stream).unwrap_or(BodyData::None),
+ })
+ }
+ RespEnc::Error { code, message } => Err(Error::Remote(code, message)),
+ }
+ }
}
-// ---- ----
+impl<M> fmt::Debug for Resp<M>
+where
+ M: Message,
+ <M as Message>::Response: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ write!(f, "Resp[{:?}", self.msg)?;
+ match &self.body {
+ BodyData::None => write!(f, "]"),
+ BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
+ BodyData::Stream(_) => write!(f, "; body=stream]"),
+ }
+ }
+}
-pub(crate) struct QueryMessage<'a> {
- pub(crate) prio: RequestPriority,
- pub(crate) path: &'a [u8],
- pub(crate) telemetry_id: Option<Vec<u8>>,
- pub(crate) body: &'a [u8],
+// ----
+
+pub(crate) enum BodyData {
+ None,
+ Fixed(Bytes),
+ Stream(ByteStream),
+}
+
+impl BodyData {
+ pub fn into_stream(self) -> Option<ByteStream> {
+ match self {
+ BodyData::None => None,
+ BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
+ BodyData::Stream(s) => Some(s),
+ }
+ }
}
-/// QueryMessage encoding:
+// ---- ----
+
+/// Encoding for requests into a ByteStream:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
-/// - body [u8; ..]
-impl<'a> QueryMessage<'a> {
- pub(crate) fn encode(self) -> Vec<u8> {
- let tel_len = match &self.telemetry_id {
- Some(t) => t.len(),
- None => 0,
- };
+/// - msg len: u32
+/// - msg [u8; ..]
+/// - the attached stream as the rest of the encoded stream
+pub(crate) struct ReqEnc {
+ pub(crate) prio: RequestPriority,
+ pub(crate) path: Bytes,
+ pub(crate) telemetry_id: Bytes,
+ pub(crate) msg: Bytes,
+ pub(crate) stream: Option<ByteStream>,
+}
+
+impl ReqEnc {
+ pub(crate) fn encode(self) -> ByteStream {
+ let mut buf = BytesMut::with_capacity(64);
+
+ buf.put_u8(self.prio);
- let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
+ buf.put_u8(self.path.len() as u8);
+ buf.put(self.path);
- ret.push(self.prio);
+ buf.put_u8(self.telemetry_id.len() as u8);
+ buf.put(&self.telemetry_id[..]);
- ret.push(self.path.len() as u8);
- ret.extend_from_slice(self.path);
+ buf.put_u32(self.msg.len() as u32);
+ buf.put(&self.msg[..]);
- if let Some(t) = self.telemetry_id {
- ret.push(t.len() as u8);
- ret.extend(t);
+ let header = buf.freeze();
+
+ if let Some(stream) = self.stream {
+ Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
} else {
- ret.push(0u8);
+ Box::pin(futures::stream::once(async move { Ok(header) }))
}
+ }
- ret.extend_from_slice(self.body);
-
- ret
+ pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
+ Self::decode_aux(stream).await.map_err(|_| Error::Framing)
}
- pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
- if bytes.len() < 3 {
- return Err(Error::Message("Invalid protocol message".into()));
- }
+ pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
+ let mut reader = ByteStreamReader::new(stream);
- let path_length = bytes[1] as usize;
- if bytes.len() < 3 + path_length {
- return Err(Error::Message("Invalid protocol message".into()));
- }
+ let prio = reader.read_u8().await?;
- let telemetry_id_len = bytes[2 + path_length] as usize;
- if bytes.len() < 3 + path_length + telemetry_id_len {
- return Err(Error::Message("Invalid protocol message".into()));
- }
+ let path_len = reader.read_u8().await?;
+ let path = reader.read_exact(path_len as usize).await?;
- let path = &bytes[2..2 + path_length];
- let telemetry_id = if telemetry_id_len > 0 {
- Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
- } else {
- None
- };
+ let telemetry_id_len = reader.read_u8().await?;
+ let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?;
- let body = &bytes[3 + path_length + telemetry_id_len..];
+ let msg_len = reader.read_u32().await?;
+ let msg = reader.read_exact(msg_len as usize).await?;
Ok(Self {
- prio: bytes[0],
+ prio,
path,
telemetry_id,
- body,
+ msg,
+ stream: Some(reader.into_stream()),
})
}
}
-// ---- ----
-
-pub(crate) struct Framing {
- direct: Vec<u8>,
- stream: Option<ByteStream>,
+/// Encoding for responses into a ByteStream:
+/// IF SUCCESS:
+/// - 0: u8
+/// - msg len: u32
+/// - msg [u8; ..]
+/// - the attached stream as the rest of the encoded stream
+/// IF ERROR:
+/// - message length + 1: u8
+/// - error code: u8
+/// - message: [u8; message_length]
+pub(crate) enum RespEnc {
+ Error {
+ code: u8,
+ message: String,
+ },
+ Success {
+ msg: Bytes,
+ stream: Option<ByteStream>,
+ },
}
-impl Framing {
- pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self {
- assert!(direct.len() <= u32::MAX as usize);
- Framing { direct, stream }
+impl RespEnc {
+ pub(crate) fn from_err(e: Error) -> Self {
+ RespEnc::Error {
+ code: e.code(),
+ message: format!("{}", e),
+ }
}
- pub fn into_stream(self) -> ByteStream {
- use futures::stream;
- let len = self.direct.len() as u32;
- // required because otherwise the borrow-checker complains
- let Framing { direct, stream } = self;
+ pub(crate) fn encode(self) -> ByteStream {
+ match self {
+ RespEnc::Success { msg, stream } => {
+ let mut buf = BytesMut::with_capacity(64);
- let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) })
- .chain(stream::once(async move { Ok(direct.into()) }));
+ buf.put_u8(0);
- if let Some(stream) = stream {
- Box::pin(res.chain(stream))
- } else {
- Box::pin(res)
- }
- }
+ buf.put_u32(msg.len() as u32);
+ buf.put(&msg[..]);
- pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + Sync + 'static>(
- mut stream: S,
- ) -> Result<Self, Error> {
- let mut packet = stream
- .next()
- .await
- .ok_or(Error::Framing)?
- .map_err(|_| Error::Framing)?;
- if packet.len() < 4 {
- return Err(Error::Framing);
+ let header = buf.freeze();
+
+ if let Some(stream) = stream {
+ Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
+ } else {
+ Box::pin(futures::stream::once(async move { Ok(header) }))
+ }
+ }
+ RespEnc::Error { code, message } => {
+ let mut buf = BytesMut::with_capacity(64);
+ buf.put_u8(1 + message.len() as u8);
+ buf.put_u8(code);
+ buf.put(message.as_bytes());
+ let header = buf.freeze();
+ Box::pin(futures::stream::once(async move { Ok(header) }))
+ }
}
+ }
- let mut len = [0; 4];
- len.copy_from_slice(&packet[..4]);
- let len = u32::from_be_bytes(len);
- packet = packet.slice(4..);
+ pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
+ Self::decode_aux(stream).await.map_err(|_| Error::Framing)
+ }
- let mut buffer = Vec::new();
- let len = len as usize;
- loop {
- let max_cp = std::cmp::min(len - buffer.len(), packet.len());
+ pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
+ let mut reader = ByteStreamReader::new(stream);
- buffer.extend_from_slice(&packet[..max_cp]);
- if buffer.len() == len {
- packet = packet.slice(max_cp..);
- break;
- }
- packet = stream
- .next()
- .await
- .ok_or(Error::Framing)?
- .map_err(|_| Error::Framing)?;
- }
+ let is_err = reader.read_u8().await?;
- let stream: ByteStream = if packet.is_empty() {
- Box::pin(stream)
+ if is_err > 0 {
+ let code = reader.read_u8().await?;
+ let message = reader.read_exact(is_err as usize - 1).await?;
+ let message = String::from_utf8(message.to_vec()).unwrap_or_default();
+ Ok(RespEnc::Error { code, message })
} else {
- Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
- };
+ let msg_len = reader.read_u32().await?;
+ let msg = reader.read_exact(msg_len as usize).await?;
- Ok(Framing {
- direct: buffer,
- stream: Some(stream),
- })
- }
-
- pub fn into_parts(self) -> (Vec<u8>, ByteStream) {
- let Framing { direct, stream } = self;
- (direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
+ Ok(RespEnc::Success {
+ msg,
+ stream: Some(reader.into_stream()),
+ })
+ }
}
}
diff --git a/src/recv.rs b/src/recv.rs
index abe7b9a..19288f2 100644
--- a/src/recv.rs
+++ b/src/recv.rs
@@ -10,7 +10,7 @@ use futures::AsyncReadExt;
use crate::error::*;
use crate::send::*;
-use crate::util::Packet;
+use crate::stream::*;
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
diff --git a/src/send.rs b/src/send.rs
index cc28d7c..59805cf 100644
--- a/src/send.rs
+++ b/src/send.rs
@@ -14,7 +14,7 @@ use tokio::sync::mpsc;
use crate::error::*;
use crate::message::*;
-use crate::util::{ByteStream, Packet};
+use crate::stream::*;
// Messages are sent by chunks
// Chunk format:
diff --git a/src/server.rs b/src/server.rs
index 1f1c22a..ae1196c 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -28,6 +28,7 @@ use crate::message::*;
use crate::netapp::*;
use crate::recv::*;
use crate::send::*;
+use crate::stream::*;
use crate::util::*;
// The client and server connection structs (client.rs and server.rs)
@@ -121,17 +122,12 @@ impl ServerConn {
Ok(())
}
- async fn recv_handler_aux(
- self: &Arc<Self>,
- bytes: &[u8],
- stream: ByteStream,
- ) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
- let msg = QueryMessage::decode(bytes)?;
- let path = String::from_utf8(msg.path.to_vec())?;
+ async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> {
+ let path = String::from_utf8(req_enc.path.to_vec())?;
let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap();
- endpoints.get(&path).map(|e| e.clone_endpoint())
+ endpoints.get(&path[..]).map(|e| e.clone_endpoint())
};
if let Some(handler) = handler_opt {
@@ -139,9 +135,9 @@ impl ServerConn {
if #[cfg(feature = "telemetry")] {
let tracer = opentelemetry::global::tracer("netapp");
- let mut span = if let Some(telemetry_id) = msg.telemetry_id {
+ let mut span = if !req_enc.telemetry_id.is_empty() {
let propagator = BinaryPropagator::new();
- let context = propagator.from_bytes(telemetry_id);
+ let context = propagator.from_bytes(req_enc.telemetry_id.to_vec());
let context = Context::new().with_remote_span_context(context);
tracer.span_builder(format!(">> RPC {}", path))
.with_kind(SpanKind::Server)
@@ -156,13 +152,13 @@ impl ServerConn {
.start(&tracer)
};
span.set_attribute(KeyValue::new("path", path.to_string()));
- span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
+ span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64));
- handler.handle(msg.body, stream, self.peer_id)
+ handler.handle(req_enc, self.peer_id)
.with_context(Context::current_with_span(span))
.await
} else {
- handler.handle(msg.body, stream, self.peer_id).await
+ handler.handle(req_enc, self.peer_id).await
}
}
} else {
@@ -181,32 +177,23 @@ impl RecvLoop for ServerConn {
let self2 = self.clone();
tokio::spawn(async move {
trace!("ServerConn recv_handler {}", id);
- let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
-
- let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
- let resp = self2.recv_handler_aux(&bytes[..], stream).await;
-
- let (resp_bytes, resp_stream) = match resp {
- Ok((rb, rs)) => {
- let mut resp_bytes = vec![0u8];
- resp_bytes.extend(rb);
- (resp_bytes, rs)
- }
- Err(e) => {
- let mut resp_bytes = vec![e.code()];
- resp_bytes.extend(e.to_string().into_bytes());
- (resp_bytes, None)
+ let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await {
+ Ok(req_enc) => {
+ let prio = req_enc.prio;
+ let resp = self2.recv_handler_aux(req_enc).await;
+
+ (prio, match resp {
+ Ok(resp_enc) => resp_enc,
+ Err(e) => RespEnc::from_err(e),
+ })
}
+ Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
};
trace!("ServerConn sending response to {}: ", id);
resp_send
- .send((
- id,
- prio,
- Framing::new(resp_bytes, resp_stream).into_stream(),
- ))
+ .send((id, prio, resp_enc.encode()))
.log_err("ServerConn recv_handler send resp bytes");
Ok::<_, Error>(())
});
diff --git a/src/stream.rs b/src/stream.rs
new file mode 100644
index 0000000..6c23f4a
--- /dev/null
+++ b/src/stream.rs
@@ -0,0 +1,176 @@
+use std::collections::VecDeque;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use bytes::Bytes;
+
+use futures::Future;
+use futures::{Stream, StreamExt};
+
+/// A stream of associated data.
+///
+/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
+/// consecutive Vec may get merged, but Vec and error code may not be reordered
+///
+/// Error code 255 means the stream was cut before its end. Other codes have no predefined
+/// meaning, it's up to your application to define their semantic.
+pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
+
+pub type Packet = Result<Bytes, u8>;
+
+pub struct ByteStreamReader {
+ stream: ByteStream,
+ buf: VecDeque<Bytes>,
+ buf_len: usize,
+ eos: bool,
+ err: Option<u8>,
+}
+
+impl ByteStreamReader {
+ pub fn new(stream: ByteStream) -> Self {
+ ByteStreamReader {
+ stream,
+ buf: VecDeque::with_capacity(8),
+ buf_len: 0,
+ eos: false,
+ err: None,
+ }
+ }
+
+ pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: true,
+ }
+ }
+
+ pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: false,
+ }
+ }
+
+ pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
+ Ok(self.read_exact(1).await?[0])
+ }
+
+ pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
+ let bytes = self.read_exact(2).await?;
+ let mut b = [0u8; 2];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u16::from_be_bytes(b))
+ }
+
+ pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
+ let bytes = self.read_exact(4).await?;
+ let mut b = [0u8; 4];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u32::from_be_bytes(b))
+ }
+
+ pub fn into_stream(self) -> ByteStream {
+ let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok));
+ if let Some(err) = self.err {
+ Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
+ } else if self.eos {
+ Box::pin(buf_stream)
+ } else {
+ Box::pin(buf_stream.chain(self.stream))
+ }
+ }
+
+ fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
+ if self.buf_len >= read_len {
+ let mut slices = Vec::with_capacity(self.buf.len());
+ let mut taken = 0;
+ while taken < read_len {
+ let front = self.buf.pop_front().unwrap();
+ if taken + front.len() <= read_len {
+ taken += front.len();
+ self.buf_len -= front.len();
+ slices.push(front);
+ } else {
+ let front_take = read_len - taken;
+ slices.push(front.slice(..front_take));
+ self.buf.push_front(front.slice(front_take..));
+ self.buf_len -= front_take;
+ break;
+ }
+ }
+ Some(
+ slices
+ .iter()
+ .map(|x| &x[..])
+ .collect::<Vec<_>>()
+ .concat()
+ .into(),
+ )
+ } else {
+ None
+ }
+ }
+}
+
+pub enum ReadExactError {
+ UnexpectedEos,
+ Stream(u8),
+}
+
+#[pin_project::pin_project]
+pub struct ByteStreamReadExact<'a> {
+ #[pin]
+ reader: &'a mut ByteStreamReader,
+ read_len: usize,
+ fail_on_eos: bool,
+}
+
+impl<'a> Future for ByteStreamReadExact<'a> {
+ type Output = Result<Bytes, ReadExactError>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
+ let mut this = self.project();
+
+ loop {
+ if let Some(bytes) = this.reader.try_get(*this.read_len) {
+ return Poll::Ready(Ok(bytes));
+ }
+ if let Some(err) = this.reader.err {
+ return Poll::Ready(Err(ReadExactError::Stream(err)));
+ }
+ if this.reader.eos {
+ if *this.fail_on_eos {
+ return Poll::Ready(Err(ReadExactError::UnexpectedEos));
+ } else {
+ let bytes = Bytes::from(
+ this.reader
+ .buf
+ .iter()
+ .map(|x| &x[..])
+ .collect::<Vec<_>>()
+ .concat(),
+ );
+ this.reader.buf.clear();
+ this.reader.buf_len = 0;
+ return Poll::Ready(Ok(bytes));
+ }
+ }
+
+ match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) {
+ Some(Ok(slice)) => {
+ this.reader.buf_len += slice.len();
+ this.reader.buf.push_back(slice);
+ }
+ Some(Err(e)) => {
+ this.reader.err = Some(e);
+ this.reader.eos = true;
+ }
+ None => {
+ this.reader.eos = true;
+ }
+ }
+ }
+ }
+}
diff --git a/src/util.rs b/src/util.rs
index 01c392c..13cccb9 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -1,12 +1,9 @@
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
-use std::pin::Pin;
-use bytes::Bytes;
use log::info;
use serde::Serialize;
-use futures::Stream;
use tokio::sync::watch;
/// A node's identifier, which is also its public cryptographic key
@@ -16,18 +13,6 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
/// A network key
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
-/// A stream of associated data.
-///
-/// The Stream can continue after receiving an error.
-/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
-/// consecutive Vec may get merged, but Vec and error code may not be reordered
-///
-/// Error code 255 means the stream was cut before its end. Other codes have no predefined
-/// meaning, it's up to your application to define their semantic.
-pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
-
-pub type Packet = Result<Bytes, u8>;
-
/// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library.
///