From 368ba908794901bc793c6a087c02241be046bdf2 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 5 Jun 2022 15:33:43 +0200 Subject: initial work on associated stream still require testing, and fixing a few kinks: - sending packets > 16k truncate them - send one more packet than it could at eos - probably update documentation /!\ contains breaking changes --- src/client.rs | 37 +++++--- src/endpoint.rs | 66 ++++++++++++-- src/proto.rs | 260 ++++++++++++++++++++++++++++++++++++++++++++++---------- src/server.rs | 38 ++++++--- src/test.rs | 1 + src/util.rs | 17 ++-- 6 files changed, 338 insertions(+), 81 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 8227e8f..bce7aca 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,10 +37,10 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption)>>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>>, + inflight: Mutex, AssociatedStream)>>>, } impl ClientConn { @@ -148,9 +148,11 @@ impl ClientConn { { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; + // increment by 2; even are direct data; odd are associated stream let id = self .next_query_number - .fetch_add(1, atomic::Ordering::Relaxed); + .fetch_add(2, atomic::Ordering::Relaxed); + let stream_id = id + 1; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -166,7 +168,7 @@ impl ClientConn { }; // Encode request - let body = rmp_to_vec_all_named(rq.borrow())?; + let (body, stream) = rmp_to_vec_all_named(rq.borrow())?; drop(rq); let request = QueryMessage { @@ -185,7 +187,10 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(vec![]).is_err() { + if old_ch + .send((vec![], Box::pin(futures::stream::empty()))) + .is_err() + { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -195,15 +200,20 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); - query_send.send((id, prio, bytes))?; + query_send.send((id, prio, Data::Full(bytes)))?; + if let Some(stream) = stream { + query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?; + } else { + query_send.send((stream_id, prio, Data::Full(Vec::new())))?; + } cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let resp = resp_recv + let (resp, stream) = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let resp = resp_recv.await?; + let (resp, stream) = resp_recv.await?; } } @@ -217,10 +227,9 @@ impl ClientConn { let code = resp[0]; if code == 0 { - Ok(rmp_serde::decode::from_read_ref::< - _, - ::Response, - >(&resp[1..])?) + let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]); + let res = T::Response::deserialize_msg(&mut deser, stream).await?; + Ok(res) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) @@ -232,12 +241,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec) { + fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream) { trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send(msg).is_err() { + if ch.send((msg, stream)).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 42e9a98..81ed036 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; +use serde::de::Error as DeError; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::error::Error; use crate::netapp::*; @@ -14,8 +15,50 @@ use crate::util::*; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + fn serialize_msg( + &self, + serializer: S, + ) -> Result<(S::Ok, Option), S::Error>; + + async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( + deserializer: D, + stream: AssociatedStream, + ) -> Result; +} + +#[async_trait] +impl SerializeMessage for T +where + T: Serialize + for<'de> Deserialize<'de> + Send + Sync, +{ + fn serialize_msg( + &self, + serializer: S, + ) -> Result<(S::Ok, Option), S::Error> { + self.serialize(serializer).map(|r| (r, None)) + } + + async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( + deserializer: D, + mut stream: AssociatedStream, + ) -> Result { + use futures::StreamExt; + + let res = Self::deserialize(deserializer)?; + if stream.next().await.is_some() { + return Err(D::Error::custom( + "failed to deserialize: found associated stream when none expected", + )); + } + Ok(res) + } } /// This trait should be implemented by an object of your application @@ -128,7 +171,12 @@ pub(crate) type DynEndpoint = Box; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error>; + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec, Option), Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -145,11 +193,17 @@ where M: Message + 'static, H: EndpointHandler + 'static, { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error> { + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec, Option), Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; + let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf); + let req = M::deserialize_msg(&mut deser, stream).await?; let res = h.handle(&req, from).await; let res_bytes = rmp_to_vec_all_named(&res)?; Ok(res_bytes) diff --git a/src/proto.rs b/src/proto.rs index e843bff..b45ff13 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,9 +1,13 @@ use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; -use log::trace; +use log::{trace, warn}; -use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::Stream; +use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -11,6 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; +use crate::util::AssociatedStream; /// Priority of a request (click to read more about priorities). /// @@ -48,14 +53,73 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, prio: RequestPriority, - data: Vec, - cursor: usize, + data: DataReader, +} + +pub(crate) enum Data { + Full(Vec), + Streaming(AssociatedStream), +} + +#[pin_project::pin_project(project = DataReaderProj)] +enum DataReader { + Full { + #[pin] + data: Vec, + pos: usize, + }, + Streaming { + #[pin] + reader: AssociatedStream, + }, +} + +impl From for DataReader { + fn from(data: Data) -> DataReader { + match data { + Data::Full(data) => DataReader::Full { data, pos: 0 }, + Data::Streaming(reader) => DataReader::Streaming { reader }, + } + } +} + +impl Stream for DataReader { + type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + DataReaderProj::Full { data, pos } => { + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); + let end = *pos + len; + + if len == 0 { + Poll::Ready(None) + } else { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..len].copy_from_slice(&data[*pos..end]); + *pos = end; + Poll::Ready(Some((body, len))) + } + } + DataReaderProj::Streaming { reader } => { + reader.poll_next(cx).map(|opt| { + opt.map(|v| { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); + // TODO this can throw away long vec, they should be splited instead + body[..len].copy_from_slice(&v[..len]); + (body, len) + }) + }) + } + } + } } struct SendQueue { @@ -108,7 +172,7 @@ impl SendQueue { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -118,51 +182,78 @@ pub(crate) trait SendLoop: Sync { let mut should_exit = false; while !should_exit || !sending.is_empty() { if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else if let Some(mut item) = sending.pop() { trace!( - "send_loop: sending bytes for {} ({} bytes, {} already sent)", - item.id, - item.data.len(), - item.cursor + "send_loop: sending bytes for {}", + item.id, ); + + let data = futures::select! { + data = item.data.next().fuse() => data, + default => { + // nothing to send yet; re-schedule and find something else to do + sending.push(item); + continue; + + // TODO if every SendQueueItem is waiting on data, use select_all to await + // something to do + // TODO find some way to not require sending empty last chunk + } + }; + let header_id = RequestID::to_be_bytes(item.id); write.write_all(&header_id[..]).await?; - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { + let data = match data.as_ref() { + Some((data, len)) => &data[..*len], + None => &[], + }; + + if !data.is_empty() { let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); + ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; + write.write_all(data).await?; sending.push(item); } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); + // this is always zero for now, but may be more when above TODO get fixed + let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; - write.write_all(&item.data[item.cursor..]).await?; + write.write_all(data).await?; } + write.flush().await?; } else { let sth = msg_recv.recv().await; if let Some((id, prio, data)) = sth { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else { should_exit = true; @@ -175,6 +266,41 @@ pub(crate) trait SendLoop: Sync { } } +struct ChannelPair { + receiver: Option>>, + sender: Option>>, +} + +impl ChannelPair { + fn take_receiver(&mut self) -> Option>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } +} + +impl Default for ChannelPair { + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -184,13 +310,17 @@ pub(crate) trait SendLoop: Sync { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec); + fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving = HashMap::new(); + let mut receiving: HashMap> = HashMap::new(); + let mut streams: HashMap< + RequestID, + ChannelPair, + > = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -214,13 +344,43 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", next_slice.len()); - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); + if id & 1 == 0 { + // main stream + let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); + msg_bytes.extend_from_slice(&next_slice[..]); - if has_cont { - receiving.insert(id, msg_bytes); + if has_cont { + receiving.insert(id, msg_bytes); + } else { + let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); + + if let Some(receiver) = channel_pair.take_receiver() { + self.recv_handler(id, msg_bytes, Box::pin(receiver)); + } else { + warn!("Couldn't take receiver part of stream") + } + + channel_pair.insert_into(&mut streams, id | 1); + } } else { - self.recv_handler(id, msg_bytes); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } + + if !has_cont { + channel_pair.take_sender(); + } + + channel_pair.insert_into(&mut streams, id); } } Ok(()) @@ -236,38 +396,50 @@ mod test { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let mut q = SendQueue::new(); diff --git a/src/server.rs b/src/server.rs index 5465307..6cd4056 100644 --- a/src/server.rs +++ b/src/server.rs @@ -55,7 +55,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption)>>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -123,7 +123,11 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux(self: &Arc, bytes: &[u8]) -> Result, Error> { + async fn recv_handler_aux( + self: &Arc, + bytes: &[u8], + stream: AssociatedStream, + ) -> Result<(Vec, Option), Error> { let msg = QueryMessage::decode(bytes)?; let path = String::from_utf8(msg.path.to_vec())?; @@ -156,11 +160,11 @@ impl ServerConn { span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); - handler.handle(msg.body, self.peer_id) + handler.handle(msg.body, stream, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, self.peer_id).await + handler.handle(msg.body, stream, self.peer_id).await } } } else { @@ -173,7 +177,7 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, bytes: Vec) { + fn recv_handler(self: &Arc, id: RequestID, bytes: Vec, stream: AssociatedStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); @@ -182,26 +186,36 @@ impl RecvLoop for ServerConn { let bytes: Bytes = bytes.into(); let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..]).await; + let resp = self2.recv_handler_aux(&bytes[..], stream).await; - let resp_bytes = match resp { - Ok(rb) => { + let (resp_bytes, resp_stream) = match resp { + Ok((rb, rs)) => { let mut resp_bytes = vec![0u8]; resp_bytes.extend(rb); - resp_bytes + (resp_bytes, rs) } Err(e) => { let mut resp_bytes = vec![e.code()]; resp_bytes.extend(e.to_string().into_bytes()); - resp_bytes + (resp_bytes, None) } }; trace!("ServerConn sending response to {}: ", id); resp_send - .send((id, prio, resp_bytes)) - .log_err("ServerConn recv_handler send resp"); + .send((id, prio, Data::Full(resp_bytes))) + .log_err("ServerConn recv_handler send resp bytes"); + + if let Some(resp_stream) = resp_stream { + resp_send + .send((id + 1, prio, Data::Streaming(resp_stream))) + .log_err("ServerConn recv_handler send resp stream"); + } else { + resp_send + .send((id + 1, prio, Data::Full(Vec::new()))) + .log_err("ServerConn recv_handler send resp stream"); + } }); } } diff --git a/src/test.rs b/src/test.rs index 82c7ba6..ecd5450 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,7 @@ use crate::NodeID; #[tokio::test(flavor = "current_thread")] async fn test_with_basic_scheduler() { + pretty_env_logger::init(); run_test().await } diff --git a/src/util.rs b/src/util.rs index f4dfac7..4333080 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,10 @@ +use crate::endpoint::SerializeMessage; + use std::net::SocketAddr; use std::net::ToSocketAddrs; +use std::pin::Pin; -use serde::Serialize; +use futures::Stream; use log::info; @@ -14,21 +17,25 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; +pub type AssociatedStream = Pin> + Send>>; + /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. /// /// Field names and variant names are included in the serialization. /// This is used internally by the netapp communication protocol. -pub fn rmp_to_vec_all_named(val: &T) -> Result, rmp_serde::encode::Error> +pub fn rmp_to_vec_all_named( + val: &T, +) -> Result<(Vec, Option), rmp_serde::encode::Error> where - T: Serialize + ?Sized, + T: SerializeMessage + ?Sized, { let mut wr = Vec::with_capacity(128); let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - val.serialize(&mut se)?; - Ok(wr) + let (_, stream) = val.serialize_msg(&mut se)?; + Ok((wr, stream)) } /// This async function returns only when a true signal was received -- cgit v1.2.3 From fb5462ecdb6b5731a63a902519d3ec9b1061b8dd Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 5 Jun 2022 16:47:29 +0200 Subject: rechunk stream --- src/proto.rs | 159 +++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 94 insertions(+), 65 deletions(-) (limited to 'src') diff --git a/src/proto.rs b/src/proto.rs index b45ff13..ca1a3d2 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -53,7 +53,7 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -77,6 +77,10 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, + packet: Vec, + pos: usize, + buf: Vec, + eos: bool, }, } @@ -84,7 +88,13 @@ impl From for DataReader { fn from(data: Data) -> DataReader { match data { Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { reader }, + Data::Streaming(reader) => DataReader::Streaming { + reader, + packet: Vec::new(), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, + }, } } } @@ -107,16 +117,43 @@ impl Stream for DataReader { Poll::Ready(Some((body, len))) } } - DataReaderProj::Streaming { reader } => { - reader.poll_next(cx).map(|opt| { - opt.map(|v| { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); - // TODO this can throw away long vec, they should be splited instead - body[..len].copy_from_slice(&v[..len]); - (body, len) - }) - }) + DataReaderProj::Streaming { + mut reader, + packet, + pos, + buf, + eos, + } => { + if *eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + loop { + let packet_left = packet.len() - *pos; + let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + buf.extend_from_slice(&packet[*pos..*pos + to_read]); + *pos += to_read; + if buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; + } + + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { + *packet = p; + *pos = 0; + } else { + *eos = true; + break; + } + } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..buf.len()].copy_from_slice(&buf); + buf.clear(); + Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) } } } @@ -196,10 +233,7 @@ pub(crate) trait SendLoop: Sync { data: data.into(), }); } else if let Some(mut item) = sending.pop() { - trace!( - "send_loop: sending bytes for {}", - item.id, - ); + trace!("send_loop: sending bytes for {}", item.id,); let data = futures::select! { data = item.data.next().fuse() => data, @@ -210,7 +244,6 @@ pub(crate) trait SendLoop: Sync { // TODO if every SendQueueItem is waiting on data, use select_all to await // something to do - // TODO find some way to not require sending empty last chunk } }; @@ -222,7 +255,7 @@ pub(crate) trait SendLoop: Sync { None => &[], }; - if !data.is_empty() { + if data.len() == MAX_CHUNK_LENGTH as usize { let size_header = ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; @@ -231,7 +264,6 @@ pub(crate) trait SendLoop: Sync { sending.push(item); } else { - // this is always zero for now, but may be more when above TODO get fixed let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; @@ -267,38 +299,38 @@ pub(crate) trait SendLoop: Sync { } struct ChannelPair { - receiver: Option>>, - sender: Option>>, + receiver: Option>>, + sender: Option>>, } impl ChannelPair { - fn take_receiver(&mut self) -> Option>> { - self.receiver.take() - } - - fn take_sender(&mut self) -> Option>> { - self.sender.take() - } - - fn ref_sender(&mut self) -> Option<&UnboundedSender>> { - self.sender.as_ref().take() - } - - fn insert_into(self, map: &mut HashMap, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); - } - } + fn take_receiver(&mut self) -> Option>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } } impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), - } - } + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } } /// The RecvLoop trait, which is implemented both by the client and the server @@ -317,10 +349,7 @@ pub(crate) trait RecvLoop: Sync + 'static { R: AsyncReadExt + Unpin + Send + Sync, { let mut receiving: HashMap> = HashMap::new(); - let mut streams: HashMap< - RequestID, - ChannelPair, - > = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -345,7 +374,7 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: read {} bytes", next_slice.len()); if id & 1 == 0 { - // main stream + // main stream let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); msg_bytes.extend_from_slice(&next_slice[..]); @@ -357,30 +386,30 @@ pub(crate) trait RecvLoop: Sync + 'static { if let Some(receiver) = channel_pair.take_receiver() { self.recv_handler(id, msg_bytes, Box::pin(receiver)); } else { - warn!("Couldn't take receiver part of stream") - } + warn!("Couldn't take receiver part of stream") + } - channel_pair.insert_into(&mut streams, id | 1); + channel_pair.insert_into(&mut streams, id | 1); } } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } if !has_cont { - channel_pair.take_sender(); - } + channel_pair.take_sender(); + } - channel_pair.insert_into(&mut streams, id); + channel_pair.insert_into(&mut streams, id); } } Ok(()) -- cgit v1.2.3 From 4745e7c4ba5665d3303ae567087781778cec9c34 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Wed, 8 Jun 2022 00:30:56 +0200 Subject: further work on streams most changes still required are related to error handling --- src/client.rs | 5 ++-- src/endpoint.rs | 80 ++++++++++++++++++++++++++++++------------------- src/netapp.rs | 4 ++- src/peering/fullmesh.rs | 8 +++-- src/proto.rs | 5 ++-- src/util.rs | 5 +++- 6 files changed, 67 insertions(+), 40 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index bce7aca..bc16fb1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -227,9 +227,8 @@ impl ClientConn { let code = resp[0]; if code == 0 { - let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]); - let res = T::Response::deserialize_msg(&mut deser, stream).await?; - Ok(res) + let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; + Ok(T::Response::deserialize_msg(ser_resp, stream).await) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index 81ed036..c25365a 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,8 +5,7 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::de::Error as DeError; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use crate::error::Error; use crate::netapp::*; @@ -22,42 +21,61 @@ pub trait Message: SerializeMessage + Send + Sync { /// A trait for de/serializing messages, with possible associated stream. #[async_trait] pub trait SerializeMessage: Sized { - fn serialize_msg( - &self, - serializer: S, - ) -> Result<(S::Ok, Option), S::Error>; + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( - deserializer: D, - stream: AssociatedStream, - ) -> Result; + // TODO should return Result + fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self; } +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + #[async_trait] impl SerializeMessage for T where - T: Serialize + for<'de> Deserialize<'de> + Send + Sync, + T: AutoSerialize, { - fn serialize_msg( - &self, - serializer: S, - ) -> Result<(S::Ok, Option), S::Error> { - self.serialize(serializer).map(|r| (r, None)) + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + (self.clone(), None) + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + // TODO verify no stream + ser_self + } +} + +impl AutoSerialize for () {} + +#[async_trait] +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } } - async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( - deserializer: D, - mut stream: AssociatedStream, - ) -> Result { - use futures::StreamExt; - - let res = Self::deserialize(deserializer)?; - if stream.next().await.is_some() { - return Err(D::Error::custom( - "failed to deserialize: found associated stream when none expected", - )); + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), } - Ok(res) } } @@ -139,7 +157,7 @@ where prio: RequestPriority, ) -> Result<::Response, Error> where - B: Borrow, + B: Borrow + Send + Sync, { if *target == self.netapp.id { match self.handler.load_full() { @@ -202,8 +220,8 @@ where match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf); - let req = M::deserialize_msg(&mut deser, stream).await?; + let req = rmp_serde::decode::from_read_ref(buf)?; + let req = M::deserialize_msg(req, stream).await; let res = h.handle(&req, from).await; let res_bytes = rmp_to_vec_all_named(&res)?; Ok(res_bytes) diff --git a/src/netapp.rs b/src/netapp.rs index e9efa2e..27f17e6 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub(crate) struct HelloMessage { pub server_addr: Option, pub server_port: u16, } +impl AutoSerialize for HelloMessage {} + impl Message for HelloMessage { type Response = (); } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 012c5a0..7dfc5c4 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3; // -- Protocol messages -- -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] struct PingMessage { pub id: u64, pub peer_list_hash: hash::Digest, @@ -39,7 +39,9 @@ impl Message for PingMessage { type Response = PingMessage; } -#[derive(Serialize, Deserialize)] +impl AutoSerialize for PingMessage {} + +#[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { pub list: Vec<(NodeID, SocketAddr)>, } @@ -48,6 +50,8 @@ impl Message for PeerListMessage { type Response = PeerListMessage; } +impl AutoSerialize for PeerListMessage {} + // -- Algorithm data structures -- #[derive(Debug)] diff --git a/src/proto.rs b/src/proto.rs index ca1a3d2..073a317 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -151,9 +151,10 @@ impl Stream for DataReader { } let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..buf.len()].copy_from_slice(&buf); + let len = buf.len(); + body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) + Poll::Ready(Some((body, len))) } } } diff --git a/src/util.rs b/src/util.rs index 4333080..02b4e7d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -8,6 +8,8 @@ use futures::Stream; use log::info; +use serde::Serialize; + use tokio::sync::watch; /// A node's identifier, which is also its public cryptographic key @@ -34,7 +36,8 @@ where let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - let (_, stream) = val.serialize_msg(&mut se)?; + let (val, stream) = val.serialize_msg(); + val.serialize(&mut se)?; Ok((wr, stream)) } -- cgit v1.2.3 From 5d7541e13a4c3640f0dc8aead595b51775fc0ac8 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 19 Jun 2022 17:44:07 +0200 Subject: wait for any ready stream instead of the highest priority one --- src/endpoint.rs | 2 +- src/proto.rs | 185 +++++++++++++++++++++++++++++++++++--------------------- src/util.rs | 8 +++ 3 files changed, 124 insertions(+), 71 deletions(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index c25365a..c430d4e 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -42,7 +42,7 @@ where (self.clone(), None) } - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self { // TODO verify no stream ser_self } diff --git a/src/proto.rs b/src/proto.rs index 073a317..417b508 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -7,7 +7,7 @@ use log::{trace, warn}; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::Stream; -use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; +use futures::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -53,7 +53,8 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +const ERROR_MARKER: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -99,8 +100,29 @@ impl From for DataReader { } } +struct DataReaderItem { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actuall lenght of data + len: usize, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, +} + +impl DataReaderItem { + fn empty_last() -> Self { + DataReaderItem { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + may_have_more: false, + } + } +} + impl Stream for DataReader { - type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { @@ -114,7 +136,11 @@ impl Stream for DataReader { let mut body = [0; MAX_CHUNK_LENGTH as usize]; body[..len].copy_from_slice(&data[*pos..end]); *pos = end; - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: end < data.len(), + })) } } DataReaderProj::Streaming { @@ -154,7 +180,11 @@ impl Stream for DataReader { let len = buf.len(); body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: !*eos, + })) } } } @@ -181,6 +211,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -196,6 +228,54 @@ impl SendQueue { fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataReaderItem); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for i in 0..self.queue.items.len() { + let (_prio, items_at_prio) = &mut self.queue.items[i]; + + for _ in 0..items_at_prio.len() { + let mut item = items_at_prio.pop_front().unwrap(); + + match Pin::new(&mut item.data).poll_next(ctx) { + Poll::Pending => items_at_prio.push_back(item), + Poll::Ready(Some(data)) => { + let id = item.id; + if data.may_have_more { + self.queue.push(item); + } else { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + } + return Poll::Ready((id, data)); + } + Poll::Ready(None) => { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + return Poll::Ready((item.id, DataReaderItem::empty_last())); + } + } + } + } + // TODO what do we do if self.queue is empty? We won't get scheduled again. + Poll::Pending + } } /// The SendLoop trait, which is implemented both by the client and the server @@ -219,77 +299,42 @@ pub(crate) trait SendLoop: Sync { let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { - if let Ok((id, prio, data)) = msg_recv.try_recv() { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } + let recv_fut = msg_recv.recv(); + futures::pin_mut!(recv_fut); + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + use futures::future::Either; + match futures::future::select(recv_fut, send_fut).await { + Either::Left((sth, _send_fut)) => { + if let Some((id, prio, data)) = sth { + sending.push(SendQueueItem { + id, + prio, + data: data.into(), + }); + } else { + should_exit = true; + }; } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else if let Some(mut item) = sending.pop() { - trace!("send_loop: sending bytes for {}", item.id,); - - let data = futures::select! { - data = item.data.next().fuse() => data, - default => { - // nothing to send yet; re-schedule and find something else to do - sending.push(item); - continue; - - // TODO if every SendQueueItem is waiting on data, use select_all to await - // something to do - } - }; - - let header_id = RequestID::to_be_bytes(item.id); - write.write_all(&header_id[..]).await?; + Either::Right(((id, data), _recv_fut)) => { + trace!("send_loop: sending bytes for {}", id); - let data = match data.as_ref() { - Some((data, len)) => &data[..*len], - None => &[], - }; + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; - if data.len() == MAX_CHUNK_LENGTH as usize { - let size_header = - ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); - write.write_all(&size_header[..]).await?; + let body = &data.data[..data.len]; - write.write_all(data).await?; + let size_header = if data.may_have_more { + ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) + } else { + ChunkLength::to_be_bytes(data.len as u16) + }; - sending.push(item); - } else { - let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; - - write.write_all(data).await?; - } - - write.flush().await?; - } else { - let sth = msg_recv.recv().await; - if let Some((id, prio, data)) = sth { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } - } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; + write.write_all(body).await?; + write.flush().await?; } } } diff --git a/src/util.rs b/src/util.rs index 02b4e7d..3ee0cb9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -19,6 +19,14 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; +/// A stream of associated data. +/// +/// The Stream can continue after receiving an error. +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// The error code have no predefined meaning, it's up to you application to define their +/// semantic. pub type AssociatedStream = Pin> + Send>>; /// Utility function: encodes any serializable value in MessagePack binary format -- cgit v1.2.3 From 0fec85b47a1bc679d2684994bfae6ef0fe7d4911 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 19 Jun 2022 18:42:27 +0200 Subject: start supporting sending error on stream --- src/proto.rs | 99 +++++++++++++++++++++++++++++++++++++++++++----------------- src/util.rs | 2 +- 2 files changed, 72 insertions(+), 29 deletions(-) (limited to 'src') diff --git a/src/proto.rs b/src/proto.rs index 417b508..e3f9be8 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -78,7 +78,7 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, - packet: Vec, + packet: Result, u8>, pos: usize, buf: Vec, eos: bool, @@ -91,7 +91,7 @@ impl From for DataReader { Data::Full(data) => DataReader::Full { data, pos: 0 }, Data::Streaming(reader) => DataReader::Streaming { reader, - packet: Vec::new(), + packet: Ok(Vec::new()), pos: 0, buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), eos: false, @@ -100,11 +100,18 @@ impl From for DataReader { } } +enum DataFrame { + Data { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actual lenght of data + len: usize, + }, + Error(u8), +} + struct DataReaderItem { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actuall lenght of data - len: usize, + data: DataFrame, /// whethere there may be more data comming from this stream. Can be used for some /// optimization. It's an error to set it to false if there is more data, but it is correct /// (albeit sub-optimal) to set it to true if there is nothing coming after @@ -114,11 +121,34 @@ struct DataReaderItem { impl DataReaderItem { fn empty_last() -> Self { DataReaderItem { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, + data: DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + }, may_have_more: false, } } + + fn header(&self) -> [u8; 2] { + let continuation = if self.may_have_more { + CHUNK_HAS_CONTINUATION + } else { + 0 + }; + let len = match self.data { + DataFrame::Data { len, .. } => len as u16, + DataFrame::Error(e) => e as u16 | ERROR_MARKER, + }; + + ChunkLength::to_be_bytes(len | continuation) + } + + fn data(&self) -> &[u8] { + match self.data { + DataFrame::Data { ref data, len } => &data[..len], + DataFrame::Error(_) => &[], + } + } } impl Stream for DataReader { @@ -137,15 +167,14 @@ impl Stream for DataReader { body[..len].copy_from_slice(&data[*pos..end]); *pos = end; Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: end < data.len(), })) } } DataReaderProj::Streaming { mut reader, - packet, + packet: res_packet, pos, buf, eos, @@ -156,6 +185,17 @@ impl Stream for DataReader { return Poll::Ready(None); } loop { + let packet = match res_packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *res_packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); + } + }; let packet_left = packet.len() - *pos; let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); let to_read = std::cmp::min(buf_left, packet_left); @@ -168,8 +208,13 @@ impl Stream for DataReader { // we don't have a full buf, packet is empty; try receive more if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *packet = p; + *res_packet = p; *pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if res_packet.is_err() && !buf.is_empty() { + break; + } } else { *eos = true; break; @@ -181,8 +226,7 @@ impl Stream for DataReader { body[..len].copy_from_slice(buf); buf.clear(); Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: !*eos, })) } @@ -211,8 +255,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -324,16 +368,8 @@ pub(crate) trait SendLoop: Sync { let header_id = RequestID::to_be_bytes(id); write.write_all(&header_id[..]).await?; - let body = &data.data[..data.len]; - - let size_header = if data.may_have_more { - ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) - } else { - ChunkLength::to_be_bytes(data.len as u16) - }; - - write.write_all(&size_header[..]).await?; - write.write_all(body).await?; + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; write.flush().await?; } } @@ -413,7 +449,13 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: got header size: {:04x}", size); let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let size = size & !CHUNK_HAS_CONTINUATION; + let is_error = (size & ERROR_MARKER) != 0; + let size = if !is_error { + size & !CHUNK_HAS_CONTINUATION + } else { + 0 + }; + // TODO propagate errors let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; @@ -430,7 +472,8 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); if let Some(receiver) = channel_pair.take_receiver() { - self.recv_handler(id, msg_bytes, Box::pin(receiver)); + use futures::StreamExt; + self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); } else { warn!("Couldn't take receiver part of stream") } diff --git a/src/util.rs b/src/util.rs index 3ee0cb9..76d7ecf 100644 --- a/src/util.rs +++ b/src/util.rs @@ -27,7 +27,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// The error code have no predefined meaning, it's up to you application to define their /// semantic. -pub type AssociatedStream = Pin> + Send>>; +pub type AssociatedStream = Pin, u8>> + Send>>; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. -- cgit v1.2.3 From d3d18b8e8bde5fee81022fd050d5f4c114262fcf Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Mon, 20 Jun 2022 23:40:31 +0200 Subject: use a framing protocol instead of even/odd channel --- src/client.rs | 32 ++--- src/endpoint.rs | 1 - src/error.rs | 4 + src/proto.rs | 364 ++++++++++++++++++++++++++------------------------------ src/server.rs | 26 ++-- 5 files changed, 193 insertions(+), 234 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index bc16fb1..a630f87 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,10 +37,11 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption>, + query_send: + ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex, AssociatedStream)>>>, + inflight: Mutex>>, } impl ClientConn { @@ -148,11 +149,9 @@ impl ClientConn { { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; - // increment by 2; even are direct data; odd are associated stream let id = self .next_query_number - .fetch_add(2, atomic::Ordering::Relaxed); - let stream_id = id + 1; + .fetch_add(1, atomic::Ordering::Relaxed); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -187,10 +186,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch - .send((vec![], Box::pin(futures::stream::empty()))) - .is_err() - { + if old_ch.send(Box::pin(futures::stream::empty())).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -200,22 +196,18 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); - query_send.send((id, prio, Data::Full(bytes)))?; - if let Some(stream) = stream { - query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?; - } else { - query_send.send((stream_id, prio, Data::Full(Vec::new())))?; - } + query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let (resp, stream) = resp_recv + let stream = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let (resp, stream) = resp_recv.await?; + let stream = resp_recv.await?; } } + let (resp, stream) = Framing::from_stream(stream).await?.into_parts(); if resp.is_empty() { return Err(Error::Message( @@ -240,12 +232,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream) { - trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); + fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send((msg, stream)).is_err() { + if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index c430d4e..f31141d 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -23,7 +23,6 @@ pub trait Message: SerializeMessage + Send + Sync { pub trait SerializeMessage: Sized { type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - // TODO should return Result fn serialize_msg(&self) -> (Self::SerializableSelf, Option); // TODO should return Result diff --git a/src/error.rs b/src/error.rs index 99acdd1..7911c29 100644 --- a/src/error.rs +++ b/src/error.rs @@ -25,6 +25,9 @@ pub enum Error { #[error(display = "UTF8 error: {}", _0)] UTF8(#[error(source)] std::string::FromUtf8Error), + #[error(display = "Framing protocol error")] + Framing, + #[error(display = "{}", _0)] Message(String), @@ -50,6 +53,7 @@ impl Error { Self::RMPEncode(_) => 10, Self::RMPDecode(_) => 11, Self::UTF8(_) => 12, + Self::Framing => 13, Self::NoHandler => 20, Self::ConnectionClosed => 21, Self::Handshake(_) => 30, diff --git a/src/proto.rs b/src/proto.rs index e3f9be8..d6dc35a 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -3,11 +3,11 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use log::{trace, warn}; +use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::Stream; +use futures::channel::mpsc::{unbounded, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -63,39 +63,24 @@ struct SendQueueItem { data: DataReader, } -pub(crate) enum Data { - Full(Vec), - Streaming(AssociatedStream), +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: AssociatedStream, + packet: Result, u8>, + pos: usize, + buf: Vec, + eos: bool, } -#[pin_project::pin_project(project = DataReaderProj)] -enum DataReader { - Full { - #[pin] - data: Vec, - pos: usize, - }, - Streaming { - #[pin] - reader: AssociatedStream, - packet: Result, u8>, - pos: usize, - buf: Vec, - eos: bool, - }, -} - -impl From for DataReader { - fn from(data: Data) -> DataReader { - match data { - Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { - reader, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - }, +impl From for DataReader { + fn from(data: AssociatedStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, } } } @@ -155,82 +140,60 @@ impl Stream for DataReader { type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - DataReaderProj::Full { data, pos } => { - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); - let end = *pos + len; - - if len == 0 { - Poll::Ready(None) - } else { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..len].copy_from_slice(&data[*pos..end]); - *pos = end; - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: end < data.len(), - })) + let mut this = self.project(); + + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; } - DataReaderProj::Streaming { - mut reader, - packet: res_packet, - pos, - buf, - eos, - } => { - if *eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - loop { - let packet = match res_packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *res_packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *pos; - let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - buf.extend_from_slice(&packet[*pos..*pos + to_read]); - *pos += to_read; - if buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *res_packet = p; - *pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if res_packet.is_err() && !buf.is_empty() { - break; - } - } else { - *eos = true; - break; - } + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = buf.len(); - body[..len].copy_from_slice(buf); - buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*eos, - })) + } else { + *this.eos = true; + break; } } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) } } @@ -334,7 +297,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -380,38 +343,82 @@ pub(crate) trait SendLoop: Sync { } } -struct ChannelPair { - receiver: Option>>, - sender: Option>>, +pub(crate) struct Framing { + direct: Vec, + stream: Option, } -impl ChannelPair { - fn take_receiver(&mut self) -> Option>> { - self.receiver.take() +impl Framing { + pub fn new(direct: Vec, stream: Option) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } } - fn take_sender(&mut self) -> Option>> { - self.sender.take() - } + pub fn into_stream(self) -> AssociatedStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; - fn ref_sender(&mut self) -> Option<&UnboundedSender>> { - self.sender.as_ref().take() - } + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); - fn insert_into(self, map: &mut HashMap, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) } } -} -impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), + pub async fn from_stream, u8>> + Unpin + Send + 'static>( + mut stream: S, + ) -> Result { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; } + + let stream: AssociatedStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec, AssociatedStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) } } @@ -424,14 +431,13 @@ impl Default for ChannelPair { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream); + fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving: HashMap> = HashMap::new(); - let mut streams: HashMap = HashMap::new(); + let mut streams: HashMap, u8>>> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -450,55 +456,30 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; - let size = if !is_error { - size & !CHUNK_HAS_CONTINUATION + let packet = if is_error { + Err(size as u8) } else { - 0 + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) }; - // TODO propagate errors - - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - - if id & 1 == 0 { - // main stream - let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); - - if has_cont { - receiving.insert(id, msg_bytes); - } else { - let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); - - if let Some(receiver) = channel_pair.take_receiver() { - use futures::StreamExt; - self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); - } else { - warn!("Couldn't take receiver part of stream") - } - channel_pair.insert_into(&mut streams, id | 1); - } + let sender = if let Some(send) = streams.remove(&(id)) { + send } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + let (send, recv) = unbounded(); + self.recv_handler(id, Box::pin(recv)); + send + }; - if !has_cont { - channel_pair.take_sender(); - } + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + let _ = sender.unbounded_send(packet); - channel_pair.insert_into(&mut streams, id); + if has_cont { + streams.insert(id, sender); } } Ok(()) @@ -509,55 +490,44 @@ pub(crate) trait RecvLoop: Sync + 'static { mod test { use super::*; + fn empty_data() -> DataReader { + type Item = Result, u8>; + let stream: Pin + Send + 'static>> = + Box::pin(futures::stream::empty::, u8>>()); + stream.into() + } + #[test] fn test_priority_queue() { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let mut q = SendQueue::new(); diff --git a/src/server.rs b/src/server.rs index 6cd4056..86e5156 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,6 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; -use bytes::Bytes; use log::{debug, trace}; #[cfg(feature = "telemetry")] @@ -55,7 +54,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -177,13 +176,13 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, bytes: Vec, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); tokio::spawn(async move { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); + trace!("ServerConn recv_handler {}", id); + let (bytes, stream) = Framing::from_stream(stream).await?.into_parts(); let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; let resp = self2.recv_handler_aux(&bytes[..], stream).await; @@ -204,18 +203,13 @@ impl RecvLoop for ServerConn { trace!("ServerConn sending response to {}: ", id); resp_send - .send((id, prio, Data::Full(resp_bytes))) + .send(( + id, + prio, + Framing::new(resp_bytes, resp_stream).into_stream(), + )) .log_err("ServerConn recv_handler send resp bytes"); - - if let Some(resp_stream) = resp_stream { - resp_send - .send((id + 1, prio, Data::Streaming(resp_stream))) - .log_err("ServerConn recv_handler send resp stream"); - } else { - resp_send - .send((id + 1, prio, Data::Full(Vec::new()))) - .log_err("ServerConn recv_handler send resp stream"); - } + Ok::<_, Error>(()) }); } } -- cgit v1.2.3 From cdff8ae1beab44a22d0eb0eb00c624e49971b6ca Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Mon, 18 Jul 2022 15:21:13 +0200 Subject: add detection of premature eos --- src/client.rs | 7 ++++--- src/proto.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++++------------ src/server.rs | 3 ++- src/util.rs | 8 +++++--- 4 files changed, 58 insertions(+), 19 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index a630f87..6d49f5c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use log::{debug, error, trace}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -41,7 +42,7 @@ pub(crate) struct ClientConn { ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>, + inflight: Mutex>>>, } impl ClientConn { @@ -186,7 +187,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(Box::pin(futures::stream::empty())).is_err() { + if old_ch.send(unbounded().1).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -232,7 +233,7 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); diff --git a/src/proto.rs b/src/proto.rs index d6dc35a..92d8d80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedSender}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; -use crate::util::AssociatedStream; +use crate::util::{AssociatedStream, Packet}; /// Priority of a request (click to read more about priorities). /// @@ -67,7 +67,7 @@ struct SendQueueItem { struct DataReader { #[pin] reader: AssociatedStream, - packet: Result, u8>, + packet: Packet, pos: usize, buf: Vec, eos: bool, @@ -370,7 +370,7 @@ impl Framing { } } - pub async fn from_stream, u8>> + Unpin + Send + 'static>( + pub async fn from_stream + Unpin + Send + 'static>( mut stream: S, ) -> Result { let mut packet = stream @@ -422,6 +422,39 @@ impl Framing { } } +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -431,13 +464,13 @@ impl Framing { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream); + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut streams: HashMap, u8>>> = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static { Ok(next_slice) }; - let sender = if let Some(send) = streams.remove(&(id)) { + let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { let (send, recv) = unbounded(); - self.recv_handler(id, Box::pin(recv)); - send + self.recv_handler(id, recv); + Sender::new(send) }; // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - let _ = sender.unbounded_send(packet); + sender.send(packet); if has_cont { streams.insert(id, sender); + } else { + sender.end(); } } Ok(()) @@ -491,9 +526,9 @@ mod test { use super::*; fn empty_data() -> DataReader { - type Item = Result, u8>; + type Item = Packet; let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::, u8>>()); + Box::pin(futures::stream::empty::()); stream.into() } diff --git a/src/server.rs b/src/server.rs index 86e5156..8075484 100644 --- a/src/server.rs +++ b/src/server.rs @@ -19,6 +19,7 @@ use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; +use futures::channel::mpsc::UnboundedReceiver; use futures::io::{AsyncReadExt, AsyncWriteExt}; use async_trait::async_trait; @@ -176,7 +177,7 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); diff --git a/src/util.rs b/src/util.rs index 76d7ecf..186678d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -25,9 +25,11 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// When sent through Netapp, the Vec may be split in smaller chunk in such a way /// consecutive Vec may get merged, but Vec and error code may not be reordered /// -/// The error code have no predefined meaning, it's up to you application to define their -/// semantic. -pub type AssociatedStream = Pin, u8>> + Send>>; +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type AssociatedStream = Pin + Send>>; + +pub type Packet = Result, u8>; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. -- cgit v1.2.3 From f35fa7d18d9e0f51bed311355ec1310b1d311ab3 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 17:34:53 +0200 Subject: Move things around --- src/client.rs | 18 +- src/endpoint.rs | 78 +----- src/error.rs | 2 +- src/lib.rs | 5 +- src/message.rs | 255 ++++++++++++++++++++ src/netapp.rs | 2 +- src/peering/basalt.rs | 3 +- src/peering/fullmesh.rs | 3 +- src/proto.rs | 617 ------------------------------------------------ src/proto2.rs | 75 ------ src/recv.rs | 114 +++++++++ src/send.rs | 410 ++++++++++++++++++++++++++++++++ src/server.rs | 32 ++- src/util.rs | 12 +- 14 files changed, 820 insertions(+), 806 deletions(-) create mode 100644 src/message.rs delete mode 100644 src/proto.rs delete mode 100644 src/proto2.rs create mode 100644 src/recv.rs create mode 100644 src/send.rs (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 6d49f5c..663a3e4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,9 +5,12 @@ use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; +use async_trait::async_trait; use log::{debug, error, trace}; use futures::channel::mpsc::{unbounded, UnboundedReceiver}; +use futures::io::AsyncReadExt; +use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -21,25 +24,18 @@ use opentelemetry::{ #[cfg(feature = "telemetry")] use opentelemetry_contrib::trace::propagator::binary::*; -use futures::io::AsyncReadExt; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_client, BoxStream}; - -use crate::endpoint::*; use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; use crate::util::*; pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: - ArcSwapOption>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, inflight: Mutex>>>, diff --git a/src/endpoint.rs b/src/endpoint.rs index f31141d..e6b2236 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,79 +5,11 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - use crate::error::Error; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; use crate::util::*; -/// This trait should be implemented by all messages your application -/// wants to handle -pub trait Message: SerializeMessage + Send + Sync { - type Response: SerializeMessage + Send + Sync; -} - -/// A trait for de/serializing messages, with possible associated stream. -#[async_trait] -pub trait SerializeMessage: Sized { - type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - - fn serialize_msg(&self) -> (Self::SerializableSelf, Option); - - // TODO should return Result - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self; -} - -pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} - -#[async_trait] -impl SerializeMessage for T -where - T: AutoSerialize, -{ - type SerializableSelf = Self; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - (self.clone(), None) - } - - async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self { - // TODO verify no stream - ser_self - } -} - -impl AutoSerialize for () {} - -#[async_trait] -impl SerializeMessage for Result -where - T: SerializeMessage + Send, - E: SerializeMessage + Send, -{ - type SerializableSelf = Result; - - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - match self { - Ok(ok) => { - let (msg, stream) = ok.serialize_msg(); - (Ok(msg), stream) - } - Err(err) => { - let (msg, stream) = err.serialize_msg(); - (Err(msg), stream) - } - } - } - - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { - match ser_self { - Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), - Err(err) => Err(E::deserialize_msg(err, stream).await), - } - } -} - /// This trait should be implemented by an object of your application /// that can handle a message of type `M`. /// @@ -191,9 +123,9 @@ pub(crate) trait GenericEndpoint { async fn handle( &self, buf: &[u8], - stream: AssociatedStream, + stream: ByteStream, from: NodeID, - ) -> Result<(Vec, Option), Error>; + ) -> Result<(Vec, Option), Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -213,9 +145,9 @@ where async fn handle( &self, buf: &[u8], - stream: AssociatedStream, + stream: ByteStream, from: NodeID, - ) -> Result<(Vec, Option), Error> { + ) -> Result<(Vec, Option), Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { diff --git a/src/error.rs b/src/error.rs index 7911c29..665647c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ -use err_derive::Error; use std::io; +use err_derive::Error; use log::error; #[derive(Debug, Error)] diff --git a/src/lib.rs b/src/lib.rs index cb24337..1edb919 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,10 +17,11 @@ pub mod error; pub mod util; pub mod endpoint; -pub mod proto; +pub mod message; mod client; -mod proto2; +mod recv; +mod send; mod server; pub mod netapp; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..dbcc857 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,255 @@ +use async_trait::async_trait; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::error::*; +use crate::util::*; + +/// Priority of a request (click to read more about priorities). +/// +/// This priority value is used to priorize messages +/// in the send queue of the client, and their responses in the send queue of the +/// server. Lower values mean higher priority. +/// +/// This mechanism is usefull for messages bigger than the maximum chunk size +/// (set at `0x4000` bytes), such as large file transfers. +/// In such case, all of the messages in the send queue with the highest priority +/// will take turns to send individual chunks, in a round-robin fashion. +/// Once all highest priority messages are sent successfully, the messages with +/// the next highest priority will begin being sent in the same way. +/// +/// The same priority value is given to a request and to its associated response. +pub type RequestPriority = u8; + +/// Priority class: high +pub const PRIO_HIGH: RequestPriority = 0x20; +/// Priority class: normal +pub const PRIO_NORMAL: RequestPriority = 0x40; +/// Priority class: background +pub const PRIO_BACKGROUND: RequestPriority = 0x80; +/// Priority: primary among given class +pub const PRIO_PRIMARY: RequestPriority = 0x00; +/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) +pub const PRIO_SECONDARY: RequestPriority = 0x01; + +// ---- + +/// This trait should be implemented by all messages your application +/// wants to handle +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +} + +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + +#[async_trait] +impl SerializeMessage for T +where + T: AutoSerialize, +{ + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + (self.clone(), None) + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { + // TODO verify no stream + ser_self + } +} + +impl AutoSerialize for () {} + +#[async_trait] +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), + } + } +} + +// ---- + +pub(crate) struct QueryMessage<'a> { + pub(crate) prio: RequestPriority, + pub(crate) path: &'a [u8], + pub(crate) telemetry_id: Option>, + pub(crate) body: &'a [u8], +} + +/// QueryMessage encoding: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - body [u8; ..] +impl<'a> QueryMessage<'a> { + pub(crate) fn encode(self) -> Vec { + let tel_len = match &self.telemetry_id { + Some(t) => t.len(), + None => 0, + }; + + let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + + ret.push(self.prio); + + ret.push(self.path.len() as u8); + ret.extend_from_slice(self.path); + + if let Some(t) = self.telemetry_id { + ret.push(t.len() as u8); + ret.extend(t); + } else { + ret.push(0u8); + } + + ret.extend_from_slice(self.body); + + ret + } + + pub(crate) fn decode(bytes: &'a [u8]) -> Result { + if bytes.len() < 3 { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path_length = bytes[1] as usize; + if bytes.len() < 3 + path_length { + return Err(Error::Message("Invalid protocol message".into())); + } + + let telemetry_id_len = bytes[2 + path_length] as usize; + if bytes.len() < 3 + path_length + telemetry_id_len { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path = &bytes[2..2 + path_length]; + let telemetry_id = if telemetry_id_len > 0 { + Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) + } else { + None + }; + + let body = &bytes[3 + path_length + telemetry_id_len..]; + + Ok(Self { + prio: bytes[0], + path, + telemetry_id, + body, + }) + } +} + +pub(crate) struct Framing { + direct: Vec, + stream: Option, +} + +impl Framing { + pub fn new(direct: Vec, stream: Option) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } + } + + pub fn into_stream(self) -> ByteStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; + + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); + + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) + } + } + + pub async fn from_stream + Unpin + Send + 'static>( + mut stream: S, + ) -> Result { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + } + + let stream: ByteStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec, ByteStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + } +} diff --git a/src/netapp.rs b/src/netapp.rs index 27f17e6..dd22d90 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -20,7 +20,7 @@ use tokio::sync::{mpsc, watch}; use crate::client::*; use crate::endpoint::*; use crate::error::*; -use crate::proto::*; +use crate::message::*; use crate::server::*; use crate::util::*; diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 7f77995..98977a3 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -14,8 +14,9 @@ use sodiumoxide::crypto::hash; use tokio::sync::watch; use crate::endpoint::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; +use crate::send::*; use crate::NodeID; // -- Protocol messages -- diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 7dfc5c4..5b489ae 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -17,7 +17,8 @@ use sodiumoxide::crypto::hash; use crate::endpoint::*; use crate::error::*; use crate::netapp::*; -use crate::proto::*; + +use crate::message::*; use crate::NodeID; const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30); diff --git a/src/proto.rs b/src/proto.rs deleted file mode 100644 index 92d8d80..0000000 --- a/src/proto.rs +++ /dev/null @@ -1,617 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use log::trace; - -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::{AsyncReadExt, AsyncWriteExt}; -use futures::{Stream, StreamExt}; -use kuska_handshake::async_std::BoxStreamWrite; - -use tokio::sync::mpsc; - -use async_trait::async_trait; - -use crate::error::*; -use crate::util::{AssociatedStream, Packet}; - -/// Priority of a request (click to read more about priorities). -/// -/// This priority value is used to priorize messages -/// in the send queue of the client, and their responses in the send queue of the -/// server. Lower values mean higher priority. -/// -/// This mechanism is usefull for messages bigger than the maximum chunk size -/// (set at `0x4000` bytes), such as large file transfers. -/// In such case, all of the messages in the send queue with the highest priority -/// will take turns to send individual chunks, in a round-robin fashion. -/// Once all highest priority messages are sent successfully, the messages with -/// the next highest priority will begin being sent in the same way. -/// -/// The same priority value is given to a request and to its associated response. -pub type RequestPriority = u8; - -/// Priority class: high -pub const PRIO_HIGH: RequestPriority = 0x20; -/// Priority class: normal -pub const PRIO_NORMAL: RequestPriority = 0x40; -/// Priority class: background -pub const PRIO_BACKGROUND: RequestPriority = 0x80; -/// Priority: primary among given class -pub const PRIO_PRIMARY: RequestPriority = 0x00; -/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) -pub const PRIO_SECONDARY: RequestPriority = 0x01; - -// Messages are sent by chunks -// Chunk format: -// - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data - -pub(crate) type RequestID = u32; -type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; -const ERROR_MARKER: ChunkLength = 0x4000; -const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; - -struct SendQueueItem { - id: RequestID, - prio: RequestPriority, - data: DataReader, -} - -#[pin_project::pin_project] -struct DataReader { - #[pin] - reader: AssociatedStream, - packet: Packet, - pos: usize, - buf: Vec, - eos: bool, -} - -impl From for DataReader { - fn from(data: AssociatedStream) -> DataReader { - DataReader { - reader: data, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - } - } -} - -enum DataFrame { - Data { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actual lenght of data - len: usize, - }, - Error(u8), -} - -struct DataReaderItem { - data: DataFrame, - /// whethere there may be more data comming from this stream. Can be used for some - /// optimization. It's an error to set it to false if there is more data, but it is correct - /// (albeit sub-optimal) to set it to true if there is nothing coming after - may_have_more: bool, -} - -impl DataReaderItem { - fn empty_last() -> Self { - DataReaderItem { - data: DataFrame::Data { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, - }, - may_have_more: false, - } - } - - fn header(&self) -> [u8; 2] { - let continuation = if self.may_have_more { - CHUNK_HAS_CONTINUATION - } else { - 0 - }; - let len = match self.data { - DataFrame::Data { len, .. } => len as u16, - DataFrame::Error(e) => e as u16 | ERROR_MARKER, - }; - - ChunkLength::to_be_bytes(len | continuation) - } - - fn data(&self) -> &[u8] { - match self.data { - DataFrame::Data { ref data, len } => &data[..len], - DataFrame::Error(_) => &[], - } - } -} - -impl Stream for DataReader { - type Item = DataReaderItem; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - if *this.eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - - loop { - let packet = match this.packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *this.packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *this.pos; - let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - this.buf - .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); - *this.pos += to_read; - if this.buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { - *this.packet = p; - *this.pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if this.packet.is_err() && !this.buf.is_empty() { - break; - } - } else { - *this.eos = true; - break; - } - } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = this.buf.len(); - body[..len].copy_from_slice(this.buf); - this.buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*this.eos, - })) - } -} - -struct SendQueue { - items: VecDeque<(u8, VecDeque)>, -} - -impl SendQueue { - fn new() -> Self { - Self { - items: VecDeque::with_capacity(64), - } - } - fn push(&mut self, item: SendQueueItem) { - let prio = item.prio; - let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { - Ok(i) => i, - Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); - i - } - }; - self.items[pos_prio].1.push_back(item); - } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] - fn pop(&mut self) -> Option { - match self.items.pop_front() { - None => None, - Some((prio, mut items_at_prio)) => { - let ret = items_at_prio.pop_front(); - if !items_at_prio.is_empty() { - self.items.push_front((prio, items_at_prio)); - } - ret.or_else(|| self.pop()) - } - } - } - fn is_empty(&self) -> bool { - self.items.iter().all(|(_k, v)| v.is_empty()) - } - - // this is like an async fn, but hand implemented - fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { - SendQueuePollNextReady { queue: self } - } -} - -struct SendQueuePollNextReady<'a> { - queue: &'a mut SendQueue, -} - -impl<'a> futures::Future for SendQueuePollNextReady<'a> { - type Output = (RequestID, DataReaderItem); - - fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - for i in 0..self.queue.items.len() { - let (_prio, items_at_prio) = &mut self.queue.items[i]; - - for _ in 0..items_at_prio.len() { - let mut item = items_at_prio.pop_front().unwrap(); - - match Pin::new(&mut item.data).poll_next(ctx) { - Poll::Pending => items_at_prio.push_back(item), - Poll::Ready(Some(data)) => { - let id = item.id; - if data.may_have_more { - self.queue.push(item); - } else { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - } - return Poll::Ready((id, data)); - } - Poll::Ready(None) => { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - return Poll::Ready((item.id, DataReaderItem::empty_last())); - } - } - } - } - // TODO what do we do if self.queue is empty? We won't get scheduled again. - Poll::Pending - } -} - -/// The SendLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` -/// that takes a channel of messages to send and an asynchronous writer, -/// and sends messages from the channel to the async writer, putting them in a queue -/// before being sent and doing the round-robin sending strategy. -/// -/// The `.send_loop()` exits when the sending end of the channel is closed, -/// or if there is an error at any time writing to the async writer. -#[async_trait] -pub(crate) trait SendLoop: Sync { - async fn send_loop( - self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, - mut write: BoxStreamWrite, - ) -> Result<(), Error> - where - W: AsyncWriteExt + Unpin + Send + Sync, - { - let mut sending = SendQueue::new(); - let mut should_exit = false; - while !should_exit || !sending.is_empty() { - let recv_fut = msg_recv.recv(); - futures::pin_mut!(recv_fut); - let send_fut = sending.next_ready(); - - // recv_fut is cancellation-safe according to tokio doc, - // send_fut is cancellation-safe as implemented above? - use futures::future::Either; - match futures::future::select(recv_fut, send_fut).await { - Either::Left((sth, _send_fut)) => { - if let Some((id, prio, data)) = sth { - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; - }; - } - Either::Right(((id, data), _recv_fut)) => { - trace!("send_loop: sending bytes for {}", id); - - let header_id = RequestID::to_be_bytes(id); - write.write_all(&header_id[..]).await?; - - write.write_all(&data.header()).await?; - write.write_all(data.data()).await?; - write.flush().await?; - } - } - } - - let _ = write.goodbye().await; - Ok(()) - } -} - -pub(crate) struct Framing { - direct: Vec, - stream: Option, -} - -impl Framing { - pub fn new(direct: Vec, stream: Option) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } - } - - pub fn into_stream(self) -> AssociatedStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; - - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) - .chain(stream::once(async move { Ok(direct) })); - - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } - - pub async fn from_stream + Unpin + Send + 'static>( - mut stream: S, - ) -> Result { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); - } - - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet.drain(..4); - - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); - - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet.drain(..max_cp); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } - - let stream: AssociatedStream = if packet.is_empty() { - Box::pin(stream) - } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; - - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec, AssociatedStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) - } -} - -/// Structure to warn when the sender is dropped before end of stream was reached, like when -/// connection to some remote drops while transmitting data -struct Sender { - inner: UnboundedSender, - closed: bool, -} - -impl Sender { - fn new(inner: UnboundedSender) -> Self { - Sender { - inner, - closed: false, - } - } - - fn send(&self, packet: Packet) { - let _ = self.inner.unbounded_send(packet); - } - - fn end(&mut self) { - self.closed = true; - } -} - -impl Drop for Sender { - fn drop(&mut self) { - if !self.closed { - self.send(Err(255)); - } - self.inner.close_channel(); - } -} - -/// The RecvLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` -/// and a prototype of a handler for received messages `.recv_handler()` that -/// must be filled by implementors. `.recv_loop()` receives messages in a loop -/// according to the protocol defined above: chunks of message in progress of being -/// received are stored in a buffer, and when the last chunk of a message is received, -/// the full message is passed to the receive handler. -#[async_trait] -pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); - - async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> - where - R: AsyncReadExt + Unpin + Send + Sync, - { - let mut streams: HashMap = HashMap::new(); - loop { - trace!("recv_loop: reading packet"); - let mut header_id = [0u8; RequestID::BITS as usize / 8]; - match read.read_exact(&mut header_id[..]).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e.into()), - }; - let id = RequestID::from_be_bytes(header_id); - trace!("recv_loop: got header id: {:04x}", id); - - let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; - read.read_exact(&mut header_size[..]).await?; - let size = ChunkLength::from_be_bytes(header_size); - trace!("recv_loop: got header size: {:04x}", size); - - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let is_error = (size & ERROR_MARKER) != 0; - let packet = if is_error { - Err(size as u8) - } else { - let size = size & !CHUNK_HAS_CONTINUATION; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - Ok(next_slice) - }; - - let mut sender = if let Some(send) = streams.remove(&(id)) { - send - } else { - let (send, recv) = unbounded(); - self.recv_handler(id, recv); - Sender::new(send) - }; - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - sender.send(packet); - - if has_cont { - streams.insert(id, sender); - } else { - sender.end(); - } - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn empty_data() -> DataReader { - type Item = Packet; - let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::()); - stream.into() - } - - #[test] - fn test_priority_queue() { - let i1 = SendQueueItem { - id: 1, - prio: PRIO_NORMAL, - data: empty_data(), - }; - let i2 = SendQueueItem { - id: 2, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i2bis = SendQueueItem { - id: 20, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i3 = SendQueueItem { - id: 3, - prio: PRIO_HIGH | PRIO_SECONDARY, - data: empty_data(), - }; - let i4 = SendQueueItem { - id: 4, - prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: empty_data(), - }; - let i5 = SendQueueItem { - id: 5, - prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: empty_data(), - }; - - let mut q = SendQueue::new(); - - q.push(i1); // 1 - let a = q.pop().unwrap(); // empty -> 1 - assert_eq!(a.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 1 - q.push(i2); // 2 1 - q.push(i2bis); // [2 20] 1 - let a = q.pop().unwrap(); // 20 1 -> 2 - assert_eq!(a.id, 2); - let b = q.pop().unwrap(); // 1 -> 20 - assert_eq!(b.id, 20); - let c = q.pop().unwrap(); // empty -> 1 - assert_eq!(c.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 2 - q.push(b); // [2 20] - q.push(c); // [2 20] 1 - q.push(i3); // [2 20] 3 1 - q.push(i4); // [2 20] 3 1 4 - q.push(i5); // [2 20] 3 1 5 4 - - let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 - assert_eq!(a.id, 2); - q.push(a); // [20 2] 3 1 5 4 - - let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 - assert_eq!(a.id, 20); - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - q.push(b); // 2 3 1 5 4 - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - let c = q.pop().unwrap(); // 1 5 4 -> 3 - assert_eq!(c.id, 3); - q.push(b); // 2 1 5 4 - let b = q.pop().unwrap(); // 1 5 4 -> 2 - assert_eq!(b.id, 2); - let e = q.pop().unwrap(); // 5 4 -> 1 - assert_eq!(e.id, 1); - let f = q.pop().unwrap(); // 4 -> 5 - assert_eq!(f.id, 5); - let g = q.pop().unwrap(); // empty -> 4 - assert_eq!(g.id, 4); - assert!(q.pop().is_none()); - } -} diff --git a/src/proto2.rs b/src/proto2.rs deleted file mode 100644 index 7210781..0000000 --- a/src/proto2.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::error::*; -use crate::proto::*; - -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option>, - pub(crate) body: &'a [u8], -} - -/// QueryMessage encoding: -/// - priority: u8 -/// - path length: u8 -/// - path: [u8; path length] -/// - telemetry id length: u8 -/// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; - - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); - - ret.push(self.prio); - - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); - - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); - } else { - ret.push(0u8); - } - - ret.extend_from_slice(self.body); - - ret - } - - pub(crate) fn decode(bytes: &'a [u8]) -> Result { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } - - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; - - let body = &bytes[3 + path_length + telemetry_id_len..]; - - Ok(Self { - prio: bytes[0], - path, - telemetry_id, - body, - }) - } -} diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..628612b --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use std::sync::Arc; + +use log::trace; + +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::AsyncReadExt; + +use async_trait::async_trait; + +use crate::error::*; + +use crate::send::*; +use crate::util::Packet; + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); + + async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> + where + R: AsyncReadExt + Unpin + Send + Sync, + { + let mut streams: HashMap = HashMap::new(); + loop { + trace!("recv_loop: reading packet"); + let mut header_id = [0u8; RequestID::BITS as usize / 8]; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; + let id = RequestID::from_be_bytes(header_id); + trace!("recv_loop: got header id: {:04x}", id); + + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; + read.read_exact(&mut header_size[..]).await?; + let size = ChunkLength::from_be_bytes(header_size); + trace!("recv_loop: got header size: {:04x}", size); + + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; + let is_error = (size & ERROR_MARKER) != 0; + let packet = if is_error { + Err(size as u8) + } else { + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) + }; + + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = unbounded(); + self.recv_handler(id, recv); + Sender::new(send) + }; + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + sender.send(packet); + + if has_cont { + streams.insert(id, sender); + } else { + sender.end(); + } + } + Ok(()) + } +} diff --git a/src/send.rs b/src/send.rs new file mode 100644 index 0000000..330d41d --- /dev/null +++ b/src/send.rs @@ -0,0 +1,410 @@ +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use log::trace; + +use futures::AsyncWriteExt; +use futures::Stream; +use kuska_handshake::async_std::BoxStreamWrite; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::message::*; +use crate::util::{ByteStream, Packet}; + +// Messages are sent by chunks +// Chunk format: +// - u32 BE: request id (same for request and response) +// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag +// when this is not the last chunk of the message +// - [u8; chunk_length] chunk data + +pub(crate) type RequestID = u32; +pub(crate) type ChunkLength = u16; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; +pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; + +struct SendQueueItem { + id: RequestID, + prio: RequestPriority, + data: DataReader, +} + +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: ByteStream, + packet: Packet, + pos: usize, + buf: Vec, + eos: bool, +} + +impl From for DataReader { + fn from(data: ByteStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, + } + } +} + +enum DataFrame { + Data { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actual lenght of data + len: usize, + }, + Error(u8), +} + +struct DataReaderItem { + data: DataFrame, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, +} + +impl DataReaderItem { + fn empty_last() -> Self { + DataReaderItem { + data: DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + }, + may_have_more: false, + } + } + + fn header(&self) -> [u8; 2] { + let continuation = if self.may_have_more { + CHUNK_HAS_CONTINUATION + } else { + 0 + }; + let len = match self.data { + DataFrame::Data { len, .. } => len as u16, + DataFrame::Error(e) => e as u16 | ERROR_MARKER, + }; + + ChunkLength::to_be_bytes(len | continuation) + } + + fn data(&self) -> &[u8] { + match self.data { + DataFrame::Data { ref data, len } => &data[..len], + DataFrame::Error(_) => &[], + } + } +} + +impl Stream for DataReader { + type Item = DataReaderItem; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); + } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; + } + + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; + } + } else { + *this.eos = true; + break; + } + } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) + } +} + +struct SendQueue { + items: VecDeque<(u8, VecDeque)>, +} + +impl SendQueue { + fn new() -> Self { + Self { + items: VecDeque::with_capacity(64), + } + } + fn push(&mut self, item: SendQueueItem) { + let prio = item.prio; + let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { + Ok(i) => i, + Err(i) => { + self.items.insert(i, (prio, VecDeque::new())); + i + } + }; + self.items[pos_prio].1.push_back(item); + } + // used only in tests. They should probably be rewriten + #[allow(dead_code)] + fn pop(&mut self) -> Option { + match self.items.pop_front() { + None => None, + Some((prio, mut items_at_prio)) => { + let ret = items_at_prio.pop_front(); + if !items_at_prio.is_empty() { + self.items.push_front((prio, items_at_prio)); + } + ret.or_else(|| self.pop()) + } + } + } + fn is_empty(&self) -> bool { + self.items.iter().all(|(_k, v)| v.is_empty()) + } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataReaderItem); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for i in 0..self.queue.items.len() { + let (_prio, items_at_prio) = &mut self.queue.items[i]; + + for _ in 0..items_at_prio.len() { + let mut item = items_at_prio.pop_front().unwrap(); + + match Pin::new(&mut item.data).poll_next(ctx) { + Poll::Pending => items_at_prio.push_back(item), + Poll::Ready(Some(data)) => { + let id = item.id; + if data.may_have_more { + self.queue.push(item); + } else { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + } + return Poll::Ready((id, data)); + } + Poll::Ready(None) => { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + return Poll::Ready((item.id, DataReaderItem::empty_last())); + } + } + } + } + // TODO what do we do if self.queue is empty? We won't get scheduled again. + Poll::Pending + } +} + +/// The SendLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` +/// that takes a channel of messages to send and an asynchronous writer, +/// and sends messages from the channel to the async writer, putting them in a queue +/// before being sent and doing the round-robin sending strategy. +/// +/// The `.send_loop()` exits when the sending end of the channel is closed, +/// or if there is an error at any time writing to the async writer. +#[async_trait] +pub(crate) trait SendLoop: Sync { + async fn send_loop( + self: Arc, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>, + mut write: BoxStreamWrite, + ) -> Result<(), Error> + where + W: AsyncWriteExt + Unpin + Send + Sync, + { + let mut sending = SendQueue::new(); + let mut should_exit = false; + while !should_exit || !sending.is_empty() { + let recv_fut = msg_recv.recv(); + futures::pin_mut!(recv_fut); + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + use futures::future::Either; + match futures::future::select(recv_fut, send_fut).await { + Either::Left((sth, _send_fut)) => { + if let Some((id, prio, data)) = sth { + sending.push(SendQueueItem { + id, + prio, + data: data.into(), + }); + } else { + should_exit = true; + }; + } + Either::Right(((id, data), _recv_fut)) => { + trace!("send_loop: sending bytes for {}", id); + + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; + write.flush().await?; + } + } + } + + let _ = write.goodbye().await; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn empty_data() -> DataReader { + type Item = Packet; + let stream: Pin + Send + 'static>> = + Box::pin(futures::stream::empty::()); + stream.into() + } + + #[test] + fn test_priority_queue() { + let i1 = SendQueueItem { + id: 1, + prio: PRIO_NORMAL, + data: empty_data(), + }; + let i2 = SendQueueItem { + id: 2, + prio: PRIO_HIGH, + data: empty_data(), + }; + let i2bis = SendQueueItem { + id: 20, + prio: PRIO_HIGH, + data: empty_data(), + }; + let i3 = SendQueueItem { + id: 3, + prio: PRIO_HIGH | PRIO_SECONDARY, + data: empty_data(), + }; + let i4 = SendQueueItem { + id: 4, + prio: PRIO_BACKGROUND | PRIO_SECONDARY, + data: empty_data(), + }; + let i5 = SendQueueItem { + id: 5, + prio: PRIO_BACKGROUND | PRIO_PRIMARY, + data: empty_data(), + }; + + let mut q = SendQueue::new(); + + q.push(i1); // 1 + let a = q.pop().unwrap(); // empty -> 1 + assert_eq!(a.id, 1); + assert!(q.pop().is_none()); + + q.push(a); // 1 + q.push(i2); // 2 1 + q.push(i2bis); // [2 20] 1 + let a = q.pop().unwrap(); // 20 1 -> 2 + assert_eq!(a.id, 2); + let b = q.pop().unwrap(); // 1 -> 20 + assert_eq!(b.id, 20); + let c = q.pop().unwrap(); // empty -> 1 + assert_eq!(c.id, 1); + assert!(q.pop().is_none()); + + q.push(a); // 2 + q.push(b); // [2 20] + q.push(c); // [2 20] 1 + q.push(i3); // [2 20] 3 1 + q.push(i4); // [2 20] 3 1 4 + q.push(i5); // [2 20] 3 1 5 4 + + let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 + assert_eq!(a.id, 2); + q.push(a); // [20 2] 3 1 5 4 + + let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 + assert_eq!(a.id, 20); + let b = q.pop().unwrap(); // 3 1 5 4 -> 2 + assert_eq!(b.id, 2); + q.push(b); // 2 3 1 5 4 + let b = q.pop().unwrap(); // 3 1 5 4 -> 2 + assert_eq!(b.id, 2); + let c = q.pop().unwrap(); // 1 5 4 -> 3 + assert_eq!(c.id, 3); + q.push(b); // 2 1 5 4 + let b = q.pop().unwrap(); // 1 5 4 -> 2 + assert_eq!(b.id, 2); + let e = q.pop().unwrap(); // 5 4 -> 1 + assert_eq!(e.id, 1); + let f = q.pop().unwrap(); // 4 -> 5 + assert_eq!(f.id, 5); + let g = q.pop().unwrap(); // empty -> 4 + assert_eq!(g.id, 4); + assert!(q.pop().is_none()); + } +} diff --git a/src/server.rs b/src/server.rs index 8075484..1f1c22a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,8 +2,17 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; +use async_trait::async_trait; use log::{debug, trace}; +use futures::channel::mpsc::UnboundedReceiver; +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::{handshake_server, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, watch}; +use tokio_util::compat::*; + #[cfg(feature = "telemetry")] use opentelemetry::{ trace::{FutureExt, Span, SpanKind, TraceContextExt, TraceId, Tracer}, @@ -14,22 +23,11 @@ use opentelemetry_contrib::trace::propagator::binary::*; #[cfg(feature = "telemetry")] use rand::{thread_rng, Rng}; -use tokio::net::TcpStream; -use tokio::select; -use tokio::sync::{mpsc, watch}; -use tokio_util::compat::*; - -use futures::channel::mpsc::UnboundedReceiver; -use futures::io::{AsyncReadExt, AsyncWriteExt}; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_server, BoxStream}; - use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -55,7 +53,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -126,8 +124,8 @@ impl ServerConn { async fn recv_handler_aux( self: &Arc, bytes: &[u8], - stream: AssociatedStream, - ) -> Result<(Vec, Option), Error> { + stream: ByteStream, + ) -> Result<(Vec, Option), Error> { let msg = QueryMessage::decode(bytes)?; let path = String::from_utf8(msg.path.to_vec())?; diff --git a/src/util.rs b/src/util.rs index 186678d..6fbafe6 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,17 +1,15 @@ -use crate::endpoint::SerializeMessage; - use std::net::SocketAddr; use std::net::ToSocketAddrs; use std::pin::Pin; -use futures::Stream; - use log::info; - use serde::Serialize; +use futures::Stream; use tokio::sync::watch; +use crate::message::SerializeMessage; + /// A node's identifier, which is also its public cryptographic key pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; /// A node's secret key @@ -27,7 +25,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// Error code 255 means the stream was cut before its end. Other codes have no predefined /// meaning, it's up to your application to define their semantic. -pub type AssociatedStream = Pin + Send>>; +pub type ByteStream = Pin + Send>>; pub type Packet = Result, u8>; @@ -38,7 +36,7 @@ pub type Packet = Result, u8>; /// This is used internally by the netapp communication protocol. pub fn rmp_to_vec_all_named( val: &T, -) -> Result<(Vec, Option), rmp_serde::encode::Error> +) -> Result<(Vec, Option), rmp_serde::encode::Error> where T: SerializeMessage + ?Sized, { -- cgit v1.2.3 From 9dffa812c43470ee8a29c23c3a1be73085e25843 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 17:59:15 +0200 Subject: Refactor send.rs --- src/recv.rs | 5 +-- src/send.rs | 116 +++++++++++++++++++++++++++++++----------------------------- 2 files changed, 61 insertions(+), 60 deletions(-) (limited to 'src') diff --git a/src/recv.rs b/src/recv.rs index 628612b..f5221e6 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -1,16 +1,13 @@ use std::collections::HashMap; - use std::sync::Arc; +use async_trait::async_trait; use log::trace; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::AsyncReadExt; -use async_trait::async_trait; - use crate::error::*; - use crate::send::*; use crate::util::Packet; diff --git a/src/send.rs b/src/send.rs index 330d41d..0179eb2 100644 --- a/src/send.rs +++ b/src/send.rs @@ -24,6 +24,7 @@ use crate::util::{ByteStream, Packet}; pub(crate) type RequestID = u32; pub(crate) type ChunkLength = u16; + pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; @@ -62,53 +63,58 @@ enum DataFrame { data: [u8; MAX_CHUNK_LENGTH as usize], /// actual lenght of data len: usize, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, }, + /// An error code automatically signals the end of the stream Error(u8), } -struct DataReaderItem { - data: DataFrame, - /// whethere there may be more data comming from this stream. Can be used for some - /// optimization. It's an error to set it to false if there is more data, but it is correct - /// (albeit sub-optimal) to set it to true if there is nothing coming after - may_have_more: bool, -} - -impl DataReaderItem { +impl DataFrame { fn empty_last() -> Self { - DataReaderItem { - data: DataFrame::Data { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, - }, + DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, may_have_more: false, } } fn header(&self) -> [u8; 2] { - let continuation = if self.may_have_more { - CHUNK_HAS_CONTINUATION - } else { - 0 - }; - let len = match self.data { - DataFrame::Data { len, .. } => len as u16, - DataFrame::Error(e) => e as u16 | ERROR_MARKER, + let header_u16 = match self { + DataFrame::Data { + len, + may_have_more: false, + .. + } => *len as u16, + DataFrame::Data { + len, + may_have_more: true, + .. + } => *len as u16 | CHUNK_HAS_CONTINUATION, + DataFrame::Error(e) => *e as u16 | ERROR_MARKER, }; - - ChunkLength::to_be_bytes(len | continuation) + ChunkLength::to_be_bytes(header_u16) } fn data(&self) -> &[u8] { - match self.data { - DataFrame::Data { ref data, len } => &data[..len], + match self { + DataFrame::Data { ref data, len, .. } => &data[..*len], DataFrame::Error(_) => &[], } } + + fn may_have_more(&self) -> bool { + match self { + DataFrame::Data { may_have_more, .. } => *may_have_more, + DataFrame::Error(_) => false, + } + } } impl Stream for DataReader { - type Item = DataReaderItem; + type Item = DataFrame; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); @@ -125,10 +131,8 @@ impl Stream for DataReader { Err(e) => { let e = *e; *this.packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); + *this.eos = true; + return Poll::Ready(Some(DataFrame::Error(e))); } }; let packet_left = packet.len() - *this.pos; @@ -161,8 +165,9 @@ impl Stream for DataReader { let len = this.buf.len(); body[..len].copy_from_slice(this.buf); this.buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, + Poll::Ready(Some(DataFrame::Data { + data: body, + len, may_have_more: !*this.eos, })) } @@ -218,38 +223,37 @@ struct SendQueuePollNextReady<'a> { } impl<'a> futures::Future for SendQueuePollNextReady<'a> { - type Output = (RequestID, DataReaderItem); + type Output = (RequestID, DataFrame); fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { for i in 0..self.queue.items.len() { let (_prio, items_at_prio) = &mut self.queue.items[i]; - for _ in 0..items_at_prio.len() { - let mut item = items_at_prio.pop_front().unwrap(); - + let mut ready_item = None; + for (j, item) in items_at_prio.iter_mut().enumerate() { match Pin::new(&mut item.data).poll_next(ctx) { - Poll::Pending => items_at_prio.push_back(item), - Poll::Ready(Some(data)) => { - let id = item.id; - if data.may_have_more { - self.queue.push(item); - } else { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - } - return Poll::Ready((id, data)); - } - Poll::Ready(None) => { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - return Poll::Ready((item.id, DataReaderItem::empty_last())); + Poll::Pending => (), + Poll::Ready(ready_v) => { + ready_item = Some((j, ready_v)); + break; } } } + + if let Some((j, ready_v)) = ready_item { + let item = items_at_prio.remove(j).unwrap(); + let id = item.id; + if ready_v + .as_ref() + .map(|data| data.may_have_more()) + .unwrap_or(false) + { + items_at_prio.push_back(item); + } else if items_at_prio.is_empty() { + self.queue.items.remove(i); + } + return Poll::Ready((id, ready_v.unwrap_or_else(DataFrame::empty_last))); + } } // TODO what do we do if self.queue is empty? We won't get scheduled again. Poll::Pending -- cgit v1.2.3 From 26989bba1409bfc093e58ef98e75885b10ad7c1c Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 18:15:07 +0200 Subject: Use Bytes instead of Vec --- src/message.rs | 8 ++++---- src/recv.rs | 3 ++- src/send.rs | 5 +++-- src/util.rs | 3 ++- 4 files changed, 11 insertions(+), 8 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index dbcc857..6d50254 100644 --- a/src/message.rs +++ b/src/message.rs @@ -192,8 +192,8 @@ impl Framing { // required because otherwise the borrow-checker complains let Framing { direct, stream } = self; - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) - .chain(stream::once(async move { Ok(direct) })); + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) }) + .chain(stream::once(async move { Ok(direct.into()) })); if let Some(stream) = stream { Box::pin(res.chain(stream)) @@ -217,7 +217,7 @@ impl Framing { let mut len = [0; 4]; len.copy_from_slice(&packet[..4]); let len = u32::from_be_bytes(len); - packet.drain(..4); + packet = packet.slice(4..); let mut buffer = Vec::new(); let len = len as usize; @@ -226,7 +226,7 @@ impl Framing { buffer.extend_from_slice(&packet[..max_cp]); if buffer.len() == len { - packet.drain(..max_cp); + packet = packet.slice(max_cp..); break; } packet = stream diff --git a/src/recv.rs b/src/recv.rs index f5221e6..abe7b9a 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; +use bytes::Bytes; use log::trace; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; @@ -85,7 +86,7 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", next_slice.len()); - Ok(next_slice) + Ok(Bytes::from(next_slice)) }; let mut sender = if let Some(send) = streams.remove(&(id)) { diff --git a/src/send.rs b/src/send.rs index 0179eb2..660e85c 100644 --- a/src/send.rs +++ b/src/send.rs @@ -3,6 +3,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use bytes::Bytes; use async_trait::async_trait; use log::trace; @@ -49,7 +50,7 @@ impl From for DataReader { fn from(data: ByteStream) -> DataReader { DataReader { reader: data, - packet: Ok(Vec::new()), + packet: Ok(Bytes::new()), pos: 0, buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), eos: false, @@ -130,7 +131,7 @@ impl Stream for DataReader { Ok(v) => v, Err(e) => { let e = *e; - *this.packet = Ok(Vec::new()); + *this.packet = Ok(Bytes::new()); *this.eos = true; return Poll::Ready(Some(DataFrame::Error(e))); } diff --git a/src/util.rs b/src/util.rs index 6fbafe6..e81a89c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,6 +4,7 @@ use std::pin::Pin; use log::info; use serde::Serialize; +use bytes::Bytes; use futures::Stream; use tokio::sync::watch; @@ -27,7 +28,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// meaning, it's up to your application to define their semantic. pub type ByteStream = Pin + Send>>; -pub type Packet = Result, u8>; +pub type Packet = Result; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. -- cgit v1.2.3 From 44bbc1c00c2532e08dff0d4a547b0a707e89f32d Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 19:05:51 +0200 Subject: Rename AutoSerialize into SimpleMessage and refactor a bit --- src/client.rs | 10 ++-- src/endpoint.rs | 22 ++++----- src/message.rs | 118 ++++++++++++++++++++++++++++++++++-------------- src/netapp.rs | 6 +-- src/peering/fullmesh.rs | 12 ++--- src/send.rs | 2 +- src/util.rs | 11 ++--- 7 files changed, 113 insertions(+), 68 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 663a3e4..cf80746 100644 --- a/src/client.rs +++ b/src/client.rs @@ -134,15 +134,14 @@ impl ClientConn { self.query_send.store(None); } - pub(crate) async fn call( + pub(crate) async fn call( self: Arc, - rq: B, + rq: T, path: &str, prio: RequestPriority, ) -> Result<::Response, Error> where T: Message, - B: Borrow, { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; @@ -164,7 +163,8 @@ impl ClientConn { }; // Encode request - let (body, stream) = rmp_to_vec_all_named(rq.borrow())?; + let (rq, stream) = rq.into_parts(); + let body = rmp_to_vec_all_named(&rq)?; drop(rq); let request = QueryMessage { @@ -217,7 +217,7 @@ impl ClientConn { let code = resp[0]; if code == 0 { let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; - Ok(T::Response::deserialize_msg(ser_resp, stream).await) + Ok(T::Response::from_parts(ser_resp, stream)) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index e6b2236..3f292d9 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -19,7 +19,7 @@ pub trait EndpointHandler: Send + Sync where M: Message, { - async fn handle(self: &Arc, m: &M, from: NodeID) -> M::Response; + async fn handle(self: &Arc, m: M, from: NodeID) -> M::Response; } /// If one simply wants to use an endpoint in a client fashion, @@ -28,7 +28,7 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl EndpointHandler for () { - async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { + async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response { panic!("This endpoint should not have a local handler."); } } @@ -81,19 +81,16 @@ where /// Call this endpoint on a remote node (or on the local node, /// for that matter) - pub async fn call( + pub async fn call( &self, target: &NodeID, - req: B, + req: M, prio: RequestPriority, - ) -> Result<::Response, Error> - where - B: Borrow + Send + Sync, - { + ) -> Result<::Response, Error> { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await), + Some(h) => Ok(h.handle(req, self.netapp.id).await), } } else { let conn = self @@ -152,10 +149,11 @@ where None => Err(Error::NoHandler), Some(h) => { let req = rmp_serde::decode::from_read_ref(buf)?; - let req = M::deserialize_msg(req, stream).await; - let res = h.handle(&req, from).await; + let req = M::from_parts(req, stream); + let res = h.handle(req, from).await; + let (res, res_stream) = res.into_parts(); let res_bytes = rmp_to_vec_all_named(&res)?; - Ok(res_bytes) + Ok((res_bytes, res_stream)) } } } diff --git a/src/message.rs b/src/message.rs index 6d50254..f92eb8c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; -use futures::stream::{Stream, StreamExt}; +use bytes::Bytes; use serde::{Deserialize, Serialize}; +use futures::stream::{Stream, StreamExt}; + use crate::error::*; use crate::util::*; @@ -41,66 +43,112 @@ pub trait Message: SerializeMessage + Send + Sync { } /// A trait for de/serializing messages, with possible associated stream. -#[async_trait] pub trait SerializeMessage: Sized { type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + fn into_parts(self) -> (Self::SerializableSelf, Option); + + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +} + +// ---- + +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: Serialize + for<'de> Deserialize<'de> + Send, +{ + type SerializableSelf = Result; + + fn into_parts(self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.into_parts(); + (Ok(msg), stream) + } + Err(err) => (Err(err), None), + } + } - // TODO should return Result - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::from_parts(ok, stream)), + Err(err) => Err(err), + } + } } -pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} +// --- + +pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} -#[async_trait] impl SerializeMessage for T where - T: AutoSerialize, + T: SimpleMessage, { type SerializableSelf = Self; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - (self.clone(), None) + fn into_parts(self) -> (Self::SerializableSelf, Option) { + (self, None) } - async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { + fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { // TODO verify no stream ser_self } } -impl AutoSerialize for () {} +impl SimpleMessage for () {} -#[async_trait] -impl SerializeMessage for Result +impl SimpleMessage for std::sync::Arc {} + +// ---- + +#[derive(Clone)] +pub struct WithFixedBody Deserialize<'de> + Clone + Send + 'static>( + pub T, + pub Bytes, +); + +impl SerializeMessage for WithFixedBody where - T: SerializeMessage + Send, - E: SerializeMessage + Send, + T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static, { - type SerializableSelf = Result; + type SerializableSelf = T; + + fn into_parts(self) -> (Self::SerializableSelf, Option) { + let body = self.1; + ( + self.0, + Some(Box::pin(futures::stream::once(async move { Ok(body) }))), + ) + } - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - match self { - Ok(ok) => { - let (msg, stream) = ok.serialize_msg(); - (Ok(msg), stream) - } - Err(err) => { - let (msg, stream) = err.serialize_msg(); - (Err(msg), stream) - } - } + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + panic!("Cannot reconstruct a WithFixedBody type from parts"); } +} - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { - match ser_self { - Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), - Err(err) => Err(E::deserialize_msg(err, stream).await), - } +pub struct WithStreamingBody Deserialize<'de> + Send>( + pub T, + pub ByteStream, +); + +impl SerializeMessage for WithStreamingBody +where + T: Serialize + for<'de> Deserialize<'de> + Send, +{ + type SerializableSelf = T; + + fn into_parts(self) -> (Self::SerializableSelf, Option) { + (self.0, Some(self.1)) + } + + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + WithStreamingBody(ser_self, stream) } } -// ---- +// ---- ---- pub(crate) struct QueryMessage<'a> { pub(crate) prio: RequestPriority, @@ -175,6 +223,8 @@ impl<'a> QueryMessage<'a> { } } +// ---- ---- + pub(crate) struct Framing { direct: Vec, stream: Option, diff --git a/src/netapp.rs b/src/netapp.rs index dd22d90..8365de0 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -38,7 +38,7 @@ pub(crate) struct HelloMessage { pub server_port: u16, } -impl AutoSerialize for HelloMessage {} +impl SimpleMessage for HelloMessage {} impl Message for HelloMessage { type Response = (); @@ -399,7 +399,7 @@ impl NetApp { hello_endpoint .call( &conn.peer_id, - &HelloMessage { + HelloMessage { server_addr, server_port, }, @@ -434,7 +434,7 @@ impl NetApp { #[async_trait] impl EndpointHandler for NetApp { - async fn handle(self: &Arc, msg: &HelloMessage, from: NodeID) { + async fn handle(self: &Arc, msg: HelloMessage, from: NodeID) { debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 5b489ae..3eeebb3 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -40,7 +40,7 @@ impl Message for PingMessage { type Response = PingMessage; } -impl AutoSerialize for PingMessage {} +impl SimpleMessage for PingMessage {} #[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { @@ -51,7 +51,7 @@ impl Message for PeerListMessage { type Response = PeerListMessage; } -impl AutoSerialize for PeerListMessage {} +impl SimpleMessage for PeerListMessage {} // -- Algorithm data structures -- @@ -379,7 +379,7 @@ impl FullMeshPeeringStrategy { ping_time ); let ping_response = select! { - r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r, + r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r, _ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())), }; @@ -431,7 +431,7 @@ impl FullMeshPeeringStrategy { let pex_message = PeerListMessage { list: peer_list }; match self .peer_list_endpoint - .call(id, &pex_message, PRIO_BACKGROUND) + .call(id, pex_message, PRIO_BACKGROUND) .await { Err(e) => warn!("Error doing peer exchange: {}", e), @@ -587,7 +587,7 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler for FullMeshPeeringStrategy { - async fn handle(self: &Arc, ping: &PingMessage, from: NodeID) -> PingMessage { + async fn handle(self: &Arc, ping: PingMessage, from: NodeID) -> PingMessage { let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, @@ -601,7 +601,7 @@ impl EndpointHandler for FullMeshPeeringStrategy { impl EndpointHandler for FullMeshPeeringStrategy { async fn handle( self: &Arc, - peer_list: &PeerListMessage, + peer_list: PeerListMessage, _from: NodeID, ) -> PeerListMessage { self.handle_peer_list(&peer_list.list[..]); diff --git a/src/send.rs b/src/send.rs index 660e85c..cc28d7c 100644 --- a/src/send.rs +++ b/src/send.rs @@ -3,8 +3,8 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use bytes::Bytes; use async_trait::async_trait; +use bytes::Bytes; use log::trace; use futures::AsyncWriteExt; diff --git a/src/util.rs b/src/util.rs index e81a89c..f860672 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,9 +2,9 @@ use std::net::SocketAddr; use std::net::ToSocketAddrs; use std::pin::Pin; +use bytes::Bytes; use log::info; use serde::Serialize; -use bytes::Bytes; use futures::Stream; use tokio::sync::watch; @@ -35,19 +35,16 @@ pub type Packet = Result; /// /// Field names and variant names are included in the serialization. /// This is used internally by the netapp communication protocol. -pub fn rmp_to_vec_all_named( - val: &T, -) -> Result<(Vec, Option), rmp_serde::encode::Error> +pub fn rmp_to_vec_all_named(val: &T) -> Result, rmp_serde::encode::Error> where - T: SerializeMessage + ?Sized, + T: Serialize + ?Sized, { let mut wr = Vec::with_capacity(128); let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - let (val, stream) = val.serialize_msg(); val.serialize(&mut se)?; - Ok((wr, stream)) + Ok(wr) } /// This async function returns only when a true signal was received -- cgit v1.2.3 From 7d148c7e764d563efa3bccc0f14f50867db38ef1 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 19:25:07 +0200 Subject: One possibility, but I don't like it --- src/client.rs | 1 - src/endpoint.rs | 1 - src/message.rs | 54 ++++++++++++++++--------------------------------- src/netapp.rs | 2 -- src/peering/basalt.rs | 9 ++++----- src/peering/fullmesh.rs | 4 ---- src/util.rs | 2 -- 7 files changed, 21 insertions(+), 52 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index cf80746..9d572aa 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,3 @@ -use std::borrow::Borrow; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::atomic::{self, AtomicU32}; diff --git a/src/endpoint.rs b/src/endpoint.rs index 3f292d9..97a7644 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,4 +1,3 @@ -use std::borrow::Borrow; use std::marker::PhantomData; use std::sync::Arc; diff --git a/src/message.rs b/src/message.rs index f92eb8c..22cae6a 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -43,6 +42,10 @@ pub trait Message: SerializeMessage + Send + Sync { } /// A trait for de/serializing messages, with possible associated stream. +/// This is default-implemented by anything that can already be serialized +/// and deserialized. Adapters are provided that implement this for +/// adding a body, either from a fixed Bytes buffer (which allows the thing +/// to be Clone), or from a streaming byte stream. pub trait SerializeMessage: Sized { type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; @@ -53,38 +56,9 @@ pub trait SerializeMessage: Sized { // ---- -impl SerializeMessage for Result -where - T: SerializeMessage + Send, - E: Serialize + for<'de> Deserialize<'de> + Send, -{ - type SerializableSelf = Result; - - fn into_parts(self) -> (Self::SerializableSelf, Option) { - match self { - Ok(ok) => { - let (msg, stream) = ok.into_parts(); - (Ok(msg), stream) - } - Err(err) => (Err(err), None), - } - } - - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { - match ser_self { - Ok(ok) => Ok(T::from_parts(ok, stream)), - Err(err) => Err(err), - } - } -} - -// --- - -pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} - impl SerializeMessage for T where - T: SimpleMessage, + T: Serialize + for<'de> Deserialize<'de> + Send, { type SerializableSelf = Self; fn into_parts(self) -> (Self::SerializableSelf, Option) { @@ -97,12 +71,15 @@ where } } -impl SimpleMessage for () {} - -impl SimpleMessage for std::sync::Arc {} - // ---- +/// An adapter that adds a body from a fixed Bytes buffer to a serializable message, +/// implementing the SerializeMessage trait. This allows for the SerializeMessage object +/// to be cloned, which is usefull for requests that must be sent to multiple servers. +/// Note that cloning the body is cheap thanks to Bytes; make sure that your serializable +/// part is also easily clonable (e.g. by wrapping it in an Arc). +/// Note that this CANNOT be used for a response type, as it cannot be reconstructed +/// from a remote stream. #[derive(Clone)] pub struct WithFixedBody Deserialize<'de> + Clone + Send + 'static>( pub T, @@ -123,11 +100,14 @@ where ) } - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { - panic!("Cannot reconstruct a WithFixedBody type from parts"); + fn from_parts(_ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { + panic!("Cannot use a WithFixedBody as a response type"); } } +/// An adapter that adds a body from a ByteStream. This is usefull for receiving +/// responses to requests that contain attached byte streams. This type is +/// not clonable. pub struct WithStreamingBody Deserialize<'de> + Send>( pub T, pub ByteStream, diff --git a/src/netapp.rs b/src/netapp.rs index 8365de0..32a5c23 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -38,8 +38,6 @@ pub(crate) struct HelloMessage { pub server_port: u16, } -impl SimpleMessage for HelloMessage {} - impl Message for HelloMessage { type Response = (); } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 98977a3..d7bc6a8 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -16,7 +16,6 @@ use tokio::sync::watch; use crate::endpoint::*; use crate::message::*; use crate::netapp::*; -use crate::send::*; use crate::NodeID; // -- Protocol messages -- @@ -332,7 +331,7 @@ impl Basalt { async fn do_pull(self: Arc, peer: NodeID) { match self .pull_endpoint - .call(&peer, &PullMessage {}, PRIO_NORMAL) + .call(&peer, PullMessage {}, PRIO_NORMAL) .await { Ok(resp) => { @@ -347,7 +346,7 @@ impl Basalt { async fn do_push(self: Arc, peer: NodeID) { let push_msg = self.make_push_message(); - match self.push_endpoint.call(&peer, &push_msg, PRIO_NORMAL).await { + match self.push_endpoint.call(&peer, push_msg, PRIO_NORMAL).await { Ok(_) => { trace!("KYEV PEXo {}", hex::encode(peer)); } @@ -469,14 +468,14 @@ impl Basalt { #[async_trait] impl EndpointHandler for Basalt { - async fn handle(self: &Arc, _pullmsg: &PullMessage, _from: NodeID) -> PushMessage { + async fn handle(self: &Arc, _pullmsg: PullMessage, _from: NodeID) -> PushMessage { self.make_push_message() } } #[async_trait] impl EndpointHandler for Basalt { - async fn handle(self: &Arc, pushmsg: &PushMessage, _from: NodeID) { + async fn handle(self: &Arc, pushmsg: PushMessage, _from: NodeID) { self.handle_peer_list(&pushmsg.peers[..]); } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 3eeebb3..f8348af 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -40,8 +40,6 @@ impl Message for PingMessage { type Response = PingMessage; } -impl SimpleMessage for PingMessage {} - #[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { pub list: Vec<(NodeID, SocketAddr)>, @@ -51,8 +49,6 @@ impl Message for PeerListMessage { type Response = PeerListMessage; } -impl SimpleMessage for PeerListMessage {} - // -- Algorithm data structures -- #[derive(Debug)] diff --git a/src/util.rs b/src/util.rs index f860672..e7ecea8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -9,8 +9,6 @@ use serde::Serialize; use futures::Stream; use tokio::sync::watch; -use crate::message::SerializeMessage; - /// A node's identifier, which is also its public cryptographic key pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; /// A node's secret key -- cgit v1.2.3 From 4934ed726d51913afd97ca937d0ece39ef8b7371 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 20:22:56 +0200 Subject: Propose alternative API --- src/client.rs | 15 ++-- src/endpoint.rs | 45 ++++++++--- src/message.rs | 205 +++++++++++++++++++++++++++++++++--------------- src/netapp.rs | 5 +- src/peering/basalt.rs | 17 +++- src/peering/fullmesh.rs | 12 +-- 6 files changed, 209 insertions(+), 90 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 9d572aa..c878627 100644 --- a/src/client.rs +++ b/src/client.rs @@ -135,10 +135,10 @@ impl ClientConn { pub(crate) async fn call( self: Arc, - rq: T, + req: Req, path: &str, prio: RequestPriority, - ) -> Result<::Response, Error> + ) -> Result, Error> where T: Message, { @@ -162,9 +162,8 @@ impl ClientConn { }; // Encode request - let (rq, stream) = rq.into_parts(); - let body = rmp_to_vec_all_named(&rq)?; - drop(rq); + let body = req.msg_ser.unwrap().clone(); + let stream = req.body.into_stream(); let request = QueryMessage { prio, @@ -216,7 +215,11 @@ impl ClientConn { let code = resp[0]; if code == 0 { let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; - Ok(T::Response::from_parts(ser_resp, stream)) + Ok(Resp { + _phantom: Default::default(), + msg: ser_resp, + body: BodyData::Stream(stream), + }) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index 97a7644..8ee64a5 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -18,7 +18,7 @@ pub trait EndpointHandler: Send + Sync where M: Message, { - async fn handle(self: &Arc, m: M, from: NodeID) -> M::Response; + async fn handle(self: &Arc, m: Req, from: NodeID) -> Resp; } /// If one simply wants to use an endpoint in a client fashion, @@ -27,7 +27,7 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl EndpointHandler for () { - async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response { + async fn handle(self: &Arc<()>, _m: Req, _from: NodeID) -> Resp { panic!("This endpoint should not have a local handler."); } } @@ -80,16 +80,19 @@ where /// Call this endpoint on a remote node (or on the local node, /// for that matter) - pub async fn call( + pub async fn call_full( &self, target: &NodeID, - req: M, + req: T, prio: RequestPriority, - ) -> Result<::Response, Error> { + ) -> Result, Error> + where + T: IntoReq, + { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req, self.netapp.id).await), + Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await), } } else { let conn = self @@ -104,10 +107,21 @@ where "Not connected: {}", hex::encode(&target[..8]) ))), - Some(c) => c.call(req, self.path.as_str(), prio).await, + Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await, } } } + + /// Call this endpoint on a remote node, without the possibility + /// of adding or receiving a body + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<::Response, Error> { + Ok(self.call_full(target, req, prio).await?.into_msg()) + } } // ---- Internal stuff ---- @@ -148,11 +162,20 @@ where None => Err(Error::NoHandler), Some(h) => { let req = rmp_serde::decode::from_read_ref(buf)?; - let req = M::from_parts(req, stream); + let req = Req { + _phantom: Default::default(), + msg: Arc::new(req), + msg_ser: None, + body: BodyData::Stream(stream), + }; let res = h.handle(req, from).await; - let (res, res_stream) = res.into_parts(); - let res_bytes = rmp_to_vec_all_named(&res)?; - Ok((res_bytes, res_stream)) + let Resp { + msg, + body, + _phantom, + } = res; + let res_bytes = rmp_to_vec_all_named(&msg)?; + Ok((res_bytes, body.into_stream())) } } } diff --git a/src/message.rs b/src/message.rs index 22cae6a..d918c29 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,3 +1,7 @@ +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -37,94 +41,169 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: SerializeMessage + Send + Sync { - type Response: SerializeMessage + Send + Sync; +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +} + +pub struct Req { + pub(crate) _phantom: PhantomData, + pub(crate) msg: Arc, + pub(crate) msg_ser: Option, + pub(crate) body: BodyData, } -/// A trait for de/serializing messages, with possible associated stream. -/// This is default-implemented by anything that can already be serialized -/// and deserialized. Adapters are provided that implement this for -/// adding a body, either from a fixed Bytes buffer (which allows the thing -/// to be Clone), or from a streaming byte stream. -pub trait SerializeMessage: Sized { - type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; +pub struct Resp { + pub(crate) _phantom: PhantomData, + pub(crate) msg: M::Response, + pub(crate) body: BodyData, +} - fn into_parts(self) -> (Self::SerializableSelf, Option); +pub(crate) enum BodyData { + None, + Fixed(Bytes), + Stream(ByteStream), +} - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +impl BodyData { + pub fn into_stream(self) -> Option { + match self { + BodyData::None => None, + BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + BodyData::Stream(s) => Some(s), + } + } } // ---- -impl SerializeMessage for T -where - T: Serialize + for<'de> Deserialize<'de> + Send, -{ - type SerializableSelf = Self; - fn into_parts(self) -> (Self::SerializableSelf, Option) { - (self, None) +impl Req { + pub fn msg(&self) -> &M { + &self.msg } - fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { - // TODO verify no stream - ser_self + pub fn with_fixed_body(self, b: Bytes) -> Self { + Self { + body: BodyData::Fixed(b), + ..self + } + } + + pub fn with_streaming_body(self, b: ByteStream) -> Self { + Self { + body: BodyData::Stream(b), + ..self + } } } -// ---- +pub trait IntoReq { + fn into_req(self) -> Result, rmp_serde::encode::Error>; + fn into_req_local(self) -> Req; +} -/// An adapter that adds a body from a fixed Bytes buffer to a serializable message, -/// implementing the SerializeMessage trait. This allows for the SerializeMessage object -/// to be cloned, which is usefull for requests that must be sent to multiple servers. -/// Note that cloning the body is cheap thanks to Bytes; make sure that your serializable -/// part is also easily clonable (e.g. by wrapping it in an Arc). -/// Note that this CANNOT be used for a response type, as it cannot be reconstructed -/// from a remote stream. -#[derive(Clone)] -pub struct WithFixedBody Deserialize<'de> + Clone + Send + 'static>( - pub T, - pub Bytes, -); - -impl SerializeMessage for WithFixedBody -where - T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static, -{ - type SerializableSelf = T; - - fn into_parts(self) -> (Self::SerializableSelf, Option) { - let body = self.1; - ( - self.0, - Some(Box::pin(futures::stream::once(async move { Ok(body) }))), - ) +impl IntoReq for M { + fn into_req(self) -> Result, rmp_serde::encode::Error> { + let msg_ser = rmp_to_vec_all_named(&self)?; + Ok(Req { + _phantom: Default::default(), + msg: Arc::new(self), + msg_ser: Some(Bytes::from(msg_ser)), + body: BodyData::None, + }) } + fn into_req_local(self) -> Req { + Req { + _phantom: Default::default(), + msg: Arc::new(self), + msg_ser: None, + body: BodyData::None, + } + } +} - fn from_parts(_ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { - panic!("Cannot use a WithFixedBody as a response type"); +impl IntoReq for Req { + fn into_req(self) -> Result, rmp_serde::encode::Error> { + Ok(self) + } + fn into_req_local(self) -> Req { + self } } -/// An adapter that adds a body from a ByteStream. This is usefull for receiving -/// responses to requests that contain attached byte streams. This type is -/// not clonable. -pub struct WithStreamingBody Deserialize<'de> + Send>( - pub T, - pub ByteStream, -); +impl Clone for Req { + fn clone(&self) -> Self { + let body = match &self.body { + BodyData::None => BodyData::None, + BodyData::Fixed(b) => BodyData::Fixed(b.clone()), + BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"), + }; + Self { + _phantom: Default::default(), + msg: self.msg.clone(), + msg_ser: self.msg_ser.clone(), + body, + } + } +} -impl SerializeMessage for WithStreamingBody +impl fmt::Debug for Req where - T: Serialize + for<'de> Deserialize<'de> + Send, + M: Message + fmt::Debug, { - type SerializableSelf = T; + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Req[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} + +impl fmt::Debug for Resp +where + M: Message, + ::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} + +impl Resp { + pub fn new(v: M::Response) -> Self { + Resp { + _phantom: Default::default(), + msg: v, + body: BodyData::None, + } + } + + pub fn with_fixed_body(self, b: Bytes) -> Self { + Self { + body: BodyData::Fixed(b), + ..self + } + } + + pub fn with_streaming_body(self, b: ByteStream) -> Self { + Self { + body: BodyData::Stream(b), + ..self + } + } - fn into_parts(self) -> (Self::SerializableSelf, Option) { - (self.0, Some(self.1)) + pub fn msg(&self) -> &M::Response { + &self.msg } - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { - WithStreamingBody(ser_self, stream) + pub fn into_msg(self) -> M::Response { + self.msg } } diff --git a/src/netapp.rs b/src/netapp.rs index 32a5c23..0cebac0 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -404,6 +404,7 @@ impl NetApp { PRIO_NORMAL, ) .await + .map(|_| ()) .log_err("Sending hello message"); }); } @@ -432,7 +433,8 @@ impl NetApp { #[async_trait] impl EndpointHandler for NetApp { - async fn handle(self: &Arc, msg: HelloMessage, from: NodeID) { + async fn handle(self: &Arc, msg: Req, from: NodeID) -> Resp { + let msg = msg.msg(); debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { @@ -441,5 +443,6 @@ impl EndpointHandler for NetApp { h(from, remote_addr, true); } } + Resp::new(()) } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index d7bc6a8..71dea84 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -468,15 +468,24 @@ impl Basalt { #[async_trait] impl EndpointHandler for Basalt { - async fn handle(self: &Arc, _pullmsg: PullMessage, _from: NodeID) -> PushMessage { - self.make_push_message() + async fn handle( + self: &Arc, + _pullmsg: Req, + _from: NodeID, + ) -> Resp { + Resp::new(self.make_push_message()) } } #[async_trait] impl EndpointHandler for Basalt { - async fn handle(self: &Arc, pushmsg: PushMessage, _from: NodeID) { - self.handle_peer_list(&pushmsg.peers[..]); + async fn handle( + self: &Arc, + pushmsg: Req, + _from: NodeID, + ) -> Resp { + self.handle_peer_list(&pushmsg.msg().peers[..]); + Resp::new(()) } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index f8348af..9b7b666 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -583,13 +583,14 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler for FullMeshPeeringStrategy { - async fn handle(self: &Arc, ping: PingMessage, from: NodeID) -> PingMessage { + async fn handle(self: &Arc, ping: Req, from: NodeID) -> Resp { + let ping = ping.msg(); let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, }; debug!("Ping from {}", hex::encode(&from[..8])); - ping_resp + Resp::new(ping_resp) } } @@ -597,11 +598,12 @@ impl EndpointHandler for FullMeshPeeringStrategy { impl EndpointHandler for FullMeshPeeringStrategy { async fn handle( self: &Arc, - peer_list: PeerListMessage, + peer_list: Req, _from: NodeID, - ) -> PeerListMessage { + ) -> Resp { + let peer_list = peer_list.msg(); self.handle_peer_list(&peer_list.list[..]); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); - PeerListMessage { list: peer_list } + Resp::new(PeerListMessage { list: peer_list }) } } -- cgit v1.2.3 From c358fe3c92da8a8454e461484737efe2a14dfd73 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 10:55:37 +0200 Subject: Hide streaming versions as much as possible --- src/endpoint.rs | 58 ++++++++++++++++++++++++++++++++++++------------- src/message.rs | 2 +- src/netapp.rs | 6 ++--- src/peering/basalt.rs | 17 ++++----------- src/peering/fullmesh.rs | 12 +++++----- src/util.rs | 2 +- 6 files changed, 56 insertions(+), 41 deletions(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index 8ee64a5..ff626d8 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -10,11 +10,13 @@ use crate::netapp::*; use crate::util::*; /// This trait should be implemented by an object of your application -/// that can handle a message of type `M`. +/// that can handle a message of type `M`, if it wishes to handle +/// streams attached to the request and/or to send back streams +/// attached to the response.. /// /// The handler object should be in an Arc, see `Endpoint::set_handler` #[async_trait] -pub trait EndpointHandler: Send + Sync +pub trait StreamingEndpointHandler: Send + Sync where M: Message, { @@ -27,11 +29,34 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl EndpointHandler for () { - async fn handle(self: &Arc<()>, _m: Req, _from: NodeID) -> Resp { + async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { panic!("This endpoint should not have a local handler."); } } +// ---- + +#[async_trait] +pub trait EndpointHandler: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc, m: &M, from: NodeID) -> ::Response; +} + +#[async_trait] +impl StreamingEndpointHandler for T +where + T: EndpointHandler, + M: Message + 'static, +{ + async fn handle(self: &Arc, m: Req, from: NodeID) -> Resp { + Resp::new(EndpointHandler::handle(self, m.msg(), from).await) + } +} + +// ---- + /// This struct represents an endpoint for message of type `M`. /// /// Creating a new endpoint is done by calling `NetApp::endpoint`. @@ -41,13 +66,13 @@ impl EndpointHandler for () { /// An `Endpoint` is used both to send requests to remote nodes, /// and to specify the handler for such requests on the local node. /// The type `H` represents the type of the handler object for -/// endpoint messages (see `EndpointHandler`). +/// endpoint messages (see `StreamingEndpointHandler`). pub struct Endpoint where M: Message, - H: EndpointHandler, + H: StreamingEndpointHandler, { - phantom: PhantomData, + _phantom: PhantomData, netapp: Arc, path: String, handler: ArcSwapOption, @@ -56,11 +81,11 @@ where impl Endpoint where M: Message, - H: EndpointHandler, + H: StreamingEndpointHandler, { pub(crate) fn new(netapp: Arc, path: String) -> Self { Self { - phantom: PhantomData::default(), + _phantom: PhantomData::default(), netapp, path, handler: ArcSwapOption::from(None), @@ -79,8 +104,10 @@ where } /// Call this endpoint on a remote node (or on the local node, - /// for that matter) - pub async fn call_full( + /// for that matter). This function invokes the full version that + /// allows to attach a streaming body to the request and to + /// receive such a body attached to the response. + pub async fn call_streaming( &self, target: &NodeID, req: T, @@ -112,15 +139,16 @@ where } } - /// Call this endpoint on a remote node, without the possibility - /// of adding or receiving a body + /// Call this endpoint on a remote node. This function is the simplified + /// version that doesn't allow to have streams attached to the request + /// or the response; see `call_streaming` for the full version. pub async fn call( &self, target: &NodeID, req: M, prio: RequestPriority, ) -> Result<::Response, Error> { - Ok(self.call_full(target, req, prio).await?.into_msg()) + Ok(self.call_streaming(target, req, prio).await?.into_msg()) } } @@ -144,13 +172,13 @@ pub(crate) trait GenericEndpoint { pub(crate) struct EndpointArc(pub(crate) Arc>) where M: Message, - H: EndpointHandler; + H: StreamingEndpointHandler; #[async_trait] impl GenericEndpoint for EndpointArc where M: Message + 'static, - H: EndpointHandler + 'static, + H: StreamingEndpointHandler + 'static, { async fn handle( &self, diff --git a/src/message.rs b/src/message.rs index d918c29..5721318 100644 --- a/src/message.rs +++ b/src/message.rs @@ -311,7 +311,7 @@ impl Framing { } } - pub async fn from_stream + Unpin + Send + 'static>( + pub async fn from_stream + Unpin + Send + Sync + 'static>( mut stream: S, ) -> Result { let mut packet = stream diff --git a/src/netapp.rs b/src/netapp.rs index 0cebac0..166f560 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -152,7 +152,7 @@ impl NetApp { pub fn endpoint(self: &Arc, path: String) -> Arc> where M: Message + 'static, - H: EndpointHandler + 'static, + H: StreamingEndpointHandler + 'static, { let endpoint = Arc::new(Endpoint::::new(self.clone(), path.clone())); let endpoint_arc = EndpointArc(endpoint.clone()); @@ -433,8 +433,7 @@ impl NetApp { #[async_trait] impl EndpointHandler for NetApp { - async fn handle(self: &Arc, msg: Req, from: NodeID) -> Resp { - let msg = msg.msg(); + async fn handle(self: &Arc, msg: &HelloMessage, from: NodeID) { debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { @@ -443,6 +442,5 @@ impl EndpointHandler for NetApp { h(from, remote_addr, true); } } - Resp::new(()) } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 71dea84..310077f 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -468,24 +468,15 @@ impl Basalt { #[async_trait] impl EndpointHandler for Basalt { - async fn handle( - self: &Arc, - _pullmsg: Req, - _from: NodeID, - ) -> Resp { - Resp::new(self.make_push_message()) + async fn handle(self: &Arc, _pullmsg: &PullMessage, _from: NodeID) -> PushMessage { + self.make_push_message() } } #[async_trait] impl EndpointHandler for Basalt { - async fn handle( - self: &Arc, - pushmsg: Req, - _from: NodeID, - ) -> Resp { - self.handle_peer_list(&pushmsg.msg().peers[..]); - Resp::new(()) + async fn handle(self: &Arc, pushmsg: &PushMessage, _from: NodeID) { + self.handle_peer_list(&pushmsg.peers[..]); } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 9b7b666..ccbd0ba 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -583,14 +583,13 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler for FullMeshPeeringStrategy { - async fn handle(self: &Arc, ping: Req, from: NodeID) -> Resp { - let ping = ping.msg(); + async fn handle(self: &Arc, ping: &PingMessage, from: NodeID) -> PingMessage { let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, }; debug!("Ping from {}", hex::encode(&from[..8])); - Resp::new(ping_resp) + ping_resp } } @@ -598,12 +597,11 @@ impl EndpointHandler for FullMeshPeeringStrategy { impl EndpointHandler for FullMeshPeeringStrategy { async fn handle( self: &Arc, - peer_list: Req, + peer_list: &PeerListMessage, _from: NodeID, - ) -> Resp { - let peer_list = peer_list.msg(); + ) -> PeerListMessage { self.handle_peer_list(&peer_list.list[..]); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); - Resp::new(PeerListMessage { list: peer_list }) + PeerListMessage { list: peer_list } } } diff --git a/src/util.rs b/src/util.rs index e7ecea8..01c392c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -24,7 +24,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// Error code 255 means the stream was cut before its end. Other codes have no predefined /// meaning, it's up to your application to define their semantic. -pub type ByteStream = Pin + Send>>; +pub type ByteStream = Pin + Send + Sync>>; pub type Packet = Result; -- cgit v1.2.3 From 0b71ca12f910c17eaf2291076438dff3b70dc9cd Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 12:45:38 +0200 Subject: Clean up framing protocol --- src/client.rs | 58 +++------ src/endpoint.rs | 30 +---- src/lib.rs | 1 + src/message.rs | 357 +++++++++++++++++++++++++++++++++----------------------- src/recv.rs | 2 +- src/send.rs | 2 +- src/server.rs | 53 ++++----- src/stream.rs | 176 ++++++++++++++++++++++++++++ src/util.rs | 15 --- 9 files changed, 429 insertions(+), 265 deletions(-) create mode 100644 src/stream.rs (limited to 'src') diff --git a/src/client.rs b/src/client.rs index c878627..42eeaa3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use async_trait::async_trait; +use bytes::Bytes; use log::{debug, error, trace}; use futures::channel::mpsc::{unbounded, UnboundedReceiver}; @@ -28,6 +29,7 @@ use crate::message::*; use crate::netapp::*; use crate::recv::*; use crate::send::*; +use crate::stream::*; use crate::util::*; pub(crate) struct ClientConn { @@ -155,24 +157,16 @@ impl ClientConn { .with_kind(SpanKind::Client) .start(&tracer); let propagator = BinaryPropagator::new(); - let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec()); + let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into(); } else { - let telemetry_id: Option> = None; + let telemetry_id: Bytes = Bytes::new(); } }; // Encode request - let body = req.msg_ser.unwrap().clone(); - let stream = req.body.into_stream(); - - let request = QueryMessage { - prio, - path: path.as_bytes(), - telemetry_id, - body: &body[..], - }; - let bytes = request.encode(); - drop(body); + let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); + let req_msg_len = req_enc.msg.len(); + let req_stream = req_enc.encode(); // Send request through let (resp_send, resp_recv) = oneshot::channel(); @@ -181,17 +175,19 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(unbounded().1).is_err() { - debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); - } + let _ = old_ch.send(unbounded().1); } - trace!("request: query_send {}, {} bytes", id, bytes.len()); + trace!( + "request: query_send {} (serialized message: {} bytes)", + id, + req_msg_len + ); #[cfg(feature = "telemetry")] - span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?; + query_send.send((id, prio, req_stream))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -202,28 +198,10 @@ impl ClientConn { let stream = resp_recv.await?; } } - let (resp, stream) = Framing::from_stream(stream).await?.into_parts(); - if resp.is_empty() { - return Err(Error::Message( - "Response is 0 bytes, either a collision or a protocol error".into(), - )); - } - - trace!("request response {}: ", id); - - let code = resp[0]; - if code == 0 { - let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; - Ok(Resp { - _phantom: Default::default(), - msg: ser_resp, - body: BodyData::Stream(stream), - }) - } else { - let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); - Err(Error::Remote(code, msg)) - } + let resp_enc = RespEnc::decode(Box::pin(stream)).await?; + trace!("request response {}", id); + Resp::from_enc(resp_enc) } } diff --git a/src/endpoint.rs b/src/endpoint.rs index ff626d8..d8dc6c4 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -158,12 +158,7 @@ pub(crate) type DynEndpoint = Box; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle( - &self, - buf: &[u8], - stream: ByteStream, - from: NodeID, - ) -> Result<(Vec, Option), Error>; + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -180,30 +175,13 @@ where M: Message + 'static, H: StreamingEndpointHandler + 'static, { - async fn handle( - &self, - buf: &[u8], - stream: ByteStream, - from: NodeID, - ) -> Result<(Vec, Option), Error> { + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref(buf)?; - let req = Req { - _phantom: Default::default(), - msg: Arc::new(req), - msg_ser: None, - body: BodyData::Stream(stream), - }; + let req = Req::from_enc(req_enc)?; let res = h.handle(req, from).await; - let Resp { - msg, - body, - _phantom, - } = res; - let res_bytes = rmp_to_vec_all_named(&msg)?; - Ok((res_bytes, body.into_stream())) + Ok(res.into_enc()?) } } } diff --git a/src/lib.rs b/src/lib.rs index 1edb919..ce94682 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ //! Also check out the examples to learn how to use this crate. pub mod error; +pub mod stream; pub mod util; pub mod endpoint; diff --git a/src/message.rs b/src/message.rs index 5721318..ba06551 100644 --- a/src/message.rs +++ b/src/message.rs @@ -2,12 +2,13 @@ use std::fmt; use std::marker::PhantomData; use std::sync::Arc; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use serde::{Deserialize, Serialize}; -use futures::stream::{Stream, StreamExt}; +use futures::stream::StreamExt; use crate::error::*; +use crate::stream::*; use crate::util::*; /// Priority of a request (click to read more about priorities). @@ -45,6 +46,15 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; } +// ---- + +/// The Req is a helper object used to create requests and attach them +/// a streaming body. If the body is a fixed Bytes and not a ByteStream, +/// Req is cheaply clonable to allow the request to be sent to different +/// peers (Clone will panic if the body is a ByteStream). +/// +/// Internally, this is also used to encode and decode requests +/// from/to byte streams to be sent over the network. pub struct Req { pub(crate) _phantom: PhantomData, pub(crate) msg: Arc, @@ -52,30 +62,6 @@ pub struct Req { pub(crate) body: BodyData, } -pub struct Resp { - pub(crate) _phantom: PhantomData, - pub(crate) msg: M::Response, - pub(crate) body: BodyData, -} - -pub(crate) enum BodyData { - None, - Fixed(Bytes), - Stream(ByteStream), -} - -impl BodyData { - pub fn into_stream(self) -> Option { - match self { - BodyData::None => None, - BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), - BodyData::Stream(s) => Some(s), - } - } -} - -// ---- - impl Req { pub fn msg(&self) -> &M { &self.msg @@ -94,6 +80,31 @@ impl Req { ..self } } + + pub(crate) fn into_enc( + self, + prio: RequestPriority, + path: Bytes, + telemetry_id: Bytes, + ) -> ReqEnc { + ReqEnc { + prio, + path, + telemetry_id, + msg: self.msg_ser.unwrap(), + stream: self.body.into_stream(), + } + } + + pub(crate) fn from_enc(enc: ReqEnc) -> Result { + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Req { + _phantom: Default::default(), + msg: Arc::new(msg), + msg_ser: Some(enc.msg), + body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None), + }) + } } pub trait IntoReq { @@ -160,19 +171,14 @@ where } } -impl fmt::Debug for Resp -where - M: Message, - ::Response: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(f, "Resp[{:?}", self.msg)?; - match &self.body { - BodyData::None => write!(f, "]"), - BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), - BodyData::Stream(_) => write!(f, "; body=stream]"), - } - } +// ---- + +/// The Resp represents a full response from a RPC that may have +/// an attached body stream. +pub struct Resp { + pub(crate) _phantom: PhantomData, + pub(crate) msg: M::Response, + pub(crate) body: BodyData, } impl Resp { @@ -205,160 +211,213 @@ impl Resp { pub fn into_msg(self) -> M::Response { self.msg } + + pub(crate) fn into_enc(self) -> Result { + Ok(RespEnc::Success { + msg: rmp_to_vec_all_named(&self.msg)?.into(), + stream: self.body.into_stream(), + }) + } + + pub(crate) fn from_enc(enc: RespEnc) -> Result { + match enc { + RespEnc::Success { msg, stream } => { + let msg = rmp_serde::decode::from_read_ref(&msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + body: stream.map(BodyData::Stream).unwrap_or(BodyData::None), + }) + } + RespEnc::Error { code, message } => Err(Error::Remote(code, message)), + } + } } -// ---- ---- +impl fmt::Debug for Resp +where + M: Message, + ::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option>, - pub(crate) body: &'a [u8], +// ---- + +pub(crate) enum BodyData { + None, + Fixed(Bytes), + Stream(ByteStream), +} + +impl BodyData { + pub fn into_stream(self) -> Option { + match self { + BodyData::None => None, + BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + BodyData::Stream(s) => Some(s), + } + } } -/// QueryMessage encoding: +// ---- ---- + +/// Encoding for requests into a ByteStream: /// - priority: u8 /// - path length: u8 /// - path: [u8; path length] /// - telemetry id length: u8 /// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +pub(crate) struct ReqEnc { + pub(crate) prio: RequestPriority, + pub(crate) path: Bytes, + pub(crate) telemetry_id: Bytes, + pub(crate) msg: Bytes, + pub(crate) stream: Option, +} + +impl ReqEnc { + pub(crate) fn encode(self) -> ByteStream { + let mut buf = BytesMut::with_capacity(64); + + buf.put_u8(self.prio); - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + buf.put_u8(self.path.len() as u8); + buf.put(self.path); - ret.push(self.prio); + buf.put_u8(self.telemetry_id.len() as u8); + buf.put(&self.telemetry_id[..]); - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); + buf.put_u32(self.msg.len() as u32); + buf.put(&self.msg[..]); - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); + let header = buf.freeze(); + + if let Some(stream) = self.stream { + Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) } else { - ret.push(0u8); + Box::pin(futures::stream::once(async move { Ok(header) })) } + } - ret.extend_from_slice(self.body); - - ret + pub(crate) async fn decode(stream: ByteStream) -> Result { + Self::decode_aux(stream).await.map_err(|_| Error::Framing) } - pub(crate) fn decode(bytes: &'a [u8]) -> Result { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } + pub(crate) async fn decode_aux(stream: ByteStream) -> Result { + let mut reader = ByteStreamReader::new(stream); - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } + let prio = reader.read_u8().await?; - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } + let path_len = reader.read_u8().await?; + let path = reader.read_exact(path_len as usize).await?; - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; + let telemetry_id_len = reader.read_u8().await?; + let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?; - let body = &bytes[3 + path_length + telemetry_id_len..]; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; Ok(Self { - prio: bytes[0], + prio, path, telemetry_id, - body, + msg, + stream: Some(reader.into_stream()), }) } } -// ---- ---- - -pub(crate) struct Framing { - direct: Vec, - stream: Option, +/// Encoding for responses into a ByteStream: +/// IF SUCCESS: +/// - 0: u8 +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +/// IF ERROR: +/// - message length + 1: u8 +/// - error code: u8 +/// - message: [u8; message_length] +pub(crate) enum RespEnc { + Error { + code: u8, + message: String, + }, + Success { + msg: Bytes, + stream: Option, + }, } -impl Framing { - pub fn new(direct: Vec, stream: Option) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } +impl RespEnc { + pub(crate) fn from_err(e: Error) -> Self { + RespEnc::Error { + code: e.code(), + message: format!("{}", e), + } } - pub fn into_stream(self) -> ByteStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; + pub(crate) fn encode(self) -> ByteStream { + match self { + RespEnc::Success { msg, stream } => { + let mut buf = BytesMut::with_capacity(64); - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) }) - .chain(stream::once(async move { Ok(direct.into()) })); + buf.put_u8(0); - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } + buf.put_u32(msg.len() as u32); + buf.put(&msg[..]); - pub async fn from_stream + Unpin + Send + Sync + 'static>( - mut stream: S, - ) -> Result { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); + let header = buf.freeze(); + + if let Some(stream) = stream { + Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) + } else { + Box::pin(futures::stream::once(async move { Ok(header) })) + } + } + RespEnc::Error { code, message } => { + let mut buf = BytesMut::with_capacity(64); + buf.put_u8(1 + message.len() as u8); + buf.put_u8(code); + buf.put(message.as_bytes()); + let header = buf.freeze(); + Box::pin(futures::stream::once(async move { Ok(header) })) + } } + } - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet = packet.slice(4..); + pub(crate) async fn decode(stream: ByteStream) -> Result { + Self::decode_aux(stream).await.map_err(|_| Error::Framing) + } - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + pub(crate) async fn decode_aux(stream: ByteStream) -> Result { + let mut reader = ByteStreamReader::new(stream); - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet = packet.slice(max_cp..); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } + let is_err = reader.read_u8().await?; - let stream: ByteStream = if packet.is_empty() { - Box::pin(stream) + if is_err > 0 { + let code = reader.read_u8().await?; + let message = reader.read_exact(is_err as usize - 1).await?; + let message = String::from_utf8(message.to_vec()).unwrap_or_default(); + Ok(RespEnc::Error { code, message }) } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec, ByteStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + Ok(RespEnc::Success { + msg, + stream: Some(reader.into_stream()), + }) + } } } diff --git a/src/recv.rs b/src/recv.rs index abe7b9a..19288f2 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -10,7 +10,7 @@ use futures::AsyncReadExt; use crate::error::*; use crate::send::*; -use crate::util::Packet; +use crate::stream::*; /// Structure to warn when the sender is dropped before end of stream was reached, like when /// connection to some remote drops while transmitting data diff --git a/src/send.rs b/src/send.rs index cc28d7c..59805cf 100644 --- a/src/send.rs +++ b/src/send.rs @@ -14,7 +14,7 @@ use tokio::sync::mpsc; use crate::error::*; use crate::message::*; -use crate::util::{ByteStream, Packet}; +use crate::stream::*; // Messages are sent by chunks // Chunk format: diff --git a/src/server.rs b/src/server.rs index 1f1c22a..ae1196c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -28,6 +28,7 @@ use crate::message::*; use crate::netapp::*; use crate::recv::*; use crate::send::*; +use crate::stream::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -121,17 +122,12 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux( - self: &Arc, - bytes: &[u8], - stream: ByteStream, - ) -> Result<(Vec, Option), Error> { - let msg = QueryMessage::decode(bytes)?; - let path = String::from_utf8(msg.path.to_vec())?; + async fn recv_handler_aux(self: &Arc, req_enc: ReqEnc) -> Result { + let path = String::from_utf8(req_enc.path.to_vec())?; let handler_opt = { let endpoints = self.netapp.endpoints.read().unwrap(); - endpoints.get(&path).map(|e| e.clone_endpoint()) + endpoints.get(&path[..]).map(|e| e.clone_endpoint()) }; if let Some(handler) = handler_opt { @@ -139,9 +135,9 @@ impl ServerConn { if #[cfg(feature = "telemetry")] { let tracer = opentelemetry::global::tracer("netapp"); - let mut span = if let Some(telemetry_id) = msg.telemetry_id { + let mut span = if !req_enc.telemetry_id.is_empty() { let propagator = BinaryPropagator::new(); - let context = propagator.from_bytes(telemetry_id); + let context = propagator.from_bytes(req_enc.telemetry_id.to_vec()); let context = Context::new().with_remote_span_context(context); tracer.span_builder(format!(">> RPC {}", path)) .with_kind(SpanKind::Server) @@ -156,13 +152,13 @@ impl ServerConn { .start(&tracer) }; span.set_attribute(KeyValue::new("path", path.to_string())); - span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64)); - handler.handle(msg.body, stream, self.peer_id) + handler.handle(req_enc, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, stream, self.peer_id).await + handler.handle(req_enc, self.peer_id).await } } } else { @@ -181,32 +177,23 @@ impl RecvLoop for ServerConn { let self2 = self.clone(); tokio::spawn(async move { trace!("ServerConn recv_handler {}", id); - let (bytes, stream) = Framing::from_stream(stream).await?.into_parts(); - - let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..], stream).await; - - let (resp_bytes, resp_stream) = match resp { - Ok((rb, rs)) => { - let mut resp_bytes = vec![0u8]; - resp_bytes.extend(rb); - (resp_bytes, rs) - } - Err(e) => { - let mut resp_bytes = vec![e.code()]; - resp_bytes.extend(e.to_string().into_bytes()); - (resp_bytes, None) + let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await { + Ok(req_enc) => { + let prio = req_enc.prio; + let resp = self2.recv_handler_aux(req_enc).await; + + (prio, match resp { + Ok(resp_enc) => resp_enc, + Err(e) => RespEnc::from_err(e), + }) } + Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), }; trace!("ServerConn sending response to {}: ", id); resp_send - .send(( - id, - prio, - Framing::new(resp_bytes, resp_stream).into_stream(), - )) + .send((id, prio, resp_enc.encode())) .log_err("ServerConn recv_handler send resp bytes"); Ok::<_, Error>(()) }); diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..6c23f4a --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,176 @@ +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; + +use futures::Future; +use futures::{Stream, StreamExt}; + +/// A stream of associated data. +/// +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type ByteStream = Pin + Send + Sync>>; + +pub type Packet = Result; + +pub struct ByteStreamReader { + stream: ByteStream, + buf: VecDeque, + buf_len: usize, + eos: bool, + err: Option, +} + +impl ByteStreamReader { + pub fn new(stream: ByteStream) -> Self { + ByteStreamReader { + stream, + buf: VecDeque::with_capacity(8), + buf_len: 0, + eos: false, + err: None, + } + } + + pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: true, + } + } + + pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: false, + } + } + + pub async fn read_u8(&mut self) -> Result { + Ok(self.read_exact(1).await?[0]) + } + + pub async fn read_u16(&mut self) -> Result { + let bytes = self.read_exact(2).await?; + let mut b = [0u8; 2]; + b.copy_from_slice(&bytes[..]); + Ok(u16::from_be_bytes(b)) + } + + pub async fn read_u32(&mut self) -> Result { + let bytes = self.read_exact(4).await?; + let mut b = [0u8; 4]; + b.copy_from_slice(&bytes[..]); + Ok(u32::from_be_bytes(b)) + } + + pub fn into_stream(self) -> ByteStream { + let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok)); + if let Some(err) = self.err { + Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) + } else if self.eos { + Box::pin(buf_stream) + } else { + Box::pin(buf_stream.chain(self.stream)) + } + } + + fn try_get(&mut self, read_len: usize) -> Option { + if self.buf_len >= read_len { + let mut slices = Vec::with_capacity(self.buf.len()); + let mut taken = 0; + while taken < read_len { + let front = self.buf.pop_front().unwrap(); + if taken + front.len() <= read_len { + taken += front.len(); + self.buf_len -= front.len(); + slices.push(front); + } else { + let front_take = read_len - taken; + slices.push(front.slice(..front_take)); + self.buf.push_front(front.slice(front_take..)); + self.buf_len -= front_take; + break; + } + } + Some( + slices + .iter() + .map(|x| &x[..]) + .collect::>() + .concat() + .into(), + ) + } else { + None + } + } +} + +pub enum ReadExactError { + UnexpectedEos, + Stream(u8), +} + +#[pin_project::pin_project] +pub struct ByteStreamReadExact<'a> { + #[pin] + reader: &'a mut ByteStreamReader, + read_len: usize, + fail_on_eos: bool, +} + +impl<'a> Future for ByteStreamReadExact<'a> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + loop { + if let Some(bytes) = this.reader.try_get(*this.read_len) { + return Poll::Ready(Ok(bytes)); + } + if let Some(err) = this.reader.err { + return Poll::Ready(Err(ReadExactError::Stream(err))); + } + if this.reader.eos { + if *this.fail_on_eos { + return Poll::Ready(Err(ReadExactError::UnexpectedEos)); + } else { + let bytes = Bytes::from( + this.reader + .buf + .iter() + .map(|x| &x[..]) + .collect::>() + .concat(), + ); + this.reader.buf.clear(); + this.reader.buf_len = 0; + return Poll::Ready(Ok(bytes)); + } + } + + match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { + Some(Ok(slice)) => { + this.reader.buf_len += slice.len(); + this.reader.buf.push_back(slice); + } + Some(Err(e)) => { + this.reader.err = Some(e); + this.reader.eos = true; + } + None => { + this.reader.eos = true; + } + } + } + } +} diff --git a/src/util.rs b/src/util.rs index 01c392c..13cccb9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,12 +1,9 @@ use std::net::SocketAddr; use std::net::ToSocketAddrs; -use std::pin::Pin; -use bytes::Bytes; use log::info; use serde::Serialize; -use futures::Stream; use tokio::sync::watch; /// A node's identifier, which is also its public cryptographic key @@ -16,18 +13,6 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; -/// A stream of associated data. -/// -/// The Stream can continue after receiving an error. -/// When sent through Netapp, the Vec may be split in smaller chunk in such a way -/// consecutive Vec may get merged, but Vec and error code may not be reordered -/// -/// Error code 255 means the stream was cut before its end. Other codes have no predefined -/// meaning, it's up to your application to define their semantic. -pub type ByteStream = Pin + Send + Sync>>; - -pub type Packet = Result; - /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. /// -- cgit v1.2.3 From 9cb28c21b4a80aa9f29097f6bb1b8b6c23446ddc Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:01:52 +0200 Subject: Use bounded channels on receive side for backpressure --- src/client.rs | 11 ++++++----- src/error.rs | 4 ++++ src/recv.rs | 38 ++++++++++++++++++-------------------- src/server.rs | 16 +++++++++------- 4 files changed, 37 insertions(+), 32 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 42eeaa3..d51236b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,7 +8,6 @@ use async_trait::async_trait; use bytes::Bytes; use log::{debug, error, trace}; -use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use futures::io::AsyncReadExt; use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; @@ -39,7 +38,7 @@ pub(crate) struct ClientConn { query_send: ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>>, + inflight: Mutex>>, } impl ClientConn { @@ -175,7 +174,9 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - let _ = old_ch.send(unbounded().1); + let _ = old_ch.send(Box::pin(futures::stream::once(async move { + Err(Error::IdCollision.code()) + }))); } trace!( @@ -199,7 +200,7 @@ impl ClientConn { } } - let resp_enc = RespEnc::decode(Box::pin(stream)).await?; + let resp_enc = RespEnc::decode(stream).await?; trace!("request response {}", id); Resp::from_enc(resp_enc) } @@ -209,7 +210,7 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); diff --git a/src/error.rs b/src/error.rs index 665647c..f374341 100644 --- a/src/error.rs +++ b/src/error.rs @@ -28,6 +28,9 @@ pub enum Error { #[error(display = "Framing protocol error")] Framing, + #[error(display = "Request ID collision")] + IdCollision, + #[error(display = "{}", _0)] Message(String), @@ -56,6 +59,7 @@ impl Error { Self::Framing => 13, Self::NoHandler => 20, Self::ConnectionClosed => 21, + Self::IdCollision => 22, Self::Handshake(_) => 30, Self::VersionMismatch(_) => 31, Self::Remote(c, _) => *c, diff --git a/src/recv.rs b/src/recv.rs index 19288f2..b2f5530 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -5,8 +5,8 @@ use async_trait::async_trait; use bytes::Bytes; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::AsyncReadExt; +use tokio::sync::mpsc; use crate::error::*; use crate::send::*; @@ -15,33 +15,28 @@ use crate::stream::*; /// Structure to warn when the sender is dropped before end of stream was reached, like when /// connection to some remote drops while transmitting data struct Sender { - inner: UnboundedSender, - closed: bool, + inner: Option>, } impl Sender { - fn new(inner: UnboundedSender) -> Self { - Sender { - inner, - closed: false, - } + fn new(inner: mpsc::Sender) -> Self { + Sender { inner: Some(inner) } } - fn send(&self, packet: Packet) { - let _ = self.inner.unbounded_send(packet); + async fn send(&self, packet: Packet) { + let _ = self.inner.as_ref().unwrap().send(packet).await; } fn end(&mut self) { - self.closed = true; + self.inner = None; } } impl Drop for Sender { fn drop(&mut self) { - if !self.closed { - self.send(Err(255)); + if let Some(inner) = self.inner.take() { + let _ = inner.blocking_send(Err(255)); } - self.inner.close_channel(); } } @@ -54,7 +49,7 @@ impl Drop for Sender { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where @@ -92,14 +87,17 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { - let (send, recv) = unbounded(); - self.recv_handler(id, recv); + let (send, recv) = mpsc::channel(4); + self.recv_handler( + id, + Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)), + ); Sender::new(send) }; - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - sender.send(packet); + // If we get an error, the receiving end is disconnected. + // We still need to reach eos before dropping this sender + let _ = sender.send(packet).await; if has_cont { streams.insert(id, sender); diff --git a/src/server.rs b/src/server.rs index ae1196c..4b232af 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,7 +5,6 @@ use arc_swap::ArcSwapOption; use async_trait::async_trait; use log::{debug, trace}; -use futures::channel::mpsc::UnboundedReceiver; use futures::io::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::{handshake_server, BoxStream}; use tokio::net::TcpStream; @@ -171,21 +170,24 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); tokio::spawn(async move { trace!("ServerConn recv_handler {}", id); - let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await { + let (prio, resp_enc) = match ReqEnc::decode(stream).await { Ok(req_enc) => { let prio = req_enc.prio; let resp = self2.recv_handler_aux(req_enc).await; - (prio, match resp { - Ok(resp_enc) => resp_enc, - Err(e) => RespEnc::from_err(e), - }) + ( + prio, + match resp { + Ok(resp_enc) => resp_enc, + Err(e) => RespEnc::from_err(e), + }, + ) } Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), }; -- cgit v1.2.3 From 5da59ebec5f3072d0b6c3b1ffc90eb8923c50ad9 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:06:10 +0200 Subject: Move things around and fix error bit --- src/endpoint.rs | 1 - src/lib.rs | 1 - src/netapp.rs | 8 +++++++- src/recv.rs | 2 +- src/util.rs | 7 +------ 5 files changed, 9 insertions(+), 10 deletions(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index d8dc6c4..7088879 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -7,7 +7,6 @@ use async_trait::async_trait; use crate::error::Error; use crate::message::*; use crate::netapp::*; -use crate::util::*; /// This trait should be implemented by an object of your application /// that can handle a message of type `M`, if it wishes to handle diff --git a/src/lib.rs b/src/lib.rs index ce94682..bd41048 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,6 @@ pub mod netapp; pub mod peering; pub use crate::netapp::*; -pub use util::{NetworkKey, NodeID, NodeKey}; #[cfg(test)] mod test; diff --git a/src/netapp.rs b/src/netapp.rs index 166f560..29df3b9 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -22,7 +22,13 @@ use crate::endpoint::*; use crate::error::*; use crate::message::*; use crate::server::*; -use crate::util::*; + +/// A node's identifier, which is also its public cryptographic key +pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; +/// A node's secret key +pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; +/// A network key +pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// Tag which is exchanged between client and server upon connection establishment /// to check that they are running compatible versions of Netapp, diff --git a/src/recv.rs b/src/recv.rs index b2f5530..2be8728 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -75,7 +75,7 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; let packet = if is_error { - Err(size as u8) + Err((size & !ERROR_MARKER) as u8) } else { let size = size & !CHUNK_HAS_CONTINUATION; let mut next_slice = vec![0; size as usize]; diff --git a/src/util.rs b/src/util.rs index 13cccb9..425d26f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -6,12 +6,7 @@ use serde::Serialize; use tokio::sync::watch; -/// A node's identifier, which is also its public cryptographic key -pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; -/// A node's secret key -pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; -/// A network key -pub type NetworkKey = sodiumoxide::crypto::auth::Key; +use crate::netapp::*; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. -- cgit v1.2.3 From f9db9a4b696569bbc56c40b9170320307ebcdd81 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:23:42 +0200 Subject: Simplify send.rs --- src/send.rs | 205 ++++++++++++++-------------------------------------------- src/stream.rs | 29 +++++---- 2 files changed, 68 insertions(+), 166 deletions(-) (limited to 'src') diff --git a/src/send.rs b/src/send.rs index 59805cf..a8cf966 100644 --- a/src/send.rs +++ b/src/send.rs @@ -8,7 +8,6 @@ use bytes::Bytes; use log::trace; use futures::AsyncWriteExt; -use futures::Stream; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -30,152 +29,14 @@ pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; +struct SendQueue { + items: VecDeque<(u8, VecDeque)>, +} + struct SendQueueItem { id: RequestID, prio: RequestPriority, - data: DataReader, -} - -#[pin_project::pin_project] -struct DataReader { - #[pin] - reader: ByteStream, - packet: Packet, - pos: usize, - buf: Vec, - eos: bool, -} - -impl From for DataReader { - fn from(data: ByteStream) -> DataReader { - DataReader { - reader: data, - packet: Ok(Bytes::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - } - } -} - -enum DataFrame { - Data { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actual lenght of data - len: usize, - /// whethere there may be more data comming from this stream. Can be used for some - /// optimization. It's an error to set it to false if there is more data, but it is correct - /// (albeit sub-optimal) to set it to true if there is nothing coming after - may_have_more: bool, - }, - /// An error code automatically signals the end of the stream - Error(u8), -} - -impl DataFrame { - fn empty_last() -> Self { - DataFrame::Data { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, - may_have_more: false, - } - } - - fn header(&self) -> [u8; 2] { - let header_u16 = match self { - DataFrame::Data { - len, - may_have_more: false, - .. - } => *len as u16, - DataFrame::Data { - len, - may_have_more: true, - .. - } => *len as u16 | CHUNK_HAS_CONTINUATION, - DataFrame::Error(e) => *e as u16 | ERROR_MARKER, - }; - ChunkLength::to_be_bytes(header_u16) - } - - fn data(&self) -> &[u8] { - match self { - DataFrame::Data { ref data, len, .. } => &data[..*len], - DataFrame::Error(_) => &[], - } - } - - fn may_have_more(&self) -> bool { - match self { - DataFrame::Data { may_have_more, .. } => *may_have_more, - DataFrame::Error(_) => false, - } - } -} - -impl Stream for DataReader { - type Item = DataFrame; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - if *this.eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - - loop { - let packet = match this.packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *this.packet = Ok(Bytes::new()); - *this.eos = true; - return Poll::Ready(Some(DataFrame::Error(e))); - } - }; - let packet_left = packet.len() - *this.pos; - let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - this.buf - .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); - *this.pos += to_read; - if this.buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { - *this.packet = p; - *this.pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if this.packet.is_err() && !this.buf.is_empty() { - break; - } - } else { - *this.eos = true; - break; - } - } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = this.buf.len(); - body[..len].copy_from_slice(this.buf); - this.buf.clear(); - Poll::Ready(Some(DataFrame::Data { - data: body, - len, - may_have_more: !*this.eos, - })) - } -} - -struct SendQueue { - items: VecDeque<(u8, VecDeque)>, + data: ByteStreamReader, } impl SendQueue { @@ -232,35 +93,69 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { let mut ready_item = None; for (j, item) in items_at_prio.iter_mut().enumerate() { - match Pin::new(&mut item.data).poll_next(ctx) { + let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); + match Pin::new(&mut item_reader).poll(ctx) { Poll::Pending => (), Poll::Ready(ready_v) => { - ready_item = Some((j, ready_v)); + ready_item = Some((j, ready_v, item.data.eos())); break; } } } - if let Some((j, ready_v)) = ready_item { + if let Some((j, bytes_or_err, eos)) = ready_item { + let data_frame = match bytes_or_err { + Ok(bytes) => DataFrame::Data(bytes, !eos), + Err(e) => DataFrame::Error(match e { + ReadExactError::Stream(code) => code, + _ => unreachable!(), + }), + }; let item = items_at_prio.remove(j).unwrap(); let id = item.id; - if ready_v - .as_ref() - .map(|data| data.may_have_more()) - .unwrap_or(false) - { + if !eos { items_at_prio.push_back(item); } else if items_at_prio.is_empty() { self.queue.items.remove(i); } - return Poll::Ready((id, ready_v.unwrap_or_else(DataFrame::empty_last))); + return Poll::Ready((id, data_frame)); } } - // TODO what do we do if self.queue is empty? We won't get scheduled again. + // If the queue is empty, this futures is eternally pending. + // This is ok because we use it in a select with another future + // that can interrupt it. Poll::Pending } } +enum DataFrame { + /// a fixed size buffer containing some data + a boolean indicating whether + /// there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + Data(Bytes, bool), + /// An error code automatically signals the end of the stream + Error(u8), +} + +impl DataFrame { + fn header(&self) -> [u8; 2] { + let header_u16 = match self { + DataFrame::Data(data, false) => data.len() as u16, + DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION, + DataFrame::Error(e) => *e as u16 | ERROR_MARKER, + }; + ChunkLength::to_be_bytes(header_u16) + } + + fn data(&self) -> &[u8] { + match self { + DataFrame::Data(ref data, _) => &data[..], + DataFrame::Error(_) => &[], + } + } +} + /// The SendLoop trait, which is implemented both by the client and the server /// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` /// that takes a channel of messages to send and an asynchronous writer, @@ -295,7 +190,7 @@ pub(crate) trait SendLoop: Sync { sending.push(SendQueueItem { id, prio, - data: data.into(), + data: ByteStreamReader::new(data), }); } else { should_exit = true; diff --git a/src/stream.rs b/src/stream.rs index 6c23f4a..ae57d62 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -82,6 +82,23 @@ impl ByteStreamReader { } } + pub fn take_buffer(&mut self) -> Bytes { + let bytes = Bytes::from( + self .buf + .iter() + .map(|x| &x[..]) + .collect::>() + .concat(), + ); + self.buf.clear(); + self.buf_len = 0; + bytes + } + + pub fn eos(&self) -> bool { + self.buf.is_empty() && self.eos + } + fn try_get(&mut self, read_len: usize) -> Option { if self.buf_len >= read_len { let mut slices = Vec::with_capacity(self.buf.len()); @@ -144,17 +161,7 @@ impl<'a> Future for ByteStreamReadExact<'a> { if *this.fail_on_eos { return Poll::Ready(Err(ReadExactError::UnexpectedEos)); } else { - let bytes = Bytes::from( - this.reader - .buf - .iter() - .map(|x| &x[..]) - .collect::>() - .concat(), - ); - this.reader.buf.clear(); - this.reader.buf_len = 0; - return Poll::Ready(Ok(bytes)); + return Poll::Ready(Ok(this.reader.take_buffer())); } } -- cgit v1.2.3 From 50627c206043edf5ce8755893944a6f17a54fb85 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:27:56 +0200 Subject: Add comment --- src/endpoint.rs | 3 +++ 1 file changed, 3 insertions(+) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index 7088879..31500aa 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -35,6 +35,9 @@ impl EndpointHandler for () { // ---- +/// This trait should be implemented by an object of your application +/// that can handle a message of type `M`, in the cases where it doesn't +/// care about attached stream in the request nor in the response. #[async_trait] pub trait EndpointHandler: Send + Sync where -- cgit v1.2.3 From b9df442f035e6648e80adf8f9bf86b4943508ae5 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:32:08 +0200 Subject: Small optimization --- src/endpoint.rs | 5 ++++- src/message.rs | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index 31500aa..588f7e3 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -52,7 +52,10 @@ where T: EndpointHandler, M: Message + 'static, { - async fn handle(self: &Arc, m: Req, from: NodeID) -> Resp { + async fn handle(self: &Arc, mut m: Req, from: NodeID) -> Resp { + // Immediately drop stream to avoid backpressure if a stream was sent + // (this will make all data sent to the stream be ignored immediately) + drop(m.take_stream()); Resp::new(EndpointHandler::handle(self, m.msg(), from).await) } } diff --git a/src/message.rs b/src/message.rs index ba06551..0ac4cb8 100644 --- a/src/message.rs +++ b/src/message.rs @@ -81,6 +81,10 @@ impl Req { } } + pub fn take_stream(&mut self) -> Option { + std::mem::replace(&mut self.body, BodyData::None).into_stream() + } + pub(crate) fn into_enc( self, prio: RequestPriority, -- cgit v1.2.3 From 67ea3a48fa1c9c462f1c4912c9658ad002d3336d Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:40:06 +0200 Subject: Add Resp::into_parts --- src/message.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index 0ac4cb8..8e4bc2f 100644 --- a/src/message.rs +++ b/src/message.rs @@ -63,10 +63,6 @@ pub struct Req { } impl Req { - pub fn msg(&self) -> &M { - &self.msg - } - pub fn with_fixed_body(self, b: Bytes) -> Self { Self { body: BodyData::Fixed(b), @@ -81,6 +77,10 @@ impl Req { } } + pub fn msg(&self) -> &M { + &self.msg + } + pub fn take_stream(&mut self) -> Option { std::mem::replace(&mut self.body, BodyData::None).into_stream() } @@ -216,6 +216,10 @@ impl Resp { self.msg } + pub fn into_parts(self) -> (M::Response, Option) { + (self.msg, self.body.into_stream()) + } + pub(crate) fn into_enc(self) -> Result { Ok(RespEnc::Success { msg: rmp_to_vec_all_named(&self.msg)?.into(), -- cgit v1.2.3 From aa1b29d41a680f9ae266ed4bdecec89db58226c1 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:44:48 +0200 Subject: Terminology: don't use the word "body" anymore, talk of "attached stream" --- src/endpoint.rs | 4 +-- src/message.rs | 83 +++++++++++++++++++++++++++------------------------------ 2 files changed, 42 insertions(+), 45 deletions(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index 588f7e3..a402fec 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -110,8 +110,8 @@ where /// Call this endpoint on a remote node (or on the local node, /// for that matter). This function invokes the full version that - /// allows to attach a streaming body to the request and to - /// receive such a body attached to the response. + /// allows to attach a stream to the request and to + /// receive such a stream attached to the response. pub async fn call_streaming( &self, target: &NodeID, diff --git a/src/message.rs b/src/message.rs index 8e4bc2f..2ed5c98 100644 --- a/src/message.rs +++ b/src/message.rs @@ -49,30 +49,27 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { // ---- /// The Req is a helper object used to create requests and attach them -/// a streaming body. If the body is a fixed Bytes and not a ByteStream, +/// a stream of data. If the stream is a fixed Bytes and not a ByteStream, /// Req is cheaply clonable to allow the request to be sent to different -/// peers (Clone will panic if the body is a ByteStream). -/// -/// Internally, this is also used to encode and decode requests -/// from/to byte streams to be sent over the network. +/// peers (Clone will panic if the stream is a ByteStream). pub struct Req { pub(crate) _phantom: PhantomData, pub(crate) msg: Arc, pub(crate) msg_ser: Option, - pub(crate) body: BodyData, + pub(crate) stream: AttachedStream, } impl Req { - pub fn with_fixed_body(self, b: Bytes) -> Self { + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { Self { - body: BodyData::Fixed(b), + stream: AttachedStream::Fixed(b), ..self } } - pub fn with_streaming_body(self, b: ByteStream) -> Self { + pub fn with_stream(self, b: ByteStream) -> Self { Self { - body: BodyData::Stream(b), + stream: AttachedStream::Stream(b), ..self } } @@ -82,7 +79,7 @@ impl Req { } pub fn take_stream(&mut self) -> Option { - std::mem::replace(&mut self.body, BodyData::None).into_stream() + std::mem::replace(&mut self.stream, AttachedStream::None).into_stream() } pub(crate) fn into_enc( @@ -96,7 +93,7 @@ impl Req { path, telemetry_id, msg: self.msg_ser.unwrap(), - stream: self.body.into_stream(), + stream: self.stream.into_stream(), } } @@ -106,7 +103,7 @@ impl Req { _phantom: Default::default(), msg: Arc::new(msg), msg_ser: Some(enc.msg), - body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None), + stream: enc.stream.map(AttachedStream::Stream).unwrap_or(AttachedStream::None), }) } } @@ -123,7 +120,7 @@ impl IntoReq for M { _phantom: Default::default(), msg: Arc::new(self), msg_ser: Some(Bytes::from(msg_ser)), - body: BodyData::None, + stream: AttachedStream::None, }) } fn into_req_local(self) -> Req { @@ -131,7 +128,7 @@ impl IntoReq for M { _phantom: Default::default(), msg: Arc::new(self), msg_ser: None, - body: BodyData::None, + stream: AttachedStream::None, } } } @@ -147,16 +144,16 @@ impl IntoReq for Req { impl Clone for Req { fn clone(&self) -> Self { - let body = match &self.body { - BodyData::None => BodyData::None, - BodyData::Fixed(b) => BodyData::Fixed(b.clone()), - BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"), + let stream = match &self.stream { + AttachedStream::None => AttachedStream::None, + AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()), + AttachedStream::Stream(_) => panic!("Cannot clone a Req<_> with a non-buffer attached stream"), }; Self { _phantom: Default::default(), msg: self.msg.clone(), msg_ser: self.msg_ser.clone(), - body, + stream, } } } @@ -167,10 +164,10 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(f, "Req[{:?}", self.msg)?; - match &self.body { - BodyData::None => write!(f, "]"), - BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), - BodyData::Stream(_) => write!(f, "; body=stream]"), + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), } } } @@ -178,11 +175,11 @@ where // ---- /// The Resp represents a full response from a RPC that may have -/// an attached body stream. +/// an attached stream. pub struct Resp { pub(crate) _phantom: PhantomData, pub(crate) msg: M::Response, - pub(crate) body: BodyData, + pub(crate) stream: AttachedStream, } impl Resp { @@ -190,20 +187,20 @@ impl Resp { Resp { _phantom: Default::default(), msg: v, - body: BodyData::None, + stream: AttachedStream::None, } } - pub fn with_fixed_body(self, b: Bytes) -> Self { + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { Self { - body: BodyData::Fixed(b), + stream: AttachedStream::Fixed(b), ..self } } - pub fn with_streaming_body(self, b: ByteStream) -> Self { + pub fn with_stream(self, b: ByteStream) -> Self { Self { - body: BodyData::Stream(b), + stream: AttachedStream::Stream(b), ..self } } @@ -217,13 +214,13 @@ impl Resp { } pub fn into_parts(self) -> (M::Response, Option) { - (self.msg, self.body.into_stream()) + (self.msg, self.stream.into_stream()) } pub(crate) fn into_enc(self) -> Result { Ok(RespEnc::Success { msg: rmp_to_vec_all_named(&self.msg)?.into(), - stream: self.body.into_stream(), + stream: self.stream.into_stream(), }) } @@ -234,7 +231,7 @@ impl Resp { Ok(Self { _phantom: Default::default(), msg, - body: stream.map(BodyData::Stream).unwrap_or(BodyData::None), + stream: stream.map(AttachedStream::Stream).unwrap_or(AttachedStream::None), }) } RespEnc::Error { code, message } => Err(Error::Remote(code, message)), @@ -249,28 +246,28 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(f, "Resp[{:?}", self.msg)?; - match &self.body { - BodyData::None => write!(f, "]"), - BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), - BodyData::Stream(_) => write!(f, "; body=stream]"), + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), } } } // ---- -pub(crate) enum BodyData { +pub(crate) enum AttachedStream { None, Fixed(Bytes), Stream(ByteStream), } -impl BodyData { +impl AttachedStream { pub fn into_stream(self) -> Option { match self { - BodyData::None => None, - BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), - BodyData::Stream(s) => Some(s), + AttachedStream::None => None, + AttachedStream::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + AttachedStream::Stream(s) => Some(s), } } } -- cgit v1.2.3 From 50358b944ae7ee4b4aa292ede8bc5d185c86df65 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 13:48:43 +0200 Subject: Cargo fmt; better adapt with_capacity_values --- src/message.rs | 21 +++++++++++++++------ src/send.rs | 18 ++---------------- src/stream.rs | 8 +------- 3 files changed, 18 insertions(+), 29 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index 2ed5c98..56e6e8e 100644 --- a/src/message.rs +++ b/src/message.rs @@ -103,7 +103,10 @@ impl Req { _phantom: Default::default(), msg: Arc::new(msg), msg_ser: Some(enc.msg), - stream: enc.stream.map(AttachedStream::Stream).unwrap_or(AttachedStream::None), + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), }) } } @@ -147,7 +150,9 @@ impl Clone for Req { let stream = match &self.stream { AttachedStream::None => AttachedStream::None, AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()), - AttachedStream::Stream(_) => panic!("Cannot clone a Req<_> with a non-buffer attached stream"), + AttachedStream::Stream(_) => { + panic!("Cannot clone a Req<_> with a non-buffer attached stream") + } }; Self { _phantom: Default::default(), @@ -231,7 +236,9 @@ impl Resp { Ok(Self { _phantom: Default::default(), msg, - stream: stream.map(AttachedStream::Stream).unwrap_or(AttachedStream::None), + stream: stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), }) } RespEnc::Error { code, message } => Err(Error::Remote(code, message)), @@ -293,7 +300,9 @@ pub(crate) struct ReqEnc { impl ReqEnc { pub(crate) fn encode(self) -> ByteStream { - let mut buf = BytesMut::with_capacity(64); + let mut buf = BytesMut::with_capacity( + self.path.len() + self.telemetry_id.len() + self.msg.len() + 16, + ); buf.put_u8(self.prio); @@ -375,7 +384,7 @@ impl RespEnc { pub(crate) fn encode(self) -> ByteStream { match self { RespEnc::Success { msg, stream } => { - let mut buf = BytesMut::with_capacity(64); + let mut buf = BytesMut::with_capacity(msg.len() + 8); buf.put_u8(0); @@ -391,7 +400,7 @@ impl RespEnc { } } RespEnc::Error { code, message } => { - let mut buf = BytesMut::with_capacity(64); + let mut buf = BytesMut::with_capacity(message.len() + 8); buf.put_u8(1 + message.len() as u8); buf.put_u8(code); buf.put(message.as_bytes()); diff --git a/src/send.rs b/src/send.rs index a8cf966..f1df6f7 100644 --- a/src/send.rs +++ b/src/send.rs @@ -30,7 +30,7 @@ pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueue { - items: VecDeque<(u8, VecDeque)>, + items: Vec<(u8, VecDeque)>, } struct SendQueueItem { @@ -42,7 +42,7 @@ struct SendQueueItem { impl SendQueue { fn new() -> Self { Self { - items: VecDeque::with_capacity(64), + items: Vec::with_capacity(64), } } fn push(&mut self, item: SendQueueItem) { @@ -56,20 +56,6 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] - fn pop(&mut self) -> Option { - match self.items.pop_front() { - None => None, - Some((prio, mut items_at_prio)) => { - let ret = items_at_prio.pop_front(); - if !items_at_prio.is_empty() { - self.items.push_front((prio, items_at_prio)); - } - ret.or_else(|| self.pop()) - } - } - } fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } diff --git a/src/stream.rs b/src/stream.rs index ae57d62..aa7641f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -83,13 +83,7 @@ impl ByteStreamReader { } pub fn take_buffer(&mut self) -> Bytes { - let bytes = Bytes::from( - self .buf - .iter() - .map(|x| &x[..]) - .collect::>() - .concat(), - ); + let bytes = Bytes::from(self.buf.iter().map(|x| &x[..]).collect::>().concat()); self.buf.clear(); self.buf_len = 0; bytes -- cgit v1.2.3 From 482566929385fab18300d205a89d0e23d977855b Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 14:38:03 +0200 Subject: Remove copy of serialized thing in encode --- src/message.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index 56e6e8e..629992d 100644 --- a/src/message.rs +++ b/src/message.rs @@ -313,14 +313,13 @@ impl ReqEnc { buf.put(&self.telemetry_id[..]); buf.put_u32(self.msg.len() as u32); - buf.put(&self.msg[..]); let header = buf.freeze(); if let Some(stream) = self.stream { - Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream)) } else { - Box::pin(futures::stream::once(async move { Ok(header) })) + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)])) } } @@ -387,16 +386,14 @@ impl RespEnc { let mut buf = BytesMut::with_capacity(msg.len() + 8); buf.put_u8(0); - buf.put_u32(msg.len() as u32); - buf.put(&msg[..]); let header = buf.freeze(); if let Some(stream) = stream { - Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream)) + Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream)) } else { - Box::pin(futures::stream::once(async move { Ok(header) })) + Box::pin(futures::stream::iter([Ok(header), Ok(msg)])) } } RespEnc::Error { code, message } => { -- cgit v1.2.3 From cbc21e40acfc420a3e452a1fd488c6a96694b0f2 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 14:45:28 +0200 Subject: Impose static lifetime on message and response --- src/endpoint.rs | 6 +++--- src/message.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index a402fec..bb768de 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -27,7 +27,7 @@ where /// use the unit type `()` as the handler type: /// it will panic if it is ever made to handle request. #[async_trait] -impl EndpointHandler for () { +impl EndpointHandler for () { async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { panic!("This endpoint should not have a local handler."); } @@ -50,7 +50,7 @@ where impl StreamingEndpointHandler for T where T: EndpointHandler, - M: Message + 'static, + M: Message, { async fn handle(self: &Arc, mut m: Req, from: NodeID) -> Resp { // Immediately drop stream to avoid backpressure if a stream was sent @@ -177,7 +177,7 @@ where #[async_trait] impl GenericEndpoint for EndpointArc where - M: Message + 'static, + M: Message, H: StreamingEndpointHandler + 'static, { async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result { diff --git a/src/message.rs b/src/message.rs index 629992d..ff9861c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -42,8 +42,8 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static; } // ---- -- cgit v1.2.3 From a0dac87e3b8b749afa63b5707eefeb676e23b622 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 15:16:50 +0200 Subject: Add Req::new --- src/message.rs | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index ff9861c..7cf1918 100644 --- a/src/message.rs +++ b/src/message.rs @@ -60,6 +60,10 @@ pub struct Req { } impl Req { + pub fn new(v: M) -> Result { + Ok(v.into_req()?) + } + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { Self { stream: AttachedStream::Fixed(b), -- cgit v1.2.3 From a5e5fd040891c02b1f88bdafdec9e92090094548 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 15:23:45 +0200 Subject: Bump netapp version to 0.5 --- src/netapp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/netapp.rs b/src/netapp.rs index 29df3b9..f1e14ed 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -36,7 +36,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag -pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 +pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005 #[derive(Serialize, Deserialize, Debug, Clone)] pub(crate) struct HelloMessage { -- cgit v1.2.3 From ab80ade4f0034cbdcf15a99c674730f85eb06870 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 22 Jul 2022 16:42:26 +0200 Subject: Conversion between ByteStream and AsyncRead --- src/stream.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/stream.rs b/src/stream.rs index aa7641f..f5607b3 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -5,7 +5,8 @@ use std::task::{Context, Poll}; use bytes::Bytes; use futures::Future; -use futures::{Stream, StreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; +use tokio::io::AsyncRead; /// A stream of associated data. /// @@ -18,6 +19,8 @@ pub type ByteStream = Pin + Send + Sync>>; pub type Packet = Result; +// ---- + pub struct ByteStreamReader { stream: ByteStream, buf: VecDeque, @@ -175,3 +178,49 @@ impl<'a> Future for ByteStreamReadExact<'a> { } } } + +// ---- + +fn u8_to_io_error(v: u8) -> std::io::Error { + use std::io::{Error, ErrorKind}; + let kind = match v { + 101 => ErrorKind::ConnectionAborted, + 102 => ErrorKind::BrokenPipe, + 103 => ErrorKind::WouldBlock, + 104 => ErrorKind::InvalidInput, + 105 => ErrorKind::InvalidData, + 106 => ErrorKind::TimedOut, + 107 => ErrorKind::Interrupted, + 108 => ErrorKind::UnexpectedEof, + 109 => ErrorKind::OutOfMemory, + 110 => ErrorKind::ConnectionReset, + _ => ErrorKind::Other, + }; + Error::new(kind, "(in netapp stream)") +} + +fn io_error_to_u8(e: std::io::Error) -> u8 { + use std::io::{ErrorKind}; + match e.kind() { + ErrorKind::ConnectionAborted => 101, + ErrorKind::BrokenPipe => 102, + ErrorKind::WouldBlock => 103, + ErrorKind::InvalidInput => 104, + ErrorKind::InvalidData => 105, + ErrorKind::TimedOut => 106, + ErrorKind::Interrupted => 107, + ErrorKind::UnexpectedEof => 108, + ErrorKind::OutOfMemory => 109, + ErrorKind::ConnectionReset => 110, + _ => 100, + } +} + +pub fn asyncread_stream(reader: R) -> ByteStream { + Box::pin(tokio_util::io::ReaderStream::new(reader) + .map_err(io_error_to_u8)) +} + +pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { + tokio_util::io::StreamReader::new(stream.map_err(u8_to_io_error)) +} -- cgit v1.2.3 From fed0542313824df295a7e322a9aebe8ba62f97b9 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 25 Jul 2022 10:58:55 +0200 Subject: Remove blocking_send that crashes --- src/recv.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/recv.rs b/src/recv.rs index 2be8728..e748f18 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -35,7 +35,9 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { if let Some(inner) = self.inner.take() { - let _ = inner.blocking_send(Err(255)); + tokio::spawn(async move { + let _ = inner.send(Err(255)).await; + }); } } } -- cgit v1.2.3 From c17a5f84ff078826084c3a990f1890461c817346 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 25 Jul 2022 11:06:51 +0200 Subject: Remove broken test --- src/send.rs | 95 ------------------------------------------------------------- 1 file changed, 95 deletions(-) (limited to 'src') diff --git a/src/send.rs b/src/send.rs index f1df6f7..46c4383 100644 --- a/src/send.rs +++ b/src/send.rs @@ -199,98 +199,3 @@ pub(crate) trait SendLoop: Sync { Ok(()) } } - -#[cfg(test)] -mod test { - use super::*; - - fn empty_data() -> DataReader { - type Item = Packet; - let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::()); - stream.into() - } - - #[test] - fn test_priority_queue() { - let i1 = SendQueueItem { - id: 1, - prio: PRIO_NORMAL, - data: empty_data(), - }; - let i2 = SendQueueItem { - id: 2, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i2bis = SendQueueItem { - id: 20, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i3 = SendQueueItem { - id: 3, - prio: PRIO_HIGH | PRIO_SECONDARY, - data: empty_data(), - }; - let i4 = SendQueueItem { - id: 4, - prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: empty_data(), - }; - let i5 = SendQueueItem { - id: 5, - prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: empty_data(), - }; - - let mut q = SendQueue::new(); - - q.push(i1); // 1 - let a = q.pop().unwrap(); // empty -> 1 - assert_eq!(a.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 1 - q.push(i2); // 2 1 - q.push(i2bis); // [2 20] 1 - let a = q.pop().unwrap(); // 20 1 -> 2 - assert_eq!(a.id, 2); - let b = q.pop().unwrap(); // 1 -> 20 - assert_eq!(b.id, 20); - let c = q.pop().unwrap(); // empty -> 1 - assert_eq!(c.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 2 - q.push(b); // [2 20] - q.push(c); // [2 20] 1 - q.push(i3); // [2 20] 3 1 - q.push(i4); // [2 20] 3 1 4 - q.push(i5); // [2 20] 3 1 5 4 - - let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 - assert_eq!(a.id, 2); - q.push(a); // [20 2] 3 1 5 4 - - let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 - assert_eq!(a.id, 20); - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - q.push(b); // 2 3 1 5 4 - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - let c = q.pop().unwrap(); // 1 5 4 -> 3 - assert_eq!(c.id, 3); - q.push(b); // 2 1 5 4 - let b = q.pop().unwrap(); // 1 5 4 -> 2 - assert_eq!(b.id, 2); - let e = q.pop().unwrap(); // 5 4 -> 1 - assert_eq!(e.id, 1); - let f = q.pop().unwrap(); // 4 -> 5 - assert_eq!(f.id, 5); - let g = q.pop().unwrap(); // empty -> 4 - assert_eq!(g.id, 4); - assert!(q.pop().is_none()); - } -} -- cgit v1.2.3 From 7499721a10d1d9e977024224b9d80d91ce93628b Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 25 Jul 2022 11:07:23 +0200 Subject: Cargo fmt --- src/stream.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/stream.rs b/src/stream.rs index f5607b3..beb6b9c 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -200,7 +200,7 @@ fn u8_to_io_error(v: u8) -> std::io::Error { } fn io_error_to_u8(e: std::io::Error) -> u8 { - use std::io::{ErrorKind}; + use std::io::ErrorKind; match e.kind() { ErrorKind::ConnectionAborted => 101, ErrorKind::BrokenPipe => 102, @@ -217,8 +217,7 @@ fn io_error_to_u8(e: std::io::Error) -> u8 { } pub fn asyncread_stream(reader: R) -> ByteStream { - Box::pin(tokio_util::io::ReaderStream::new(reader) - .map_err(io_error_to_u8)) + Box::pin(tokio_util::io::ReaderStream::new(reader).map_err(io_error_to_u8)) } pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { -- cgit v1.2.3 From 74e57016f63b6052cf6d539812859c3a46138eee Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 25 Jul 2022 15:04:52 +0200 Subject: Add some debugging --- src/client.rs | 9 ++++----- src/recv.rs | 19 +++++++++++++++---- src/send.rs | 46 +++++++++++++++++++++++++++++++++++----------- src/server.rs | 7 ++++--- src/stream.rs | 8 +++++--- 5 files changed, 63 insertions(+), 26 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index d51236b..2fccdb8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -179,10 +179,9 @@ impl ClientConn { }))); } - trace!( - "request: query_send {} (serialized message: {} bytes)", - id, - req_msg_len + debug!( + "request: query_send {}, path {}, prio {} (serialized message: {} bytes)", + id, path, prio, req_msg_len ); #[cfg(feature = "telemetry")] @@ -201,7 +200,7 @@ impl ClientConn { } let resp_enc = RespEnc::decode(stream).await?; - trace!("request response {}", id); + debug!("client: got response to request {} (path {})", id, path); Resp::from_enc(resp_enc) } } diff --git a/src/recv.rs b/src/recv.rs index e748f18..cba42cb 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -59,7 +59,6 @@ pub(crate) trait RecvLoop: Sync + 'static { { let mut streams: HashMap = HashMap::new(); loop { - trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; match read.read_exact(&mut header_id[..]).await { Ok(_) => (), @@ -67,22 +66,31 @@ pub(crate) trait RecvLoop: Sync + 'static { Err(e) => return Err(e.into()), }; let id = RequestID::from_be_bytes(header_id); - trace!("recv_loop: got header id: {:04x}", id); let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; read.read_exact(&mut header_size[..]).await?; let size = ChunkLength::from_be_bytes(header_size); - trace!("recv_loop: got header size: {:04x}", size); let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; let packet = if is_error { + trace!( + "recv_loop: got id {}, header_size {:04x}, error {}", + id, + size, + size & !ERROR_MARKER + ); Err((size & !ERROR_MARKER) as u8) } else { let size = size & !CHUNK_HAS_CONTINUATION; let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); + trace!( + "recv_loop: got id {}, header_size {:04x}, {} bytes", + id, + size, + next_slice.len() + ); Ok(Bytes::from(next_slice)) }; @@ -90,6 +98,7 @@ pub(crate) trait RecvLoop: Sync + 'static { send } else { let (send, recv) = mpsc::channel(4); + trace!("recv_loop: id {} is new channel", id); self.recv_handler( id, Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)), @@ -102,8 +111,10 @@ pub(crate) trait RecvLoop: Sync + 'static { let _ = sender.send(packet).await; if has_cont { + assert!(!is_error); streams.insert(id, sender); } else { + trace!("recv_loop: close channel id {}", id); sender.end(); } } diff --git a/src/send.rs b/src/send.rs index 46c4383..256fe4c 100644 --- a/src/send.rs +++ b/src/send.rs @@ -74,36 +74,54 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { type Output = (RequestID, DataFrame); fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - for i in 0..self.queue.items.len() { - let (_prio, items_at_prio) = &mut self.queue.items[i]; - + for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { let mut ready_item = None; for (j, item) in items_at_prio.iter_mut().enumerate() { let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); match Pin::new(&mut item_reader).poll(ctx) { Poll::Pending => (), Poll::Ready(ready_v) => { - ready_item = Some((j, ready_v, item.data.eos())); + ready_item = Some((j, ready_v)); break; } } } - if let Some((j, bytes_or_err, eos)) = ready_item { + if let Some((j, bytes_or_err)) = ready_item { + let item = items_at_prio.remove(j).unwrap(); + let id = item.id; + let eos = item.data.eos(); + let data_frame = match bytes_or_err { - Ok(bytes) => DataFrame::Data(bytes, !eos), + Ok(bytes) => { + trace!( + "send queue poll next ready: id {} eos {:?} bytes {}", + id, + eos, + bytes.len() + ); + DataFrame::Data(bytes, !eos) + } Err(e) => DataFrame::Error(match e { - ReadExactError::Stream(code) => code, + ReadExactError::Stream(code) => { + trace!( + "send queue poll next ready: id {} eos {:?} ERROR {}", + id, + eos, + code + ); + code + } _ => unreachable!(), }), }; - let item = items_at_prio.remove(j).unwrap(); - let id = item.id; - if !eos { + + if !eos && !matches!(data_frame, DataFrame::Error(_)) { items_at_prio.push_back(item); } else if items_at_prio.is_empty() { self.queue.items.remove(i); } + return Poll::Ready((id, data_frame)); } } @@ -173,6 +191,7 @@ pub(crate) trait SendLoop: Sync { match futures::future::select(recv_fut, send_fut).await { Either::Left((sth, _send_fut)) => { if let Some((id, prio, data)) = sth { + trace!("send_loop: add stream {} to send", id); sending.push(SendQueueItem { id, prio, @@ -183,7 +202,12 @@ pub(crate) trait SendLoop: Sync { }; } Either::Right(((id, data), _recv_fut)) => { - trace!("send_loop: sending bytes for {}", id); + trace!( + "send_loop: id {}, send {} bytes, header_size {}", + id, + data.data().len(), + hex::encode(data.header()) + ); let header_id = RequestID::to_be_bytes(id); write.write_all(&header_id[..]).await?; diff --git a/src/server.rs b/src/server.rs index 4b232af..57062d8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use log::{debug, trace}; +use log::*; use futures::io::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::{handshake_server, BoxStream}; @@ -175,7 +175,8 @@ impl RecvLoop for ServerConn { let self2 = self.clone(); tokio::spawn(async move { - trace!("ServerConn recv_handler {}", id); + debug!("server: recv_handler got {}", id); + let (prio, resp_enc) = match ReqEnc::decode(stream).await { Ok(req_enc) => { let prio = req_enc.prio; @@ -192,7 +193,7 @@ impl RecvLoop for ServerConn { Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), }; - trace!("ServerConn sending response to {}: ", id); + debug!("server: sending response to {}", id); resp_send .send((id, prio, resp_enc.encode())) diff --git a/src/stream.rs b/src/stream.rs index beb6b9c..5ba2ed4 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -93,7 +93,7 @@ impl ByteStreamReader { } pub fn eos(&self) -> bool { - self.buf.is_empty() && self.eos + self.buf_len == 0 && self.eos } fn try_get(&mut self, read_len: usize) -> Option { @@ -164,8 +164,10 @@ impl<'a> Future for ByteStreamReadExact<'a> { match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { Some(Ok(slice)) => { - this.reader.buf_len += slice.len(); - this.reader.buf.push_back(slice); + if !slice.is_empty() { + this.reader.buf_len += slice.len(); + this.reader.buf.push_back(slice); + } } Some(Err(e)) => { this.reader.err = Some(e); -- cgit v1.2.3 From b55f61c38b01da01314d99ced543aba713dbd2a9 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 26 Jul 2022 12:11:48 +0200 Subject: Fix things going wrong when sending chan is closed --- src/recv.rs | 7 ++++++- src/send.rs | 36 +++++++++++++++++++++++++----------- 2 files changed, 31 insertions(+), 12 deletions(-) (limited to 'src') diff --git a/src/recv.rs b/src/recv.rs index cba42cb..4d1047b 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; -use log::trace; +use log::*; use futures::AsyncReadExt; use tokio::sync::mpsc; @@ -59,6 +59,11 @@ pub(crate) trait RecvLoop: Sync + 'static { { let mut streams: HashMap = HashMap::new(); loop { + debug!( + "Receiving: {:?}", + streams.iter().map(|(id, _)| id).collect::>() + ); + let mut header_id = [0u8; RequestID::BITS as usize / 8]; match read.read_exact(&mut header_id[..]).await { Ok(_) => (), diff --git a/src/send.rs b/src/send.rs index 256fe4c..fd415c6 100644 --- a/src/send.rs +++ b/src/send.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use async_trait::async_trait; use bytes::Bytes; -use log::trace; +use log::*; use futures::AsyncWriteExt; use kuska_handshake::async_std::BoxStreamWrite; @@ -172,24 +172,38 @@ impl DataFrame { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>, + msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where W: AsyncWriteExt + Unpin + Send + Sync, { let mut sending = SendQueue::new(); - let mut should_exit = false; - while !should_exit || !sending.is_empty() { - let recv_fut = msg_recv.recv(); - futures::pin_mut!(recv_fut); + let mut msg_recv = Some(msg_recv); + while msg_recv.is_some() || !sending.is_empty() { + debug!( + "Sending: {:?}", + sending + .items + .iter() + .map(|(_, i)| i.iter().map(|x| x.id)) + .flatten() + .collect::>() + ); + + let recv_fut = async { + if let Some(chan) = &mut msg_recv { + chan.recv().await + } else { + futures::future::pending().await + } + }; let send_fut = sending.next_ready(); // recv_fut is cancellation-safe according to tokio doc, // send_fut is cancellation-safe as implemented above? - use futures::future::Either; - match futures::future::select(recv_fut, send_fut).await { - Either::Left((sth, _send_fut)) => { + tokio::select! { + sth = recv_fut => { if let Some((id, prio, data)) = sth { trace!("send_loop: add stream {} to send", id); sending.push(SendQueueItem { @@ -198,10 +212,10 @@ pub(crate) trait SendLoop: Sync { data: ByteStreamReader::new(data), }); } else { - should_exit = true; + msg_recv = None; }; } - Either::Right(((id, data), _recv_fut)) => { + (id, data) = send_fut => { trace!( "send_loop: id {}, send {} bytes, header_size {}", id, -- cgit v1.2.3 From 2c9d595da03ae7a95e962cea78e68afff7410cc5 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 31 Aug 2022 22:19:40 +0200 Subject: Remove useless phantom and pub(crate) --- src/message.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index 7cf1918..61d01d0 100644 --- a/src/message.rs +++ b/src/message.rs @@ -53,7 +53,6 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static /// Req is cheaply clonable to allow the request to be sent to different /// peers (Clone will panic if the stream is a ByteStream). pub struct Req { - pub(crate) _phantom: PhantomData, pub(crate) msg: Arc, pub(crate) msg_ser: Option, pub(crate) stream: AttachedStream, @@ -104,7 +103,6 @@ impl Req { pub(crate) fn from_enc(enc: ReqEnc) -> Result { let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; Ok(Req { - _phantom: Default::default(), msg: Arc::new(msg), msg_ser: Some(enc.msg), stream: enc @@ -124,7 +122,6 @@ impl IntoReq for M { fn into_req(self) -> Result, rmp_serde::encode::Error> { let msg_ser = rmp_to_vec_all_named(&self)?; Ok(Req { - _phantom: Default::default(), msg: Arc::new(self), msg_ser: Some(Bytes::from(msg_ser)), stream: AttachedStream::None, @@ -132,7 +129,6 @@ impl IntoReq for M { } fn into_req_local(self) -> Req { Req { - _phantom: Default::default(), msg: Arc::new(self), msg_ser: None, stream: AttachedStream::None, @@ -159,7 +155,6 @@ impl Clone for Req { } }; Self { - _phantom: Default::default(), msg: self.msg.clone(), msg_ser: self.msg_ser.clone(), stream, @@ -331,7 +326,7 @@ impl ReqEnc { Self::decode_aux(stream).await.map_err(|_| Error::Framing) } - pub(crate) async fn decode_aux(stream: ByteStream) -> Result { + async fn decode_aux(stream: ByteStream) -> Result { let mut reader = ByteStreamReader::new(stream); let prio = reader.read_u8().await?; @@ -415,7 +410,7 @@ impl RespEnc { Self::decode_aux(stream).await.map_err(|_| Error::Framing) } - pub(crate) async fn decode_aux(stream: ByteStream) -> Result { + async fn decode_aux(stream: ByteStream) -> Result { let mut reader = ByteStreamReader::new(stream); let is_err = reader.read_u8().await?; -- cgit v1.2.3 From 3fd30c6e280fba41377c8b563352d756e8bc1caf Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 09:45:24 +0200 Subject: recv side: use unbounded channel to remove deadlock --- src/recv.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) (limited to 'src') diff --git a/src/recv.rs b/src/recv.rs index 4d1047b..3bea709 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -15,16 +15,16 @@ use crate::stream::*; /// Structure to warn when the sender is dropped before end of stream was reached, like when /// connection to some remote drops while transmitting data struct Sender { - inner: Option>, + inner: Option>, } impl Sender { - fn new(inner: mpsc::Sender) -> Self { + fn new(inner: mpsc::UnboundedSender) -> Self { Sender { inner: Some(inner) } } - async fn send(&self, packet: Packet) { - let _ = self.inner.as_ref().unwrap().send(packet).await; + fn send(&self, packet: Packet) { + let _ = self.inner.as_ref().unwrap().send(packet); } fn end(&mut self) { @@ -35,9 +35,7 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { if let Some(inner) = self.inner.take() { - tokio::spawn(async move { - let _ = inner.send(Err(255)).await; - }); + let _ = inner.send(Err(255)); } } } @@ -102,18 +100,18 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { - let (send, recv) = mpsc::channel(4); + let (send, recv) = mpsc::unbounded_channel(); trace!("recv_loop: id {} is new channel", id); self.recv_handler( id, - Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)), + Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(recv)), ); Sender::new(send) }; // If we get an error, the receiving end is disconnected. // We still need to reach eos before dropping this sender - let _ = sender.send(packet).await; + let _ = sender.send(packet); if has_cont { assert!(!is_error); -- cgit v1.2.3 From 263db66fcee65deda39de18baa837228ea38baf1 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 10:29:26 +0200 Subject: Refactor: create a BytesBuf utility crate (will also be usefull in Garage) --- src/bytes_buf.rs | 167 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/stream.rs | 52 +++-------------- 3 files changed, 177 insertions(+), 43 deletions(-) create mode 100644 src/bytes_buf.rs (limited to 'src') diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs new file mode 100644 index 0000000..46c7039 --- /dev/null +++ b/src/bytes_buf.rs @@ -0,0 +1,167 @@ +use std::collections::VecDeque; + +pub use bytes::Bytes; + +/// A circular buffer of bytes, internally represented as a list of Bytes +/// for optimization, but that for all intent and purposes acts just like +/// a big byte slice which can be extended on the right and from which +/// one can take on the left. +pub struct BytesBuf { + buf: VecDeque, + buf_len: usize, +} + +impl BytesBuf { + /// Creates a new empty BytesBuf + pub fn new() -> Self { + Self { + buf: VecDeque::new(), + buf_len: 0, + } + } + + /// Returns the number of bytes stored in the BytesBuf + #[inline] + pub fn len(&self) -> usize { + self.buf_len + } + + /// Returns true iff the BytesBuf contains zero bytes + #[inline] + pub fn is_empty(&self) -> bool { + self.buf_len == 0 + } + + /// Adds some bytes to the right of the buffer + pub fn extend(&mut self, b: Bytes) { + if !b.is_empty() { + self.buf_len += b.len(); + self.buf.push_back(b); + } + } + + /// Takes the whole content of the buffer and returns it as a single Bytes unit + pub fn take_all(&mut self) -> Bytes { + if self.buf.len() == 0 { + Bytes::new() + } else if self.buf.len() == 1 { + self.buf_len = 0; + self.buf.pop_back().unwrap() + } else { + let mut ret = Vec::with_capacity(self.buf_len); + for b in self.buf.iter() { + ret.extend(&b[..]); + } + self.buf.clear(); + self.buf_len = 0; + Bytes::from(ret) + } + } + + /// Takes at most max_len bytes from the left of the buffer + pub fn take_max(&mut self, max_len: usize) -> Bytes { + if self.buf_len <= max_len { + self.take_all() + } else { + self.take_exact_ok(max_len) + } + } + + /// Take exactly len bytes from the left of the buffer, returns None if + /// the BytesBuf doesn't contain enough data + pub fn take_exact(&mut self, len: usize) -> Option { + if self.buf_len < len { + None + } else { + Some(self.take_exact_ok(len)) + } + } + + fn take_exact_ok(&mut self, len: usize) -> Bytes { + assert!(len <= self.buf_len); + let front = self.buf.pop_front().unwrap(); + if front.len() > len { + self.buf.push_front(front.slice(len..)); + self.buf_len -= len; + front.slice(..len) + } else if front.len() == len { + self.buf_len -= len; + front + } else { + let mut ret = Vec::with_capacity(len); + ret.extend(&front[..]); + self.buf_len -= front.len(); + while ret.len() < len { + let front = self.buf.pop_front().unwrap(); + if front.len() > len - ret.len() { + let take = len - ret.len(); + ret.extend(front.slice(..take)); + self.buf.push_front(front.slice(take..)); + self.buf_len -= take; + break; + } else { + ret.extend(&front[..]); + self.buf_len -= front.len(); + } + } + Bytes::from(ret) + } + } + + /// Return the internal sequence of Bytes slices that make up the buffer + pub fn into_slices(self) -> VecDeque { + self.buf + } +} + +impl From for BytesBuf { + fn from(b: Bytes) -> BytesBuf { + let mut ret = BytesBuf::new(); + ret.extend(b); + ret + } +} + +impl From for Bytes { + fn from(mut b: BytesBuf) -> Bytes { + b.take_all() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_bytes_buf() { + let mut buf = BytesBuf::new(); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 13); + assert!(!buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!(buf.take_all(), Bytes::from(b"Hello, world!1234567890".to_vec())); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!(buf.take_max(12), Bytes::from(b"1234567890He".to_vec())); + assert!(buf.len() == 11); + + assert_eq!(buf.take_exact(12), None); + assert!(buf.len() == 11); + assert_eq!(buf.take_exact(11), Some(Bytes::from(b"llo, world!".to_vec()))); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + } +} diff --git a/src/lib.rs b/src/lib.rs index bd41048..18091c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ pub mod error; pub mod stream; pub mod util; +pub mod bytes_buf; pub mod endpoint; pub mod message; diff --git a/src/stream.rs b/src/stream.rs index 5ba2ed4..cc664ce 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,4 +1,3 @@ -use std::collections::VecDeque; use std::pin::Pin; use std::task::{Context, Poll}; @@ -8,6 +7,8 @@ use futures::Future; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::io::AsyncRead; +use crate::bytes_buf::BytesBuf; + /// A stream of associated data. /// /// When sent through Netapp, the Vec may be split in smaller chunk in such a way @@ -23,8 +24,7 @@ pub type Packet = Result; pub struct ByteStreamReader { stream: ByteStream, - buf: VecDeque, - buf_len: usize, + buf: BytesBuf, eos: bool, err: Option, } @@ -33,8 +33,7 @@ impl ByteStreamReader { pub fn new(stream: ByteStream) -> Self { ByteStreamReader { stream, - buf: VecDeque::with_capacity(8), - buf_len: 0, + buf: BytesBuf::new(), eos: false, err: None, } @@ -75,7 +74,7 @@ impl ByteStreamReader { } pub fn into_stream(self) -> ByteStream { - let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok)); + let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok)); if let Some(err) = self.err { Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) } else if self.eos { @@ -86,45 +85,15 @@ impl ByteStreamReader { } pub fn take_buffer(&mut self) -> Bytes { - let bytes = Bytes::from(self.buf.iter().map(|x| &x[..]).collect::>().concat()); - self.buf.clear(); - self.buf_len = 0; - bytes + self.buf.take_all() } pub fn eos(&self) -> bool { - self.buf_len == 0 && self.eos + self.buf.is_empty() && self.eos } fn try_get(&mut self, read_len: usize) -> Option { - if self.buf_len >= read_len { - let mut slices = Vec::with_capacity(self.buf.len()); - let mut taken = 0; - while taken < read_len { - let front = self.buf.pop_front().unwrap(); - if taken + front.len() <= read_len { - taken += front.len(); - self.buf_len -= front.len(); - slices.push(front); - } else { - let front_take = read_len - taken; - slices.push(front.slice(..front_take)); - self.buf.push_front(front.slice(front_take..)); - self.buf_len -= front_take; - break; - } - } - Some( - slices - .iter() - .map(|x| &x[..]) - .collect::>() - .concat() - .into(), - ) - } else { - None - } + self.buf.take_exact(read_len) } } @@ -164,10 +133,7 @@ impl<'a> Future for ByteStreamReadExact<'a> { match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { Some(Ok(slice)) => { - if !slice.is_empty() { - this.reader.buf_len += slice.len(); - this.reader.buf.push_back(slice); - } + this.reader.buf.extend(slice); } Some(Err(e)) => { this.reader.err = Some(e); -- cgit v1.2.3 From 7909a95d3c02a738c9a088c1cb8a5d6f70b06046 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 11:21:24 +0200 Subject: Stream errors are now std::io::Error --- src/client.rs | 2 +- src/recv.rs | 38 ++++++++++++++++++++----------------- src/send.rs | 61 +++++++++++++++++++++++++++++++---------------------------- src/stream.rs | 17 ++++++++++------- 4 files changed, 64 insertions(+), 54 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 2fccdb8..0dcbdf1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -175,7 +175,7 @@ impl ClientConn { "Too many inflight requests! RequestID collision. Interrupting previous request." ); let _ = old_ch.send(Box::pin(futures::stream::once(async move { - Err(Error::IdCollision.code()) + Err(std::io::Error::new(std::io::ErrorKind::Other, "RequestID collision, too many inflight requests")) }))); } diff --git a/src/recv.rs b/src/recv.rs index 3bea709..f8d68da 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -35,7 +35,7 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { if let Some(inner) = self.inner.take() { - let _ = inner.send(Err(255)); + let _ = inner.send(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Netapp connection dropped before end of stream"))); } } } @@ -76,25 +76,26 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; + let size = (size & CHUNK_LENGTH_MASK) as usize; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + let packet = if is_error { - trace!( - "recv_loop: got id {}, header_size {:04x}, error {}", - id, - size, - size & !ERROR_MARKER - ); - Err((size & !ERROR_MARKER) as u8) + let msg = String::from_utf8(next_slice).unwrap_or("".into()); + debug!("recv_loop: got id {}, error: {}", id, msg); + Some(Err(std::io::Error::new(std::io::ErrorKind::Other, msg))) } else { - let size = size & !CHUNK_HAS_CONTINUATION; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; trace!( - "recv_loop: got id {}, header_size {:04x}, {} bytes", + "recv_loop: got id {}, size {}, has_cont {}", id, size, - next_slice.len() + has_cont ); - Ok(Bytes::from(next_slice)) + if !next_slice.is_empty() { + Some(Ok(Bytes::from(next_slice))) + } else { + None + } }; let mut sender = if let Some(send) = streams.remove(&(id)) { @@ -109,9 +110,12 @@ pub(crate) trait RecvLoop: Sync + 'static { Sender::new(send) }; - // If we get an error, the receiving end is disconnected. - // We still need to reach eos before dropping this sender - let _ = sender.send(packet); + if let Some(packet) = packet { + // If we cannot put packet in channel, it means that the + // receiving end of the channel is disconnected. + // We still need to reach eos before dropping this sender + let _ = sender.send(packet); + } if has_cont { assert!(!is_error); diff --git a/src/send.rs b/src/send.rs index fd415c6..f362962 100644 --- a/src/send.rs +++ b/src/send.rs @@ -18,9 +18,11 @@ use crate::stream::*; // Messages are sent by chunks // Chunk format: // - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data +// - u16 BE: chunk length + flags: +// CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream +// ERROR_MARKER if this chunk denotes an error +// (these two flags are exclusive, an error denotes the end of the stream) +// - [u8; chunk_length] chunk data / error message pub(crate) type RequestID = u32; pub(crate) type ChunkLength = u16; @@ -28,6 +30,7 @@ pub(crate) type ChunkLength = u16; pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; +pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; struct SendQueue { items: Vec<(u8, VecDeque)>, @@ -92,29 +95,12 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { let id = item.id; let eos = item.data.eos(); - let data_frame = match bytes_or_err { - Ok(bytes) => { - trace!( - "send queue poll next ready: id {} eos {:?} bytes {}", - id, - eos, - bytes.len() - ); - DataFrame::Data(bytes, !eos) - } - Err(e) => DataFrame::Error(match e { - ReadExactError::Stream(code) => { - trace!( - "send queue poll next ready: id {} eos {:?} ERROR {}", - id, - eos, - code - ); - code - } - _ => unreachable!(), - }), - }; + let packet = bytes_or_err.map_err(|e| match e { + ReadExactError::Stream(err) => err, + _ => unreachable!(), + }); + + let data_frame = DataFrame::from_packet(packet, !eos); if !eos && !matches!(data_frame, DataFrame::Error(_)) { items_at_prio.push_back(item); @@ -139,15 +125,32 @@ enum DataFrame { /// (albeit sub-optimal) to set it to true if there is nothing coming after Data(Bytes, bool), /// An error code automatically signals the end of the stream - Error(u8), + Error(Bytes), } impl DataFrame { + fn from_packet(p: Packet, has_cont: bool) -> Self { + match p { + Ok(bytes) => { + assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize); + Self::Data(bytes, has_cont) + } + Err(e) => { + let msg = format!("{}", e); + let mut msg = Bytes::from(msg.into_bytes()); + if msg.len() > MAX_CHUNK_LENGTH as usize { + msg = msg.slice(..MAX_CHUNK_LENGTH as usize); + } + Self::Error(msg) + } + } + } + fn header(&self) -> [u8; 2] { let header_u16 = match self { DataFrame::Data(data, false) => data.len() as u16, DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION, - DataFrame::Error(e) => *e as u16 | ERROR_MARKER, + DataFrame::Error(msg) => msg.len() as u16 | ERROR_MARKER, }; ChunkLength::to_be_bytes(header_u16) } @@ -155,7 +158,7 @@ impl DataFrame { fn data(&self) -> &[u8] { match self { DataFrame::Data(ref data, _) => &data[..], - DataFrame::Error(_) => &[], + DataFrame::Error(ref msg) => &msg[..], } } } diff --git a/src/stream.rs b/src/stream.rs index cc664ce..3518246 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,7 +4,7 @@ use std::task::{Context, Poll}; use bytes::Bytes; use futures::Future; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt}; use tokio::io::AsyncRead; use crate::bytes_buf::BytesBuf; @@ -18,7 +18,7 @@ use crate::bytes_buf::BytesBuf; /// meaning, it's up to your application to define their semantic. pub type ByteStream = Pin + Send + Sync>>; -pub type Packet = Result; +pub type Packet = Result; // ---- @@ -26,7 +26,7 @@ pub struct ByteStreamReader { stream: ByteStream, buf: BytesBuf, eos: bool, - err: Option, + err: Option, } impl ByteStreamReader { @@ -99,7 +99,7 @@ impl ByteStreamReader { pub enum ReadExactError { UnexpectedEos, - Stream(u8), + Stream(std::io::Error), } #[pin_project::pin_project] @@ -120,7 +120,8 @@ impl<'a> Future for ByteStreamReadExact<'a> { if let Some(bytes) = this.reader.try_get(*this.read_len) { return Poll::Ready(Ok(bytes)); } - if let Some(err) = this.reader.err { + if let Some(err) = &this.reader.err { + let err = std::io::Error::new(err.kind(), format!("{}", err)); return Poll::Ready(Err(ReadExactError::Stream(err))); } if this.reader.eos { @@ -149,6 +150,7 @@ impl<'a> Future for ByteStreamReadExact<'a> { // ---- +/* fn u8_to_io_error(v: u8) -> std::io::Error { use std::io::{Error, ErrorKind}; let kind = match v { @@ -183,11 +185,12 @@ fn io_error_to_u8(e: std::io::Error) -> u8 { _ => 100, } } +*/ pub fn asyncread_stream(reader: R) -> ByteStream { - Box::pin(tokio_util::io::ReaderStream::new(reader).map_err(io_error_to_u8)) + Box::pin(tokio_util::io::ReaderStream::new(reader)) } pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { - tokio_util::io::StreamReader::new(stream.map_err(u8_to_io_error)) + tokio_util::io::StreamReader::new(stream) } -- cgit v1.2.3 From 745c78618479c4177647e4d7fed97d5fd2d00d4f Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 11:34:53 +0200 Subject: Also encode errorkind in stream --- src/error.rs | 36 ++++++++++++++++++++++++++++++++++++ src/recv.rs | 7 ++++--- src/send.rs | 23 ++++++++++++++++------- src/stream.rs | 36 ------------------------------------ 4 files changed, 56 insertions(+), 46 deletions(-) (limited to 'src') diff --git a/src/error.rs b/src/error.rs index f374341..2fa4594 100644 --- a/src/error.rs +++ b/src/error.rs @@ -109,3 +109,39 @@ where } } } + +// ---- Helpers for serializing I/O Errors + +pub(crate) fn u8_to_io_errorkind(v: u8) -> std::io::ErrorKind { + use std::io::ErrorKind; + match v { + 101 => ErrorKind::ConnectionAborted, + 102 => ErrorKind::BrokenPipe, + 103 => ErrorKind::WouldBlock, + 104 => ErrorKind::InvalidInput, + 105 => ErrorKind::InvalidData, + 106 => ErrorKind::TimedOut, + 107 => ErrorKind::Interrupted, + 108 => ErrorKind::UnexpectedEof, + 109 => ErrorKind::OutOfMemory, + 110 => ErrorKind::ConnectionReset, + _ => ErrorKind::Other, + } +} + +pub(crate) fn io_errorkind_to_u8(kind: std::io::ErrorKind) -> u8 { + use std::io::ErrorKind; + match kind { + ErrorKind::ConnectionAborted => 101, + ErrorKind::BrokenPipe => 102, + ErrorKind::WouldBlock => 103, + ErrorKind::InvalidInput => 104, + ErrorKind::InvalidData => 105, + ErrorKind::TimedOut => 106, + ErrorKind::Interrupted => 107, + ErrorKind::UnexpectedEof => 108, + ErrorKind::OutOfMemory => 109, + ErrorKind::ConnectionReset => 110, + _ => 100, + } +} diff --git a/src/recv.rs b/src/recv.rs index f8d68da..f8606f3 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -81,9 +81,10 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut next_slice[..]).await?; let packet = if is_error { - let msg = String::from_utf8(next_slice).unwrap_or("".into()); - debug!("recv_loop: got id {}, error: {}", id, msg); - Some(Err(std::io::Error::new(std::io::ErrorKind::Other, msg))) + let kind = u8_to_io_errorkind(next_slice[0]); + let msg = std::str::from_utf8(&next_slice[1..]).unwrap_or(""); + debug!("recv_loop: got id {}, error {:?}: {}", id, kind, msg); + Some(Err(std::io::Error::new(kind, msg.to_string()))) } else { trace!( "recv_loop: got id {}, size {}, has_cont {}", diff --git a/src/send.rs b/src/send.rs index f362962..287fe40 100644 --- a/src/send.rs +++ b/src/send.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; -use bytes::Bytes; +use bytes::{Bytes, BytesMut, BufMut}; use log::*; use futures::AsyncWriteExt; @@ -22,7 +22,11 @@ use crate::stream::*; // CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream // ERROR_MARKER if this chunk denotes an error // (these two flags are exclusive, an error denotes the end of the stream) -// - [u8; chunk_length] chunk data / error message +// - [u8; chunk_length], either +// - if not error: chunk data +// - if error: +// - u8: error kind, encoded using error::io_errorkind_to_u8 +// - rest: error message pub(crate) type RequestID = u32; pub(crate) type ChunkLength = u16; @@ -136,12 +140,17 @@ impl DataFrame { Self::Data(bytes, has_cont) } Err(e) => { - let msg = format!("{}", e); - let mut msg = Bytes::from(msg.into_bytes()); - if msg.len() > MAX_CHUNK_LENGTH as usize { - msg = msg.slice(..MAX_CHUNK_LENGTH as usize); + let mut buf = BytesMut::new(); + buf.put_u8(io_errorkind_to_u8(e.kind())); + + let msg = format!("{}", e).into_bytes(); + if msg.len() > (MAX_CHUNK_LENGTH - 1) as usize { + buf.put(&msg[..(MAX_CHUNK_LENGTH - 1) as usize]); + } else { + buf.put(&msg[..]); } - Self::Error(msg) + + Self::Error(buf.freeze()) } } } diff --git a/src/stream.rs b/src/stream.rs index 3518246..6e00e5f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -150,42 +150,6 @@ impl<'a> Future for ByteStreamReadExact<'a> { // ---- -/* -fn u8_to_io_error(v: u8) -> std::io::Error { - use std::io::{Error, ErrorKind}; - let kind = match v { - 101 => ErrorKind::ConnectionAborted, - 102 => ErrorKind::BrokenPipe, - 103 => ErrorKind::WouldBlock, - 104 => ErrorKind::InvalidInput, - 105 => ErrorKind::InvalidData, - 106 => ErrorKind::TimedOut, - 107 => ErrorKind::Interrupted, - 108 => ErrorKind::UnexpectedEof, - 109 => ErrorKind::OutOfMemory, - 110 => ErrorKind::ConnectionReset, - _ => ErrorKind::Other, - }; - Error::new(kind, "(in netapp stream)") -} - -fn io_error_to_u8(e: std::io::Error) -> u8 { - use std::io::ErrorKind; - match e.kind() { - ErrorKind::ConnectionAborted => 101, - ErrorKind::BrokenPipe => 102, - ErrorKind::WouldBlock => 103, - ErrorKind::InvalidInput => 104, - ErrorKind::InvalidData => 105, - ErrorKind::TimedOut => 106, - ErrorKind::Interrupted => 107, - ErrorKind::UnexpectedEof => 108, - ErrorKind::OutOfMemory => 109, - ErrorKind::ConnectionReset => 110, - _ => 100, - } -} -*/ pub fn asyncread_stream(reader: R) -> ByteStream { Box::pin(tokio_util::io::ReaderStream::new(reader)) -- cgit v1.2.3 From cd203f5708907c2bf172a3c5b7c5b40e2557b2f4 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 12:15:50 +0200 Subject: Add OrderTag to Req and Resp, refactor errors --- src/bytes_buf.rs | 10 +++- src/client.rs | 9 ++- src/error.rs | 27 +-------- src/lib.rs | 2 +- src/message.rs | 167 +++++++++++++++++++++++++++++++++---------------------- src/recv.rs | 8 ++- src/send.rs | 6 +- src/server.rs | 22 ++------ src/stream.rs | 1 - 9 files changed, 136 insertions(+), 116 deletions(-) (limited to 'src') diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs index 46c7039..857be9d 100644 --- a/src/bytes_buf.rs +++ b/src/bytes_buf.rs @@ -146,7 +146,10 @@ mod test { assert!(buf.len() == 23); assert!(!buf.is_empty()); - assert_eq!(buf.take_all(), Bytes::from(b"Hello, world!1234567890".to_vec())); + assert_eq!( + buf.take_all(), + Bytes::from(b"Hello, world!1234567890".to_vec()) + ); assert!(buf.len() == 0); assert!(buf.is_empty()); @@ -160,7 +163,10 @@ mod test { assert_eq!(buf.take_exact(12), None); assert!(buf.len() == 11); - assert_eq!(buf.take_exact(11), Some(Bytes::from(b"llo, world!".to_vec()))); + assert_eq!( + buf.take_exact(11), + Some(Bytes::from(b"llo, world!".to_vec())) + ); assert!(buf.len() == 0); assert!(buf.is_empty()); } diff --git a/src/client.rs b/src/client.rs index 0dcbdf1..aef7bbb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -35,7 +35,7 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, inflight: Mutex>>, @@ -165,7 +165,7 @@ impl ClientConn { // Encode request let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); let req_msg_len = req_enc.msg.len(); - let req_stream = req_enc.encode(); + let (req_stream, req_order) = req_enc.encode(); // Send request through let (resp_send, resp_recv) = oneshot::channel(); @@ -175,7 +175,10 @@ impl ClientConn { "Too many inflight requests! RequestID collision. Interrupting previous request." ); let _ = old_ch.send(Box::pin(futures::stream::once(async move { - Err(std::io::Error::new(std::io::ErrorKind::Other, "RequestID collision, too many inflight requests")) + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "RequestID collision, too many inflight requests", + )) }))); } diff --git a/src/error.rs b/src/error.rs index 2fa4594..c0aeeac 100644 --- a/src/error.rs +++ b/src/error.rs @@ -28,6 +28,9 @@ pub enum Error { #[error(display = "Framing protocol error")] Framing, + #[error(display = "Remote error ({:?}): {}", _0, _1)] + Remote(io::ErrorKind, String), + #[error(display = "Request ID collision")] IdCollision, @@ -42,30 +45,6 @@ pub enum Error { #[error(display = "Version mismatch: {}", _0)] VersionMismatch(String), - - #[error(display = "Remote error {}: {}", _0, _1)] - Remote(u8, String), -} - -impl Error { - pub fn code(&self) -> u8 { - match self { - Self::Io(_) => 100, - Self::TokioJoin(_) => 110, - Self::OneshotRecv(_) => 111, - Self::RMPEncode(_) => 10, - Self::RMPDecode(_) => 11, - Self::UTF8(_) => 12, - Self::Framing => 13, - Self::NoHandler => 20, - Self::ConnectionClosed => 21, - Self::IdCollision => 22, - Self::Handshake(_) => 30, - Self::VersionMismatch(_) => 31, - Self::Remote(c, _) => *c, - Self::Message(_) => 99, - } - } } impl From> for Error { diff --git a/src/lib.rs b/src/lib.rs index 18091c8..8e30e40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,10 +13,10 @@ //! about message priorization. //! Also check out the examples to learn how to use this crate. +pub mod bytes_buf; pub mod error; pub mod stream; pub mod util; -pub mod bytes_buf; pub mod endpoint; pub mod message; diff --git a/src/message.rs b/src/message.rs index 61d01d0..ca68cac 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use std::sync::Arc; use bytes::{BufMut, Bytes, BytesMut}; +use rand::prelude::*; use serde::{Deserialize, Serialize}; use futures::stream::StreamExt; @@ -40,6 +41,24 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; // ---- +#[derive(Clone, Copy)] +pub struct OrderTagStream(u64); +#[derive(Clone, Copy, Serialize, Deserialize, Debug)] +pub struct OrderTag(u64, u64); + +impl OrderTag { + pub fn stream() -> OrderTagStream { + OrderTagStream(thread_rng().gen()) + } +} +impl OrderTagStream { + pub fn order(&self, order: u64) -> OrderTag { + OrderTag(self.0, order) + } +} + +// ---- + /// This trait should be implemented by all messages your application /// wants to handle pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { @@ -56,6 +75,7 @@ pub struct Req { pub(crate) msg: Arc, pub(crate) msg_ser: Option, pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option, } impl Req { @@ -77,6 +97,13 @@ impl Req { } } + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + pub fn msg(&self) -> &M { &self.msg } @@ -97,6 +124,7 @@ impl Req { telemetry_id, msg: self.msg_ser.unwrap(), stream: self.stream.into_stream(), + order_tag: self.order_tag, } } @@ -109,6 +137,7 @@ impl Req { .stream .map(AttachedStream::Stream) .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, }) } } @@ -125,6 +154,7 @@ impl IntoReq for M { msg: Arc::new(self), msg_ser: Some(Bytes::from(msg_ser)), stream: AttachedStream::None, + order_tag: None, }) } fn into_req_local(self) -> Req { @@ -132,6 +162,7 @@ impl IntoReq for M { msg: Arc::new(self), msg_ser: None, stream: AttachedStream::None, + order_tag: None, } } } @@ -158,6 +189,7 @@ impl Clone for Req { msg: self.msg.clone(), msg_ser: self.msg_ser.clone(), stream, + order_tag: self.order_tag, } } } @@ -184,6 +216,7 @@ pub struct Resp { pub(crate) _phantom: PhantomData, pub(crate) msg: M::Response, pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option, } impl Resp { @@ -192,6 +225,7 @@ impl Resp { _phantom: Default::default(), msg: v, stream: AttachedStream::None, + order_tag: None, } } @@ -209,6 +243,13 @@ impl Resp { } } + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + pub fn msg(&self) -> &M::Response { &self.msg } @@ -222,26 +263,24 @@ impl Resp { } pub(crate) fn into_enc(self) -> Result { - Ok(RespEnc::Success { + Ok(RespEnc { msg: rmp_to_vec_all_named(&self.msg)?.into(), stream: self.stream.into_stream(), + order_tag: self.order_tag, }) } pub(crate) fn from_enc(enc: RespEnc) -> Result { - match enc { - RespEnc::Success { msg, stream } => { - let msg = rmp_serde::decode::from_read_ref(&msg)?; - Ok(Self { - _phantom: Default::default(), - msg, - stream: stream - .map(AttachedStream::Stream) - .unwrap_or(AttachedStream::None), - }) - } - RespEnc::Error { code, message } => Err(Error::Remote(code, message)), - } + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) } } @@ -295,10 +334,11 @@ pub(crate) struct ReqEnc { pub(crate) telemetry_id: Bytes, pub(crate) msg: Bytes, pub(crate) stream: Option, + pub(crate) order_tag: Option, } impl ReqEnc { - pub(crate) fn encode(self) -> ByteStream { + pub(crate) fn encode(self) -> (ByteStream, Option) { let mut buf = BytesMut::with_capacity( self.path.len() + self.telemetry_id.len() + self.msg.len() + 16, ); @@ -315,15 +355,18 @@ impl ReqEnc { let header = buf.freeze(); - if let Some(stream) = self.stream { + let res_stream: ByteStream = if let Some(stream) = self.stream { Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream)) } else { Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)])) - } + }; + (res_stream, self.order_tag) } pub(crate) async fn decode(stream: ByteStream) -> Result { - Self::decode_aux(stream).await.map_err(|_| Error::Framing) + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) } async fn decode_aux(stream: ByteStream) -> Result { @@ -346,6 +389,7 @@ impl ReqEnc { telemetry_id, msg, stream: Some(reader.into_stream()), + order_tag: None, }) } } @@ -360,74 +404,67 @@ impl ReqEnc { /// - message length + 1: u8 /// - error code: u8 /// - message: [u8; message_length] -pub(crate) enum RespEnc { - Error { - code: u8, - message: String, - }, - Success { - msg: Bytes, - stream: Option, - }, +pub(crate) struct RespEnc { + msg: Bytes, + stream: Option, + order_tag: Option, } impl RespEnc { - pub(crate) fn from_err(e: Error) -> Self { - RespEnc::Error { - code: e.code(), - message: format!("{}", e), - } - } - - pub(crate) fn encode(self) -> ByteStream { - match self { - RespEnc::Success { msg, stream } => { - let mut buf = BytesMut::with_capacity(msg.len() + 8); - - buf.put_u8(0); + pub(crate) fn encode(resp: Result) -> (ByteStream, Option) { + match resp { + Ok(Self { + msg, + stream, + order_tag, + }) => { + let mut buf = BytesMut::with_capacity(4); buf.put_u32(msg.len() as u32); - let header = buf.freeze(); - if let Some(stream) = stream { + let res_stream: ByteStream = if let Some(stream) = stream { Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream)) } else { Box::pin(futures::stream::iter([Ok(header), Ok(msg)])) - } + }; + (res_stream, order_tag) } - RespEnc::Error { code, message } => { - let mut buf = BytesMut::with_capacity(message.len() + 8); - buf.put_u8(1 + message.len() as u8); - buf.put_u8(code); - buf.put(message.as_bytes()); - let header = buf.freeze(); - Box::pin(futures::stream::once(async move { Ok(header) })) + Err(err) => { + let err = std::io::Error::new( + std::io::ErrorKind::Other, + format!("netapp error: {}", err), + ); + ( + Box::pin(futures::stream::once(async move { Err(err) })), + None, + ) } } } pub(crate) async fn decode(stream: ByteStream) -> Result { - Self::decode_aux(stream).await.map_err(|_| Error::Framing) + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) } async fn decode_aux(stream: ByteStream) -> Result { let mut reader = ByteStreamReader::new(stream); - let is_err = reader.read_u8().await?; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; - if is_err > 0 { - let code = reader.read_u8().await?; - let message = reader.read_exact(is_err as usize - 1).await?; - let message = String::from_utf8(message.to_vec()).unwrap_or_default(); - Ok(RespEnc::Error { code, message }) - } else { - let msg_len = reader.read_u32().await?; - let msg = reader.read_exact(msg_len as usize).await?; + Ok(Self { + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} - Ok(RespEnc::Success { - msg, - stream: Some(reader.into_stream()), - }) - } +fn read_exact_error_to_error(e: ReadExactError) -> Error { + match e { + ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()), + ReadExactError::UnexpectedEos => Error::Framing, } } diff --git a/src/recv.rs b/src/recv.rs index f8606f3..b5289fb 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -35,7 +35,10 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { if let Some(inner) = self.inner.take() { - let _ = inner.send(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Netapp connection dropped before end of stream"))); + let _ = inner.send(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Netapp connection dropped before end of stream", + ))); } } } @@ -82,7 +85,8 @@ pub(crate) trait RecvLoop: Sync + 'static { let packet = if is_error { let kind = u8_to_io_errorkind(next_slice[0]); - let msg = std::str::from_utf8(&next_slice[1..]).unwrap_or(""); + let msg = + std::str::from_utf8(&next_slice[1..]).unwrap_or(""); debug!("recv_loop: got id {}, error {:?}: {}", id, kind, msg); Some(Err(std::io::Error::new(kind, msg.to_string()))) } else { diff --git a/src/send.rs b/src/send.rs index 287fe40..c40787f 100644 --- a/src/send.rs +++ b/src/send.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; -use bytes::{Bytes, BytesMut, BufMut}; +use bytes::{BufMut, Bytes, BytesMut}; use log::*; use futures::AsyncWriteExt; @@ -36,6 +36,8 @@ pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; +pub(crate) type SendStream = (RequestID, RequestPriority, ByteStream); + struct SendQueue { items: Vec<(u8, VecDeque)>, } @@ -184,7 +186,7 @@ impl DataFrame { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>, + msg_recv: mpsc::UnboundedReceiver, mut write: BoxStreamWrite, ) -> Result<(), Error> where diff --git a/src/server.rs b/src/server.rs index 57062d8..c23c9e4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -53,7 +53,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -177,26 +177,16 @@ impl RecvLoop for ServerConn { tokio::spawn(async move { debug!("server: recv_handler got {}", id); - let (prio, resp_enc) = match ReqEnc::decode(stream).await { - Ok(req_enc) => { - let prio = req_enc.prio; - let resp = self2.recv_handler_aux(req_enc).await; - - ( - prio, - match resp { - Ok(resp_enc) => resp_enc, - Err(e) => RespEnc::from_err(e), - }, - ) - } - Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), + let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { + Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await), + Err(e) => (PRIO_HIGH, Err(e)), }; debug!("server: sending response to {}", id); + let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_enc.encode())) + .send((id, prio, resp_stream)) .log_err("ServerConn recv_handler send resp bytes"); Ok::<_, Error>(()) }); diff --git a/src/stream.rs b/src/stream.rs index 6e00e5f..efa0ebc 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -150,7 +150,6 @@ impl<'a> Future for ByteStreamReadExact<'a> { // ---- - pub fn asyncread_stream(reader: R) -> ByteStream { Box::pin(tokio_util::io::ReaderStream::new(reader)) } -- cgit v1.2.3 From 4a59b73d7bfd0f136f654e874afb5d2a9bf4df2e Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 12:46:33 +0200 Subject: Add actual support for order tag --- src/client.rs | 2 +- src/message.rs | 2 +- src/send.rs | 112 +++++++++++++++++++++++++++++++++++++++------------------ src/server.rs | 2 +- 4 files changed, 81 insertions(+), 37 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index aef7bbb..df54810 100644 --- a/src/client.rs +++ b/src/client.rs @@ -190,7 +190,7 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, req_stream))?; + query_send.send((id, prio, req_order, req_stream))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { diff --git a/src/message.rs b/src/message.rs index ca68cac..1834f28 100644 --- a/src/message.rs +++ b/src/message.rs @@ -44,7 +44,7 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; #[derive(Clone, Copy)] pub struct OrderTagStream(u64); #[derive(Clone, Copy, Serialize, Deserialize, Debug)] -pub struct OrderTag(u64, u64); +pub struct OrderTag(pub(crate) u64, pub(crate) u64); impl OrderTag { pub fn stream() -> OrderTagStream { diff --git a/src/send.rs b/src/send.rs index c40787f..ea6cf9f 100644 --- a/src/send.rs +++ b/src/send.rs @@ -1,4 +1,4 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -7,7 +7,7 @@ use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use log::*; -use futures::AsyncWriteExt; +use futures::{AsyncWriteExt, Future}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -36,15 +36,21 @@ pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; -pub(crate) type SendStream = (RequestID, RequestPriority, ByteStream); +pub(crate) type SendStream = (RequestID, RequestPriority, Option, ByteStream); struct SendQueue { - items: Vec<(u8, VecDeque)>, + items: Vec<(u8, SendQueuePriority)>, +} + +struct SendQueuePriority { + items: VecDeque, + order: HashMap>, } struct SendQueueItem { id: RequestID, prio: RequestPriority, + order_tag: Option, data: ByteStreamReader, } @@ -59,11 +65,11 @@ impl SendQueue { let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { Ok(i) => i, Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); + self.items.insert(i, (prio, SendQueuePriority::new())); i } }; - self.items[pos_prio].1.push_back(item); + self.items[pos_prio].1.push(item); } fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) @@ -75,29 +81,34 @@ impl SendQueue { } } -struct SendQueuePollNextReady<'a> { - queue: &'a mut SendQueue, -} - -impl<'a> futures::Future for SendQueuePollNextReady<'a> { - type Output = (RequestID, DataFrame); - - fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { - let mut ready_item = None; - for (j, item) in items_at_prio.iter_mut().enumerate() { - let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); - match Pin::new(&mut item_reader).poll(ctx) { - Poll::Pending => (), - Poll::Ready(ready_v) => { - ready_item = Some((j, ready_v)); - break; - } +impl SendQueuePriority { + fn new() -> Self { + Self { + items: VecDeque::new(), + order: HashMap::new(), + } + } + fn push(&mut self, item: SendQueueItem) { + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.entry(stream).or_default(); + let i = order_vec.iter().take_while(|o2| **o2 < order).count(); + order_vec.insert(i, order); + } + self.items.push_back(item); + } + fn is_empty(&self) -> bool { + self.items.is_empty() + } + fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> { + for (j, item) in self.items.iter_mut().enumerate() { + if let Some(OrderTag(stream, order)) = item.order_tag { + if order > *self.order.get(&stream).unwrap().front().unwrap() { + continue; } } - if let Some((j, bytes_or_err)) = ready_item { - let item = items_at_prio.remove(j).unwrap(); + let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); + if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) { let id = item.id; let eos = item.data.eos(); @@ -106,15 +117,47 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { _ => unreachable!(), }); + if eos || packet.is_err() { + if let Some(OrderTag(stream, order)) = item.order_tag { + assert_eq!( + self.order.get_mut(&stream).unwrap().pop_front(), + Some(order) + ) + } + self.items.remove(j); + } + let data_frame = DataFrame::from_packet(packet, !eos); - if !eos && !matches!(data_frame, DataFrame::Error(_)) { - items_at_prio.push_back(item); - } else if items_at_prio.is_empty() { + return Poll::Ready((id, data_frame)); + } + } + + Poll::Pending + } + fn dump(&self, prio: u8) -> String { + self.items + .iter() + .map(|i| format!("[{} {} {:?}]", prio, i.id, i.order_tag)) + .collect::>() + .join(" ") + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataFrame); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { + if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) { + if items_at_prio.is_empty() { self.queue.items.remove(i); } - - return Poll::Ready((id, data_frame)); + return Poll::Ready(res); } } // If the queue is empty, this futures is eternally pending. @@ -200,8 +243,7 @@ pub(crate) trait SendLoop: Sync { sending .items .iter() - .map(|(_, i)| i.iter().map(|x| x.id)) - .flatten() + .map(|(prio, i)| i.dump(*prio)) .collect::>() ); @@ -217,12 +259,14 @@ pub(crate) trait SendLoop: Sync { // recv_fut is cancellation-safe according to tokio doc, // send_fut is cancellation-safe as implemented above? tokio::select! { + biased; // always read incomming channel first if it has data sth = recv_fut => { - if let Some((id, prio, data)) = sth { + if let Some((id, prio, order_tag, data)) = sth { trace!("send_loop: add stream {} to send", id); sending.push(SendQueueItem { id, prio, + order_tag, data: ByteStreamReader::new(data), }); } else { diff --git a/src/server.rs b/src/server.rs index c23c9e4..f8c3f98 100644 --- a/src/server.rs +++ b/src/server.rs @@ -186,7 +186,7 @@ impl RecvLoop for ServerConn { let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_stream)) + .send((id, prio, resp_order, resp_stream)) .log_err("ServerConn recv_handler send resp bytes"); Ok::<_, Error>(()) }); -- cgit v1.2.3 From 32925667385db9e1d9e56ebae67d03d8096f7c46 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 14:43:27 +0200 Subject: fix trace message --- src/send.rs | 1 + 1 file changed, 1 insertion(+) (limited to 'src') diff --git a/src/send.rs b/src/send.rs index 3b01cb5..d927d98 100644 --- a/src/send.rs +++ b/src/send.rs @@ -247,6 +247,7 @@ pub(crate) trait SendLoop: Sync { .iter() .map(|(prio, i)| i.dump(*prio)) .collect::>() + .join(" ; ") ); let recv_fut = async { -- cgit v1.2.3 From 522f420e2bf30d5ef6f50dccb88adf86882ac7c6 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 15:54:11 +0200 Subject: Implement request cancellation --- src/client.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- src/recv.rs | 18 +++++++++++++++- src/send.rs | 57 ++++++++++++++++++++++++++++++++++++++----------- src/server.rs | 37 ++++++++++++++++++++++++++------ 4 files changed, 159 insertions(+), 21 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 9726125..d82c91e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; +use std::task::Poll; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -9,6 +11,7 @@ use bytes::Bytes; use log::{debug, error, trace}; use futures::io::AsyncReadExt; +use futures::Stream; use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; @@ -35,7 +38,7 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, inflight: Mutex>>, @@ -193,7 +196,9 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, req_order, req_stream))?; + query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?; + + let canceller = CancelOnDrop::new(id, query_send.as_ref().clone()); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -205,6 +210,8 @@ impl ClientConn { } } + let stream = Box::pin(canceller.for_stream(stream)); + let resp_enc = RespEnc::decode(stream).await?; debug!("client: got response to request {} (path {})", id, path); Resp::from_enc(resp_enc) @@ -223,6 +230,63 @@ impl RecvLoop for ClientConn { if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } + } else { + debug!("Got unexpected response to request {}, dropping it", id); + } + } +} + +// ---- + +struct CancelOnDrop { + id: RequestID, + query_send: mpsc::UnboundedSender, +} + +impl CancelOnDrop { + fn new(id: RequestID, query_send: mpsc::UnboundedSender) -> Self { + Self { id, query_send } + } + fn for_stream(self, stream: ByteStream) -> CancelOnDropStream { + CancelOnDropStream { + cancel: Some(self), + stream: stream, + } + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + trace!("cancelling request {}", self.id); + let _ = self.query_send.send(SendItem::Cancel(self.id)); + } +} + +#[pin_project::pin_project] +struct CancelOnDropStream { + cancel: Option, + #[pin] + stream: ByteStream, +} + +impl Stream for CancelOnDropStream { + type Item = Packet; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + let res = this.stream.poll_next(cx); + if matches!(res, Poll::Ready(None)) { + if let Some(c) = this.cancel.take() { + std::mem::forget(c) + } } + res + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() } } diff --git a/src/recv.rs b/src/recv.rs index ac93c4b..8909190 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -53,6 +53,7 @@ impl Drop for Sender { #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream); + fn cancel_handler(self: &Arc, _id: RequestID) {} async fn recv_loop(self: Arc, mut read: R, debug_name: String) -> Result<(), Error> where @@ -78,6 +79,18 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut header_size[..]).await?; let size = ChunkLength::from_be_bytes(header_size); + if size == CANCEL_REQUEST { + if let Some(mut stream) = streams.remove(&id) { + let _ = stream.send(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "netapp: cancel requested", + ))); + stream.end(); + } + self.cancel_handler(id); + continue; + } + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; let size = (size & CHUNK_LENGTH_MASK) as usize; @@ -88,7 +101,10 @@ pub(crate) trait RecvLoop: Sync + 'static { let kind = u8_to_io_errorkind(next_slice[0]); let msg = std::str::from_utf8(&next_slice[1..]).unwrap_or(""); - debug!("recv_loop({}): got id {}, error {:?}: {}", debug_name, id, kind, msg); + debug!( + "recv_loop({}): got id {}, error {:?}: {}", + debug_name, id, kind, msg + ); Some(Err(std::io::Error::new(kind, msg.to_string()))) } else { trace!( diff --git a/src/send.rs b/src/send.rs index d927d98..780bbcf 100644 --- a/src/send.rs +++ b/src/send.rs @@ -22,6 +22,7 @@ use crate::stream::*; // CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream // ERROR_MARKER if this chunk denotes an error // (these two flags are exclusive, an error denotes the end of the stream) +// **special value** 0xFFFF indicates a CANCEL message // - [u8; chunk_length], either // - if not error: chunk data // - if error: @@ -35,8 +36,14 @@ pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; +pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF; -pub(crate) type SendStream = (RequestID, RequestPriority, Option, ByteStream); +pub(crate) enum SendItem { + Stream(RequestID, RequestPriority, Option, ByteStream), + Cancel(RequestID), +} + +// ---- struct SendQueue { items: Vec<(u8, SendQueuePriority)>, @@ -71,6 +78,11 @@ impl SendQueue { }; self.items[pos_prio].1.push(item); } + fn remove(&mut self, id: RequestID) { + for (_, prioq) in self.items.iter_mut() { + prioq.remove(id); + } + } fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } @@ -96,6 +108,16 @@ impl SendQueuePriority { } self.items.push_back(item); } + fn remove(&mut self, id: RequestID) { + if let Some(i) = self.items.iter().position(|x| x.id == id) { + let item = self.items.remove(i).unwrap(); + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.get_mut(&stream).unwrap(); + let j = order_vec.iter().position(|x| *x == order).unwrap(); + order_vec.remove(j).unwrap(); + } + } + } fn is_empty(&self) -> bool { self.items.is_empty() } @@ -229,7 +251,7 @@ impl DataFrame { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - msg_recv: mpsc::UnboundedReceiver, + msg_recv: mpsc::UnboundedReceiver, mut write: BoxStreamWrite, debug_name: String, ) -> Result<(), Error> @@ -264,16 +286,27 @@ pub(crate) trait SendLoop: Sync { tokio::select! { biased; // always read incomming channel first if it has data sth = recv_fut => { - if let Some((id, prio, order_tag, data)) = sth { - trace!("send_loop({}): add stream {} to send", debug_name, id); - sending.push(SendQueueItem { - id, - prio, - order_tag, - data: ByteStreamReader::new(data), - }); - } else { - msg_recv = None; + match sth { + Some(SendItem::Stream(id, prio, order_tag, data)) => { + trace!("send_loop({}): add stream {} to send", debug_name, id); + sending.push(SendQueueItem { + id, + prio, + order_tag, + data: ByteStreamReader::new(data), + }) + } + Some(SendItem::Cancel(id)) => { + trace!("send_loop({}): cancelling {}", debug_name, id); + sending.remove(id); + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?; + write.flush().await?; + } + None => { + msg_recv = None; + } }; } (id, data) = send_fut => { diff --git a/src/server.rs b/src/server.rs index 2c12d9d..f9eb121 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,6 @@ +use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -53,7 +54,8 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, + running_handlers: Mutex>>, } impl ServerConn { @@ -99,6 +101,7 @@ impl ServerConn { remote_addr, peer_id, resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), + running_handlers: Mutex::new(HashMap::new()), }); netapp.connected_as_server(peer_id, conn.clone()); @@ -174,10 +177,15 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { - let resp_send = self.resp_send.load_full().unwrap(); + let resp_send = match self.resp_send.load_full() { + Some(c) => c, + None => return, + }; + + let mut rh = self.running_handlers.lock().unwrap(); let self2 = self.clone(); - tokio::spawn(async move { + let jh = tokio::spawn(async move { debug!("server: recv_handler got {}", id); let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { @@ -189,9 +197,26 @@ impl RecvLoop for ServerConn { let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_order, resp_stream)) + .send(SendItem::Stream(id, prio, resp_order, resp_stream)) .log_err("ServerConn recv_handler send resp bytes"); - Ok::<_, Error>(()) + + self2.running_handlers.lock().unwrap().remove(&id); }); + + rh.insert(id, jh); + } + + fn cancel_handler(self: &Arc, id: RequestID) { + trace!("received cancel for request {}", id); + + // If the handler is still running, abort it now + if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) { + jh.abort(); + } + + // Inform the response sender that we don't need to send the response + if let Some(resp_send) = self.resp_send.load_full() { + let _ = resp_send.send(SendItem::Cancel(id)); + } } } -- cgit v1.2.3 From b931d0d1cfb39d5feae1d4e0a7a49cdebd45761b Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 16:01:56 +0200 Subject: try debug --- src/client.rs | 1 + 1 file changed, 1 insertion(+) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index d82c91e..7dffa36 100644 --- a/src/client.rs +++ b/src/client.rs @@ -280,6 +280,7 @@ impl Stream for CancelOnDropStream { let res = this.stream.poll_next(cx); if matches!(res, Poll::Ready(None)) { if let Some(c) = this.cancel.take() { + trace!("defusing cancel request {}", c.id); std::mem::forget(c) } } -- cgit v1.2.3 From b82ad70dd5d5e7ce9102f63fec37396dbda8de08 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 16:10:38 +0200 Subject: Correctly defuse cancellation on simple requests --- src/message.rs | 2 ++ src/stream.rs | 34 ++++++++++++++++++++++------------ 2 files changed, 24 insertions(+), 12 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index 1834f28..ec9433a 100644 --- a/src/message.rs +++ b/src/message.rs @@ -454,6 +454,8 @@ impl RespEnc { let msg_len = reader.read_u32().await?; let msg = reader.read_exact(msg_len as usize).await?; + reader.fill_buffer().await; + Ok(Self { msg, stream: Some(reader.into_stream()), diff --git a/src/stream.rs b/src/stream.rs index efa0ebc..05ee051 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -95,6 +95,26 @@ impl ByteStreamReader { fn try_get(&mut self, read_len: usize) -> Option { self.buf.take_exact(read_len) } + + fn add_stream_next(&mut self, packet: Option) { + match packet { + Some(Ok(slice)) => { + self.buf.extend(slice); + } + Some(Err(e)) => { + self.err = Some(e); + self.eos = true; + } + None => { + self.eos = true; + } + } + } + + pub async fn fill_buffer(&mut self) { + let packet = self.stream.next().await; + self.add_stream_next(packet); + } } pub enum ReadExactError { @@ -132,18 +152,8 @@ impl<'a> Future for ByteStreamReadExact<'a> { } } - match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { - Some(Ok(slice)) => { - this.reader.buf.extend(slice); - } - Some(Err(e)) => { - this.reader.err = Some(e); - this.reader.eos = true; - } - None => { - this.reader.eos = true; - } - } + let next_packet = futures::ready!(this.reader.stream.as_mut().poll_next(cx)); + this.reader.add_stream_next(next_packet); } } } -- cgit v1.2.3 From f6ad1d0fab340e77fbfcb3488a98c342d334838e Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 16:13:43 +0200 Subject: less verbosity --- src/client.rs | 1 - 1 file changed, 1 deletion(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 7dffa36..d82c91e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -280,7 +280,6 @@ impl Stream for CancelOnDropStream { let res = this.stream.poll_next(cx); if matches!(res, Poll::Ready(None)) { if let Some(c) = this.cancel.take() { - trace!("defusing cancel request {}", c.id); std::mem::forget(c) } } -- cgit v1.2.3 From 0f799a7768997c37e3e1b6861c097c4cd934acde Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 6 Sep 2022 19:42:49 +0200 Subject: Implement Least Attained First scheduling of streams --- src/send.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/send.rs b/src/send.rs index 780bbcf..2606434 100644 --- a/src/send.rs +++ b/src/send.rs @@ -59,6 +59,7 @@ struct SendQueueItem { prio: RequestPriority, order_tag: Option, data: ByteStreamReader, + sent: usize, } impl SendQueue { @@ -106,7 +107,7 @@ impl SendQueuePriority { let i = order_vec.iter().take_while(|o2| **o2 < order).count(); order_vec.insert(i, order); } - self.items.push_back(item); + self.items.push_front(item); } fn remove(&mut self, id: RequestID) { if let Some(i) = self.items.iter().position(|x| x.id == id) { @@ -139,7 +140,11 @@ impl SendQueuePriority { _ => unreachable!(), }); - if eos || packet.is_err() { + let is_err = packet.is_err(); + let data_frame = DataFrame::from_packet(packet, !eos); + item.sent += data_frame.data().len(); + + if eos || is_err { if let Some(OrderTag(stream, order)) = item.order_tag { assert_eq!( self.order.get_mut(&stream).unwrap().pop_front(), @@ -147,10 +152,16 @@ impl SendQueuePriority { ) } self.items.remove(j); + } else { + for k in j..self.items.len() - 1 { + if self.items[k].sent >= self.items[k + 1].sent { + self.items.swap(k, k + 1); + } else { + break; + } + } } - let data_frame = DataFrame::from_packet(packet, !eos); - return Poll::Ready((id, data_frame)); } } @@ -160,7 +171,7 @@ impl SendQueuePriority { fn dump(&self, prio: u8) -> String { self.items .iter() - .map(|i| format!("[{} {} {:?}]", prio, i.id, i.order_tag)) + .map(|i| format!("[{} {} {:?} @{}]", prio, i.id, i.order_tag, i.sent)) .collect::>() .join(" ") } @@ -294,6 +305,7 @@ pub(crate) trait SendLoop: Sync { prio, order_tag, data: ByteStreamReader::new(data), + sent: 0, }) } Some(SendItem::Cancel(id)) => { -- cgit v1.2.3 From 8a7aca98375ff20effaab3d7c95124bd4cbc925c Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 12 Sep 2022 17:20:45 +0200 Subject: reword doc comment --- src/bytes_buf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs index 857be9d..05b7edd 100644 --- a/src/bytes_buf.rs +++ b/src/bytes_buf.rs @@ -5,7 +5,7 @@ pub use bytes::Bytes; /// A circular buffer of bytes, internally represented as a list of Bytes /// for optimization, but that for all intent and purposes acts just like /// a big byte slice which can be extended on the right and from which -/// one can take on the left. +/// stuff can be taken on the left. pub struct BytesBuf { buf: VecDeque, buf_len: usize, -- cgit v1.2.3 From f022a77f97c169807ae098e101a29301c0d19fbd Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 12 Sep 2022 17:43:10 +0200 Subject: Add documentation --- src/message.rs | 52 ++++++++++++++++++++++++++++++++++++++++++++++--- src/peering/fullmesh.rs | 1 + src/stream.rs | 47 ++++++++++++++++++++++++++++++++++++-------- 3 files changed, 89 insertions(+), 11 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index ec9433a..2b2b75f 100644 --- a/src/message.rs +++ b/src/message.rs @@ -41,17 +41,31 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; // ---- -#[derive(Clone, Copy)] -pub struct OrderTagStream(u64); +/// An order tag can be added to a message or a response to indicate +/// whether it should be sent after or before other messages with order tags +/// referencing a same stream #[derive(Clone, Copy, Serialize, Deserialize, Debug)] pub struct OrderTag(pub(crate) u64, pub(crate) u64); +/// A stream is an opaque identifier that defines a set of messages +/// or responses that are ordered wrt one another using to order tags. +#[derive(Clone, Copy)] +pub struct OrderTagStream(u64); + + impl OrderTag { + /// Create a new stream from which to generate order tags. Example: + /// ``` + /// let stream = OrderTag.stream(); + /// let tag_1 = stream.order(1); + /// let tag_2 = stream.order(2); + /// ``` pub fn stream() -> OrderTagStream { OrderTagStream(thread_rng().gen()) } } impl OrderTagStream { + /// Create the order tag for message `order` in this stream pub fn order(&self, order: u64) -> OrderTag { OrderTag(self.0, order) } @@ -60,8 +74,10 @@ impl OrderTagStream { // ---- /// This trait should be implemented by all messages your application -/// wants to handle +/// wants to handle. It specifies which data type should be sent +/// as a response to this message in the RPC protocol. pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { + /// The type of the response that is sent in response to this message type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static; } @@ -79,10 +95,13 @@ pub struct Req { } impl Req { + /// Creates a new request from a base message `M` pub fn new(v: M) -> Result { Ok(v.into_req()?) } + /// Attach a stream to message in request, where the stream is streamed + /// from a fixed `Bytes` buffer pub fn with_stream_from_buffer(self, b: Bytes) -> Self { Self { stream: AttachedStream::Fixed(b), @@ -90,6 +109,10 @@ impl Req { } } + /// Attach a stream to message in request, where the stream is + /// an instance of `ByteStream`. Note than when a `Req` has an attached + /// stream which is a `ByteStream` instance, it can no longer be cloned + /// to be sent to different nodes (`.clone()` will panic) pub fn with_stream(self, b: ByteStream) -> Self { Self { stream: AttachedStream::Stream(b), @@ -97,6 +120,8 @@ impl Req { } } + /// Add an order tag to this request to indicate in which order it should + /// be sent. pub fn with_order_tag(self, order_tag: OrderTag) -> Self { Self { order_tag: Some(order_tag), @@ -104,10 +129,12 @@ impl Req { } } + /// Get a reference to the message `M` contained in this request pub fn msg(&self) -> &M { &self.msg } + /// Takes out the stream attached to this request, if any pub fn take_stream(&mut self) -> Option { std::mem::replace(&mut self.stream, AttachedStream::None).into_stream() } @@ -142,8 +169,14 @@ impl Req { } } +/// `IntoReq` represents any object that can be transformed into `Req` pub trait IntoReq { + /// Transform the object into a `Req`, serializing the message M + /// to be sent to remote nodes fn into_req(self) -> Result, rmp_serde::encode::Error>; + /// Transform the object into a `Req`, skipping the serialization + /// of message M, in the case we are not sending this RPC message to + /// a remote node fn into_req_local(self) -> Req; } @@ -220,6 +253,7 @@ pub struct Resp { } impl Resp { + /// Creates a new response from a base response message pub fn new(v: M::Response) -> Self { Resp { _phantom: Default::default(), @@ -229,6 +263,8 @@ impl Resp { } } + /// Attach a stream to message in response, where the stream is streamed + /// from a fixed `Bytes` buffer pub fn with_stream_from_buffer(self, b: Bytes) -> Self { Self { stream: AttachedStream::Fixed(b), @@ -236,6 +272,8 @@ impl Resp { } } + /// Attach a stream to message in response, where the stream is + /// an instance of `ByteStream`. pub fn with_stream(self, b: ByteStream) -> Self { Self { stream: AttachedStream::Stream(b), @@ -243,6 +281,8 @@ impl Resp { } } + /// Add an order tag to this response to indicate in which order it should + /// be sent. pub fn with_order_tag(self, order_tag: OrderTag) -> Self { Self { order_tag: Some(order_tag), @@ -250,14 +290,20 @@ impl Resp { } } + /// Get a reference to the response message contained in this request pub fn msg(&self) -> &M::Response { &self.msg } + /// Transforms the `Resp` into the response message it contains, + /// dropping everything else (including attached data stream) pub fn into_msg(self) -> M::Response { self.msg } + /// Transforms the `Resp` into, on the one side, the response message + /// it contains, and on the other side, the associated data stream + /// if it exists pub fn into_parts(self) -> (M::Response, Option) { (self.msg, self.stream.into_stream()) } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 7f1c065..2f3330e 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -81,6 +81,7 @@ impl PeerInfoInternal { } } +/// Information that the full mesh peering strategy can return about the peers it knows of #[derive(Copy, Clone, Debug)] pub struct PeerInfo { /// The node's identifier (its public key) diff --git a/src/stream.rs b/src/stream.rs index 05ee051..82f7be3 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -9,19 +9,23 @@ use tokio::io::AsyncRead; use crate::bytes_buf::BytesBuf; -/// A stream of associated data. +/// A stream of bytes (click to read more). /// /// When sent through Netapp, the Vec may be split in smaller chunk in such a way /// consecutive Vec may get merged, but Vec and error code may not be reordered /// -/// Error code 255 means the stream was cut before its end. Other codes have no predefined -/// meaning, it's up to your application to define their semantic. +/// Items sent in the ByteStream may be errors of type `std::io::Error`. +/// An error indicates the end of the ByteStream: a reader should no longer read +/// after recieving an error, and a writer should stop writing after sending an error. pub type ByteStream = Pin + Send + Sync>>; +/// A packet sent in a ByteStream, which may contain either +/// a Bytes object or an error pub type Packet = Result; // ---- +/// A helper struct to read defined lengths of data from a BytesStream pub struct ByteStreamReader { stream: ByteStream, buf: BytesBuf, @@ -30,6 +34,7 @@ pub struct ByteStreamReader { } impl ByteStreamReader { + /// Creates a new `ByteStreamReader` from a `ByteStream` pub fn new(stream: ByteStream) -> Self { ByteStreamReader { stream, @@ -39,6 +44,8 @@ impl ByteStreamReader { } } + /// Read exactly `read_len` bytes from the underlying stream + /// (returns a future) pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { ByteStreamReadExact { reader: self, @@ -47,6 +54,8 @@ impl ByteStreamReader { } } + /// Read at most `read_len` bytes from the underlying stream, or less + /// if the end of the stream is reached (returns a future) pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { ByteStreamReadExact { reader: self, @@ -55,10 +64,14 @@ impl ByteStreamReader { } } + /// Read exactly one byte from the underlying stream and returns it + /// as an u8 pub async fn read_u8(&mut self) -> Result { Ok(self.read_exact(1).await?[0]) } + /// Read exactly two bytes from the underlying stream and returns them as an u16 (using + /// big-endian decoding) pub async fn read_u16(&mut self) -> Result { let bytes = self.read_exact(2).await?; let mut b = [0u8; 2]; @@ -66,6 +79,8 @@ impl ByteStreamReader { Ok(u16::from_be_bytes(b)) } + /// Read exactly four bytes from the underlying stream and returns them as an u32 (using + /// big-endian decoding) pub async fn read_u32(&mut self) -> Result { let bytes = self.read_exact(4).await?; let mut b = [0u8; 4]; @@ -73,6 +88,8 @@ impl ByteStreamReader { Ok(u32::from_be_bytes(b)) } + /// Transforms the stream reader back into the underlying stream (starting + /// after everything that the reader has read) pub fn into_stream(self) -> ByteStream { let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok)); if let Some(err) = self.err { @@ -84,10 +101,21 @@ impl ByteStreamReader { } } + /// Tries to fill the internal read buffer from the underlying stream. + /// Calling this might be necessary to ensure that `.eos()` returns a correct + /// result, otherwise the reader might not be aware that the underlying + /// stream has nothing left to return. + pub async fn fill_buffer(&mut self) { + let packet = self.stream.next().await; + self.add_stream_next(packet); + } + + /// Clears the internal read buffer and returns its content pub fn take_buffer(&mut self) -> Bytes { self.buf.take_all() } + /// Returns true if the end of the underlying stream has been reached pub fn eos(&self) -> bool { self.buf.is_empty() && self.eos } @@ -110,18 +138,19 @@ impl ByteStreamReader { } } } - - pub async fn fill_buffer(&mut self) { - let packet = self.stream.next().await; - self.add_stream_next(packet); - } } +/// The error kind that can be returned by `ByteStreamReader::read_exact` and +/// `ByteStreamReader::read_exact_or_eos` pub enum ReadExactError { + /// The end of the stream was reached before the requested number of bytes could be read UnexpectedEos, + /// The underlying data stream returned an IO error when trying to read Stream(std::io::Error), } +/// The future returned by `ByteStreamReader::read_exact` and +/// `ByteStreamReader::read_exact_or_eos` #[pin_project::pin_project] pub struct ByteStreamReadExact<'a> { #[pin] @@ -160,10 +189,12 @@ impl<'a> Future for ByteStreamReadExact<'a> { // ---- +/// Turns a `tokio::io::AsyncRead` asynchronous reader into a `ByteStream` pub fn asyncread_stream(reader: R) -> ByteStream { Box::pin(tokio_util::io::ReaderStream::new(reader)) } +/// Turns a `ByteStream` into a `tokio::io::AsyncRead` asynchronous reader pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { tokio_util::io::StreamReader::new(stream) } -- cgit v1.2.3 From 2305c2cf03919f074ec92d98cb6593c4ead50c4b Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 11:31:19 +0200 Subject: Use BytesMut instead of Vec in bytes_buf (extend is probably faster) --- src/bytes_buf.rs | 18 ++++++++++-------- src/message.rs | 1 - 2 files changed, 10 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs index 05b7edd..931be82 100644 --- a/src/bytes_buf.rs +++ b/src/bytes_buf.rs @@ -1,5 +1,7 @@ use std::collections::VecDeque; +use bytes::BytesMut; + pub use bytes::Bytes; /// A circular buffer of bytes, internally represented as a list of Bytes @@ -48,13 +50,13 @@ impl BytesBuf { self.buf_len = 0; self.buf.pop_back().unwrap() } else { - let mut ret = Vec::with_capacity(self.buf_len); + let mut ret = BytesMut::with_capacity(self.buf_len); for b in self.buf.iter() { - ret.extend(&b[..]); + ret.extend_from_slice(&b[..]); } self.buf.clear(); self.buf_len = 0; - Bytes::from(ret) + ret.freeze() } } @@ -88,23 +90,23 @@ impl BytesBuf { self.buf_len -= len; front } else { - let mut ret = Vec::with_capacity(len); - ret.extend(&front[..]); + let mut ret = BytesMut::with_capacity(len); + ret.extend_from_slice(&front[..]); self.buf_len -= front.len(); while ret.len() < len { let front = self.buf.pop_front().unwrap(); if front.len() > len - ret.len() { let take = len - ret.len(); - ret.extend(front.slice(..take)); + ret.extend_from_slice(&front[..take]); self.buf.push_front(front.slice(take..)); self.buf_len -= take; break; } else { - ret.extend(&front[..]); + ret.extend_from_slice(&front[..]); self.buf_len -= front.len(); } } - Bytes::from(ret) + ret.freeze() } } diff --git a/src/message.rs b/src/message.rs index 2b2b75f..cc816c6 100644 --- a/src/message.rs +++ b/src/message.rs @@ -52,7 +52,6 @@ pub struct OrderTag(pub(crate) u64, pub(crate) u64); #[derive(Clone, Copy)] pub struct OrderTagStream(u64); - impl OrderTag { /// Create a new stream from which to generate order tags. Example: /// ``` -- cgit v1.2.3 From bf0e82047f0a54d71ce048ed8f5e1584485b542c Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 11:51:03 +0200 Subject: try make more like before --- src/endpoint.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index bb768de..015000b 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -43,7 +43,7 @@ pub trait EndpointHandler: Send + Sync where M: Message, { - async fn handle(self: &Arc, m: &M, from: NodeID) -> ::Response; + async fn handle(self: &Arc, m: &M, from: NodeID) -> M::Response; } #[async_trait] -- cgit v1.2.3 From add2b54743a76bf805b0dc5ab7a1d8d326445314 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 11:52:35 +0200 Subject: fix comment --- src/endpoint.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/endpoint.rs b/src/endpoint.rs index 015000b..3cafafe 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -53,8 +53,8 @@ where M: Message, { async fn handle(self: &Arc, mut m: Req, from: NodeID) -> Resp { - // Immediately drop stream to avoid backpressure if a stream was sent - // (this will make all data sent to the stream be ignored immediately) + // Immediately drop stream to ignore all data that comes in, + // instead of buffering it indefinitely drop(m.take_stream()); Resp::new(EndpointHandler::handle(self, m.msg(), from).await) } -- cgit v1.2.3 From db96af2609a75284c5608cf592c3d4ce4b28ae0b Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:05:42 +0200 Subject: Add comment on cancellation --- src/message.rs | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index cc816c6..b12da32 100644 --- a/src/message.rs +++ b/src/message.rs @@ -499,6 +499,11 @@ impl RespEnc { let msg_len = reader.read_u32().await?; let msg = reader.read_exact(msg_len as usize).await?; + // Check whether the response stream still has data or not. + // If no more data is coming, this will defuse the request canceller. + // If we didn't do this, and the client doesn't try to read from the stream, + // the request canceller doesn't know that we read everything and + // sends a cancellation message to the server (which they don't care about). reader.fill_buffer().await; Ok(Self { -- cgit v1.2.3 From 9362d268904de694328f20e9e3a31b569955787c Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:08:33 +0200 Subject: fill_buffer do something only if buf is empty --- src/stream.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/stream.rs b/src/stream.rs index 82f7be3..88c3fed 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -101,13 +101,15 @@ impl ByteStreamReader { } } - /// Tries to fill the internal read buffer from the underlying stream. + /// Tries to fill the internal read buffer from the underlying stream if it is empty. /// Calling this might be necessary to ensure that `.eos()` returns a correct /// result, otherwise the reader might not be aware that the underlying /// stream has nothing left to return. pub async fn fill_buffer(&mut self) { - let packet = self.stream.next().await; - self.add_stream_next(packet); + if self.buf.is_empty() { + let packet = self.stream.next().await; + self.add_stream_next(packet); + } } /// Clears the internal read buffer and returns its content -- cgit v1.2.3 From 8ab6256c3b5a2cde7144b3a5e1ef488b7bce6227 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:12:55 +0200 Subject: No longer need to derive Clone on message types --- src/message.rs | 2 +- src/netapp.rs | 2 +- src/peering/fullmesh.rs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/message.rs b/src/message.rs index b12da32..9cc1a3e 100644 --- a/src/message.rs +++ b/src/message.rs @@ -54,7 +54,7 @@ pub struct OrderTagStream(u64); impl OrderTag { /// Create a new stream from which to generate order tags. Example: - /// ``` + /// ```ignore /// let stream = OrderTag.stream(); /// let tag_1 = stream.order(1); /// let tag_2 = stream.order(2); diff --git a/src/netapp.rs b/src/netapp.rs index f1e14ed..b1ad9db 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -38,7 +38,7 @@ pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005 -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug)] pub(crate) struct HelloMessage { pub server_addr: Option, pub server_port: u16, diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 2f3330e..fb2e3d1 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -30,7 +30,7 @@ const FAILED_PING_THRESHOLD: usize = 4; // -- Protocol messages -- -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize)] struct PingMessage { pub id: u64, pub peer_list_hash: hash::Digest, @@ -40,7 +40,7 @@ impl Message for PingMessage { type Response = PingMessage; } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize)] struct PeerListMessage { pub list: Vec<(NodeID, SocketAddr)>, } -- cgit v1.2.3 From 18d5abc981faf2d76ced42bad5cb69aa83128832 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:20:49 +0200 Subject: add precision to protocol description --- src/send.rs | 1 + 1 file changed, 1 insertion(+) (limited to 'src') diff --git a/src/send.rs b/src/send.rs index 2606434..0ca62fd 100644 --- a/src/send.rs +++ b/src/send.rs @@ -28,6 +28,7 @@ use crate::stream::*; // - if error: // - u8: error kind, encoded using error::io_errorkind_to_u8 // - rest: error message +// - absent for cancel message pub(crate) type RequestID = u32; pub(crate) type ChunkLength = u16; -- cgit v1.2.3 From c00676feba3819883b2888799d5f743c4ca9bca0 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:25:37 +0200 Subject: Uniformize flag naming --- src/recv.rs | 4 ++-- src/send.rs | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/recv.rs b/src/recv.rs index 8909190..0de7bef 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -91,8 +91,8 @@ pub(crate) trait RecvLoop: Sync + 'static { continue; } - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let is_error = (size & ERROR_MARKER) != 0; + let has_cont = (size & CHUNK_FLAG_HAS_CONTINUATION) != 0; + let is_error = (size & CHUNK_FLAG_ERROR) != 0; let size = (size & CHUNK_LENGTH_MASK) as usize; let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; diff --git a/src/send.rs b/src/send.rs index 0ca62fd..af5f00c 100644 --- a/src/send.rs +++ b/src/send.rs @@ -19,8 +19,8 @@ use crate::stream::*; // Chunk format: // - u32 BE: request id (same for request and response) // - u16 BE: chunk length + flags: -// CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream -// ERROR_MARKER if this chunk denotes an error +// CHUNK_FLAG_HAS_CONTINUATION when this is not the last chunk of the stream +// CHUNK_FLAG_ERROR if this chunk denotes an error // (these two flags are exclusive, an error denotes the end of the stream) // **special value** 0xFFFF indicates a CANCEL message // - [u8; chunk_length], either @@ -28,14 +28,14 @@ use crate::stream::*; // - if error: // - u8: error kind, encoded using error::io_errorkind_to_u8 // - rest: error message -// - absent for cancel message +// - absent for cancel messag pub(crate) type RequestID = u32; pub(crate) type ChunkLength = u16; pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; -pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; -pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; +pub(crate) const CHUNK_FLAG_ERROR: ChunkLength = 0x4000; +pub(crate) const CHUNK_FLAG_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF; @@ -237,8 +237,8 @@ impl DataFrame { fn header(&self) -> [u8; 2] { let header_u16 = match self { DataFrame::Data(data, false) => data.len() as u16, - DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION, - DataFrame::Error(msg) => msg.len() as u16 | ERROR_MARKER, + DataFrame::Data(data, true) => data.len() as u16 | CHUNK_FLAG_HAS_CONTINUATION, + DataFrame::Error(msg) => msg.len() as u16 | CHUNK_FLAG_ERROR, }; ChunkLength::to_be_bytes(header_u16) } -- cgit v1.2.3 From b509e6057f850971e3339404cfd2240193871402 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:28:01 +0200 Subject: Missing cleanup --- src/send.rs | 1 + 1 file changed, 1 insertion(+) (limited to 'src') diff --git a/src/send.rs b/src/send.rs index af5f00c..4e16179 100644 --- a/src/send.rs +++ b/src/send.rs @@ -84,6 +84,7 @@ impl SendQueue { for (_, prioq) in self.items.iter_mut() { prioq.remove(id); } + self.items.retain(|(_prio, q)| !q.is_empty()); } fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) -- cgit v1.2.3 From 395f942fc745f5947005cad3a0e2ac15403fdbc9 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:34:03 +0200 Subject: Fix potential memory leak --- src/send.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/send.rs b/src/send.rs index 4e16179..0db0ba7 100644 --- a/src/send.rs +++ b/src/send.rs @@ -118,6 +118,9 @@ impl SendQueuePriority { let order_vec = self.order.get_mut(&stream).unwrap(); let j = order_vec.iter().position(|x| *x == order).unwrap(); order_vec.remove(j).unwrap(); + if order_vec.is_empty() { + self.order.remove(&stream); + } } } } @@ -147,14 +150,19 @@ impl SendQueuePriority { item.sent += data_frame.data().len(); if eos || is_err { + // If item had an order tag, remove it from the corresponding ordering list if let Some(OrderTag(stream, order)) = item.order_tag { - assert_eq!( - self.order.get_mut(&stream).unwrap().pop_front(), - Some(order) - ) + let order_stream = self.order.get_mut(&stream).unwrap(); + assert_eq!(order_stream.pop_front(), Some(order)); + if order_stream.is_empty() { + self.order.remove(&stream); + } } + // Remove item from sending queue self.items.remove(j); } else { + // Move item later in send queue to implement LAS scheduling + // (LAS = Least Attained Service) for k in j..self.items.len() - 1 { if self.items[k].sent >= self.items[k + 1].sent { self.items.swap(k, k + 1); -- cgit v1.2.3 From 298e956a199711b65ce3820931ca943108b78225 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 13 Sep 2022 12:48:54 +0200 Subject: undo needless change --- src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/server.rs b/src/server.rs index f9eb121..cd367c4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -132,7 +132,7 @@ impl ServerConn { let handler_opt = { let endpoints = self.netapp.endpoints.read().unwrap(); - endpoints.get(&path[..]).map(|e| e.clone_endpoint()) + endpoints.get(&path).map(|e| e.clone_endpoint()) }; if let Some(handler) = handler_opt { -- cgit v1.2.3