diff options
-rw-r--r-- | src/client.rs | 58 | ||||
-rw-r--r-- | src/endpoint.rs | 30 | ||||
-rw-r--r-- | src/lib.rs | 1 | ||||
-rw-r--r-- | src/message.rs | 357 | ||||
-rw-r--r-- | src/recv.rs | 2 | ||||
-rw-r--r-- | src/send.rs | 2 | ||||
-rw-r--r-- | src/server.rs | 53 | ||||
-rw-r--r-- | src/stream.rs | 176 | ||||
-rw-r--r-- | src/util.rs | 15 |
9 files changed, 429 insertions, 265 deletions
diff --git a/src/client.rs b/src/client.rs index c878627..42eeaa3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use async_trait::async_trait; +use bytes::Bytes; use log::{debug, error, trace}; use futures::channel::mpsc::{unbounded, UnboundedReceiver}; @@ -28,6 +29,7 @@ use crate::message::*; use crate::netapp::*; use crate::recv::*; use crate::send::*; +use crate::stream::*; use crate::util::*; pub(crate) struct ClientConn { @@ -155,24 +157,16 @@ impl ClientConn { .with_kind(SpanKind::Client) .start(&tracer); let propagator = BinaryPropagator::new(); - let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec()); + let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into(); } else { - let telemetry_id: Option<Vec<u8>> = None; + let telemetry_id: Bytes = Bytes::new(); } }; // Encode request - let body = req.msg_ser.unwrap().clone(); - let stream = req.body.into_stream(); - - let request = QueryMessage { - prio, - path: path.as_bytes(), - telemetry_id, - body: &body[..], - }; - let bytes = request.encode(); - drop(body); + let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); + let req_msg_len = req_enc.msg.len(); + let req_stream = req_enc.encode(); // Send request through let (resp_send, resp_recv) = oneshot::channel(); @@ -181,17 +175,19 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(unbounded().1).is_err() { - debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); - } + let _ = old_ch.send(unbounded().1); } - trace!("request: query_send {}, {} bytes", id, bytes.len()); + trace!( + "request: query_send {} (serialized message: {} bytes)", + id, + req_msg_len + ); #[cfg(feature = "telemetry")] - span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?; + query_send.send((id, prio, req_stream))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -202,28 +198,10 @@ impl ClientConn { let stream = resp_recv.await?; } } - let (resp, stream) = Framing::from_stream(stream).await?.into_parts(); - if resp.is_empty() { - return Err(Error::Message( - "Response is 0 bytes, either a collision or a protocol error".into(), - )); - } - - trace!("request response {}: ", id); - - let code = resp[0]; - if code == 0 { - let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; - Ok(Resp { - _phantom: Default::default(), - msg: ser_resp, - body: BodyData::Stream(stream), - }) - } else { - let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); - Err(Error::Remote(code, msg)) - } + let resp_enc = RespEnc::decode(Box::pin(stream)).await?; + trace!("request response {}", id); + Resp::from_enc(resp_enc) } } diff --git a/src/endpoint.rs b/src/endpoint.rs index ff626d8..d8dc6c4 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -158,12 +158,7 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle( - &self, - buf: &[u8], - stream: ByteStream, - from: NodeID, - ) -> Result<(Vec<u8>, Option<ByteStream>), Error>; + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -180,30 +175,13 @@ where M: Message + 'static, H: StreamingEndpointHandler<M> + 'static, { - async fn handle( - &self, - buf: &[u8], - stream: ByteStream, - from: NodeID, - ) -> Result<(Vec<u8>, Option<ByteStream>), Error> { + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref(buf)?; - let req = Req { - _phantom: Default::default(), - msg: Arc::new(req), - msg_ser: None, - body: BodyData::Stream(stream), - }; + let req = Req::from_enc(req_enc)?; let res = h.handle(req, from).await; - let Resp { - msg, - body, - _phantom, - } = res; - let res_bytes = rmp_to_vec_all_named(&msg)?; - Ok((res_bytes, body.into_stream())) + Ok(res.into_enc()?) } } } @@ -14,6 +14,7 @@ //! Also check out the examples to learn how to use this crate. pub mod error; +pub mod stream; pub mod util; pub mod endpoint; diff --git a/src/message.rs b/src/message.rs index 5721318..ba06551 100644 --- a/src/message.rs +++ b/src/message.rs @@ -2,12 +2,13 @@ use std::fmt; use std::marker::PhantomData; use std::sync::Arc; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use serde::{Deserialize, Serialize}; -use futures::stream::{Stream, StreamExt}; +use futures::stream::StreamExt; use crate::error::*; +use crate::stream::*; use crate::util::*; /// Priority of a request (click to read more about priorities). @@ -45,6 +46,15 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; } +// ---- + +/// The Req<M> is a helper object used to create requests and attach them +/// a streaming body. If the body is a fixed Bytes and not a ByteStream, +/// Req<M> is cheaply clonable to allow the request to be sent to different +/// peers (Clone will panic if the body is a ByteStream). +/// +/// Internally, this is also used to encode and decode requests +/// from/to byte streams to be sent over the network. pub struct Req<M: Message> { pub(crate) _phantom: PhantomData<M>, pub(crate) msg: Arc<M>, @@ -52,30 +62,6 @@ pub struct Req<M: Message> { pub(crate) body: BodyData, } -pub struct Resp<M: Message> { - pub(crate) _phantom: PhantomData<M>, - pub(crate) msg: M::Response, - pub(crate) body: BodyData, -} - -pub(crate) enum BodyData { - None, - Fixed(Bytes), - Stream(ByteStream), -} - -impl BodyData { - pub fn into_stream(self) -> Option<ByteStream> { - match self { - BodyData::None => None, - BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), - BodyData::Stream(s) => Some(s), - } - } -} - -// ---- - impl<M: Message> Req<M> { pub fn msg(&self) -> &M { &self.msg @@ -94,6 +80,31 @@ impl<M: Message> Req<M> { ..self } } + + pub(crate) fn into_enc( + self, + prio: RequestPriority, + path: Bytes, + telemetry_id: Bytes, + ) -> ReqEnc { + ReqEnc { + prio, + path, + telemetry_id, + msg: self.msg_ser.unwrap(), + stream: self.body.into_stream(), + } + } + + pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> { + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Req { + _phantom: Default::default(), + msg: Arc::new(msg), + msg_ser: Some(enc.msg), + body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None), + }) + } } pub trait IntoReq<M: Message> { @@ -160,19 +171,14 @@ where } } -impl<M> fmt::Debug for Resp<M> -where - M: Message, - <M as Message>::Response: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(f, "Resp[{:?}", self.msg)?; - match &self.body { - BodyData::None => write!(f, "]"), - BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), - BodyData::Stream(_) => write!(f, "; body=stream]"), - } - } +// ---- + +/// The Resp<M> represents a full response from a RPC that may have +/// an attached body stream. +pub struct Resp<M: Message> { + pub(crate) _phantom: PhantomData<M>, + pub(crate) msg: M::Response, + pub(crate) body: BodyData, } impl<M: Message> Resp<M> { @@ -205,160 +211,213 @@ impl<M: Message> Resp<M> { pub fn into_msg(self) -> M::Response { self.msg } + + pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> { + Ok(RespEnc::Success { + msg: rmp_to_vec_all_named(&self.msg)?.into(), + stream: self.body.into_stream(), + }) + } + + pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> { + match enc { + RespEnc::Success { msg, stream } => { + let msg = rmp_serde::decode::from_read_ref(&msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + body: stream.map(BodyData::Stream).unwrap_or(BodyData::None), + }) + } + RespEnc::Error { code, message } => Err(Error::Remote(code, message)), + } + } } -// ---- ---- +impl<M> fmt::Debug for Resp<M> +where + M: Message, + <M as Message>::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option<Vec<u8>>, - pub(crate) body: &'a [u8], +// ---- + +pub(crate) enum BodyData { + None, + Fixed(Bytes), + Stream(ByteStream), +} + +impl BodyData { + pub fn into_stream(self) -> Option<ByteStream> { + match self { + BodyData::None => None, + BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + BodyData::Stream(s) => Some(s), + } + } } -/// QueryMessage encoding: +// ---- ---- + +/// Encoding for requests into a ByteStream: /// - priority: u8 /// - path length: u8 /// - path: [u8; path length] /// - telemetry id length: u8 /// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec<u8> { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +pub(crate) struct ReqEnc { + pub(crate) prio: RequestPriority, + pub(crate) path: Bytes, + pub(crate) telemetry_id: Bytes, + pub(crate) msg: Bytes, + pub(crate) stream: Option<ByteStream>, +} + +impl ReqEnc { + pub(crate) fn encode(self) -> ByteStream { + let mut buf = BytesMut::with_capacity(64); + + buf.put_u8(self.prio); - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + buf.put_u8(self.path.len() as u8); + buf.put(self.path); - ret.push(self.prio); + buf.put_u8(self.telemetry_id.len() as u8); + buf.put(&self.telemetry_id[..]); - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); + buf.put_u32(self.msg.len() as u32); + buf.put(&self.msg[..]); - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); + let header = buf.freeze(); + + if let Some(stream) = self.stream { + Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) } else { - ret.push(0u8); + Box::pin(futures::stream::once(async move { Ok(header) })) } + } - ret.extend_from_slice(self.body); - - ret + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream).await.map_err(|_| Error::Framing) } - pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } + pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } + let prio = reader.read_u8().await?; - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } + let path_len = reader.read_u8().await?; + let path = reader.read_exact(path_len as usize).await?; - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; + let telemetry_id_len = reader.read_u8().await?; + let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?; - let body = &bytes[3 + path_length + telemetry_id_len..]; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; Ok(Self { - prio: bytes[0], + prio, path, telemetry_id, - body, + msg, + stream: Some(reader.into_stream()), }) } } -// ---- ---- - -pub(crate) struct Framing { - direct: Vec<u8>, - stream: Option<ByteStream>, +/// Encoding for responses into a ByteStream: +/// IF SUCCESS: +/// - 0: u8 +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +/// IF ERROR: +/// - message length + 1: u8 +/// - error code: u8 +/// - message: [u8; message_length] +pub(crate) enum RespEnc { + Error { + code: u8, + message: String, + }, + Success { + msg: Bytes, + stream: Option<ByteStream>, + }, } -impl Framing { - pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } +impl RespEnc { + pub(crate) fn from_err(e: Error) -> Self { + RespEnc::Error { + code: e.code(), + message: format!("{}", e), + } } - pub fn into_stream(self) -> ByteStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; + pub(crate) fn encode(self) -> ByteStream { + match self { + RespEnc::Success { msg, stream } => { + let mut buf = BytesMut::with_capacity(64); - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) }) - .chain(stream::once(async move { Ok(direct.into()) })); + buf.put_u8(0); - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } + buf.put_u32(msg.len() as u32); + buf.put(&msg[..]); - pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + Sync + 'static>( - mut stream: S, - ) -> Result<Self, Error> { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); + let header = buf.freeze(); + + if let Some(stream) = stream { + Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) + } else { + Box::pin(futures::stream::once(async move { Ok(header) })) + } + } + RespEnc::Error { code, message } => { + let mut buf = BytesMut::with_capacity(64); + buf.put_u8(1 + message.len() as u8); + buf.put_u8(code); + buf.put(message.as_bytes()); + let header = buf.freeze(); + Box::pin(futures::stream::once(async move { Ok(header) })) + } } + } - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet = packet.slice(4..); + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream).await.map_err(|_| Error::Framing) + } - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet = packet.slice(max_cp..); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } + let is_err = reader.read_u8().await?; - let stream: ByteStream = if packet.is_empty() { - Box::pin(stream) + if is_err > 0 { + let code = reader.read_u8().await?; + let message = reader.read_exact(is_err as usize - 1).await?; + let message = String::from_utf8(message.to_vec()).unwrap_or_default(); + Ok(RespEnc::Error { code, message }) } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec<u8>, ByteStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + Ok(RespEnc::Success { + msg, + stream: Some(reader.into_stream()), + }) + } } } diff --git a/src/recv.rs b/src/recv.rs index abe7b9a..19288f2 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -10,7 +10,7 @@ use futures::AsyncReadExt; use crate::error::*; use crate::send::*; -use crate::util::Packet; +use crate::stream::*; /// Structure to warn when the sender is dropped before end of stream was reached, like when /// connection to some remote drops while transmitting data diff --git a/src/send.rs b/src/send.rs index cc28d7c..59805cf 100644 --- a/src/send.rs +++ b/src/send.rs @@ -14,7 +14,7 @@ use tokio::sync::mpsc; use crate::error::*; use crate::message::*; -use crate::util::{ByteStream, Packet}; +use crate::stream::*; // Messages are sent by chunks // Chunk format: diff --git a/src/server.rs b/src/server.rs index 1f1c22a..ae1196c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -28,6 +28,7 @@ use crate::message::*; use crate::netapp::*; use crate::recv::*; use crate::send::*; +use crate::stream::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -121,17 +122,12 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux( - self: &Arc<Self>, - bytes: &[u8], - stream: ByteStream, - ) -> Result<(Vec<u8>, Option<ByteStream>), Error> { - let msg = QueryMessage::decode(bytes)?; - let path = String::from_utf8(msg.path.to_vec())?; + async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> { + let path = String::from_utf8(req_enc.path.to_vec())?; let handler_opt = { let endpoints = self.netapp.endpoints.read().unwrap(); - endpoints.get(&path).map(|e| e.clone_endpoint()) + endpoints.get(&path[..]).map(|e| e.clone_endpoint()) }; if let Some(handler) = handler_opt { @@ -139,9 +135,9 @@ impl ServerConn { if #[cfg(feature = "telemetry")] { let tracer = opentelemetry::global::tracer("netapp"); - let mut span = if let Some(telemetry_id) = msg.telemetry_id { + let mut span = if !req_enc.telemetry_id.is_empty() { let propagator = BinaryPropagator::new(); - let context = propagator.from_bytes(telemetry_id); + let context = propagator.from_bytes(req_enc.telemetry_id.to_vec()); let context = Context::new().with_remote_span_context(context); tracer.span_builder(format!(">> RPC {}", path)) .with_kind(SpanKind::Server) @@ -156,13 +152,13 @@ impl ServerConn { .start(&tracer) }; span.set_attribute(KeyValue::new("path", path.to_string())); - span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64)); - handler.handle(msg.body, stream, self.peer_id) + handler.handle(req_enc, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, stream, self.peer_id).await + handler.handle(req_enc, self.peer_id).await } } } else { @@ -181,32 +177,23 @@ impl RecvLoop for ServerConn { let self2 = self.clone(); tokio::spawn(async move { trace!("ServerConn recv_handler {}", id); - let (bytes, stream) = Framing::from_stream(stream).await?.into_parts(); - - let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..], stream).await; - - let (resp_bytes, resp_stream) = match resp { - Ok((rb, rs)) => { - let mut resp_bytes = vec![0u8]; - resp_bytes.extend(rb); - (resp_bytes, rs) - } - Err(e) => { - let mut resp_bytes = vec![e.code()]; - resp_bytes.extend(e.to_string().into_bytes()); - (resp_bytes, None) + let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await { + Ok(req_enc) => { + let prio = req_enc.prio; + let resp = self2.recv_handler_aux(req_enc).await; + + (prio, match resp { + Ok(resp_enc) => resp_enc, + Err(e) => RespEnc::from_err(e), + }) } + Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), }; trace!("ServerConn sending response to {}: ", id); resp_send - .send(( - id, - prio, - Framing::new(resp_bytes, resp_stream).into_stream(), - )) + .send((id, prio, resp_enc.encode())) .log_err("ServerConn recv_handler send resp bytes"); Ok::<_, Error>(()) }); diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..6c23f4a --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,176 @@ +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; + +use futures::Future; +use futures::{Stream, StreamExt}; + +/// A stream of associated data. +/// +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; + +pub type Packet = Result<Bytes, u8>; + +pub struct ByteStreamReader { + stream: ByteStream, + buf: VecDeque<Bytes>, + buf_len: usize, + eos: bool, + err: Option<u8>, +} + +impl ByteStreamReader { + pub fn new(stream: ByteStream) -> Self { + ByteStreamReader { + stream, + buf: VecDeque::with_capacity(8), + buf_len: 0, + eos: false, + err: None, + } + } + + pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: true, + } + } + + pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: false, + } + } + + pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> { + Ok(self.read_exact(1).await?[0]) + } + + pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> { + let bytes = self.read_exact(2).await?; + let mut b = [0u8; 2]; + b.copy_from_slice(&bytes[..]); + Ok(u16::from_be_bytes(b)) + } + + pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> { + let bytes = self.read_exact(4).await?; + let mut b = [0u8; 4]; + b.copy_from_slice(&bytes[..]); + Ok(u32::from_be_bytes(b)) + } + + pub fn into_stream(self) -> ByteStream { + let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok)); + if let Some(err) = self.err { + Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) + } else if self.eos { + Box::pin(buf_stream) + } else { + Box::pin(buf_stream.chain(self.stream)) + } + } + + fn try_get(&mut self, read_len: usize) -> Option<Bytes> { + if self.buf_len >= read_len { + let mut slices = Vec::with_capacity(self.buf.len()); + let mut taken = 0; + while taken < read_len { + let front = self.buf.pop_front().unwrap(); + if taken + front.len() <= read_len { + taken += front.len(); + self.buf_len -= front.len(); + slices.push(front); + } else { + let front_take = read_len - taken; + slices.push(front.slice(..front_take)); + self.buf.push_front(front.slice(front_take..)); + self.buf_len -= front_take; + break; + } + } + Some( + slices + .iter() + .map(|x| &x[..]) + .collect::<Vec<_>>() + .concat() + .into(), + ) + } else { + None + } + } +} + +pub enum ReadExactError { + UnexpectedEos, + Stream(u8), +} + +#[pin_project::pin_project] +pub struct ByteStreamReadExact<'a> { + #[pin] + reader: &'a mut ByteStreamReader, + read_len: usize, + fail_on_eos: bool, +} + +impl<'a> Future for ByteStreamReadExact<'a> { + type Output = Result<Bytes, ReadExactError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> { + let mut this = self.project(); + + loop { + if let Some(bytes) = this.reader.try_get(*this.read_len) { + return Poll::Ready(Ok(bytes)); + } + if let Some(err) = this.reader.err { + return Poll::Ready(Err(ReadExactError::Stream(err))); + } + if this.reader.eos { + if *this.fail_on_eos { + return Poll::Ready(Err(ReadExactError::UnexpectedEos)); + } else { + let bytes = Bytes::from( + this.reader + .buf + .iter() + .map(|x| &x[..]) + .collect::<Vec<_>>() + .concat(), + ); + this.reader.buf.clear(); + this.reader.buf_len = 0; + return Poll::Ready(Ok(bytes)); + } + } + + match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { + Some(Ok(slice)) => { + this.reader.buf_len += slice.len(); + this.reader.buf.push_back(slice); + } + Some(Err(e)) => { + this.reader.err = Some(e); + this.reader.eos = true; + } + None => { + this.reader.eos = true; + } + } + } + } +} diff --git a/src/util.rs b/src/util.rs index 01c392c..13cccb9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,12 +1,9 @@ use std::net::SocketAddr; use std::net::ToSocketAddrs; -use std::pin::Pin; -use bytes::Bytes; use log::info; use serde::Serialize; -use futures::Stream; use tokio::sync::watch; /// A node's identifier, which is also its public cryptographic key @@ -16,18 +13,6 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; -/// A stream of associated data. -/// -/// The Stream can continue after receiving an error. -/// When sent through Netapp, the Vec may be split in smaller chunk in such a way -/// consecutive Vec may get merged, but Vec and error code may not be reordered -/// -/// Error code 255 means the stream was cut before its end. Other codes have no predefined -/// meaning, it's up to your application to define their semantic. -pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; - -pub type Packet = Result<Bytes, u8>; - /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. /// |