diff options
-rw-r--r-- | src/client.rs | 32 | ||||
-rw-r--r-- | src/endpoint.rs | 1 | ||||
-rw-r--r-- | src/error.rs | 4 | ||||
-rw-r--r-- | src/proto.rs | 364 | ||||
-rw-r--r-- | src/server.rs | 26 |
5 files changed, 193 insertions, 234 deletions
diff --git a/src/client.rs b/src/client.rs index bc16fb1..a630f87 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,10 +37,11 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>, + query_send: + ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>, next_query_number: AtomicU32, - inflight: Mutex<HashMap<RequestID, oneshot::Sender<(Vec<u8>, AssociatedStream)>>>, + inflight: Mutex<HashMap<RequestID, oneshot::Sender<AssociatedStream>>>, } impl ClientConn { @@ -148,11 +149,9 @@ impl ClientConn { { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; - // increment by 2; even are direct data; odd are associated stream let id = self .next_query_number - .fetch_add(2, atomic::Ordering::Relaxed); - let stream_id = id + 1; + .fetch_add(1, atomic::Ordering::Relaxed); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -187,10 +186,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch - .send((vec![], Box::pin(futures::stream::empty()))) - .is_err() - { + if old_ch.send(Box::pin(futures::stream::empty())).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -200,22 +196,18 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); - query_send.send((id, prio, Data::Full(bytes)))?; - if let Some(stream) = stream { - query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?; - } else { - query_send.send((stream_id, prio, Data::Full(Vec::new())))?; - } + query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let (resp, stream) = resp_recv + let stream = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let (resp, stream) = resp_recv.await?; + let stream = resp_recv.await?; } } + let (resp, stream) = Framing::from_stream(stream).await?.into_parts(); if resp.is_empty() { return Err(Error::Message( @@ -240,12 +232,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream) { - trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) { + trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send((msg, stream)).is_err() { + if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index c430d4e..f31141d 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -23,7 +23,6 @@ pub trait Message: SerializeMessage + Send + Sync { pub trait SerializeMessage: Sized { type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - // TODO should return Result fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>); // TODO should return Result diff --git a/src/error.rs b/src/error.rs index 99acdd1..7911c29 100644 --- a/src/error.rs +++ b/src/error.rs @@ -25,6 +25,9 @@ pub enum Error { #[error(display = "UTF8 error: {}", _0)] UTF8(#[error(source)] std::string::FromUtf8Error), + #[error(display = "Framing protocol error")] + Framing, + #[error(display = "{}", _0)] Message(String), @@ -50,6 +53,7 @@ impl Error { Self::RMPEncode(_) => 10, Self::RMPDecode(_) => 11, Self::UTF8(_) => 12, + Self::Framing => 13, Self::NoHandler => 20, Self::ConnectionClosed => 21, Self::Handshake(_) => 30, diff --git a/src/proto.rs b/src/proto.rs index e3f9be8..d6dc35a 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -3,11 +3,11 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use log::{trace, warn}; +use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::Stream; +use futures::channel::mpsc::{unbounded, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -63,39 +63,24 @@ struct SendQueueItem { data: DataReader, } -pub(crate) enum Data { - Full(Vec<u8>), - Streaming(AssociatedStream), +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: AssociatedStream, + packet: Result<Vec<u8>, u8>, + pos: usize, + buf: Vec<u8>, + eos: bool, } -#[pin_project::pin_project(project = DataReaderProj)] -enum DataReader { - Full { - #[pin] - data: Vec<u8>, - pos: usize, - }, - Streaming { - #[pin] - reader: AssociatedStream, - packet: Result<Vec<u8>, u8>, - pos: usize, - buf: Vec<u8>, - eos: bool, - }, -} - -impl From<Data> for DataReader { - fn from(data: Data) -> DataReader { - match data { - Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { - reader, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - }, +impl From<AssociatedStream> for DataReader { + fn from(data: AssociatedStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, } } } @@ -155,82 +140,60 @@ impl Stream for DataReader { type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - match self.project() { - DataReaderProj::Full { data, pos } => { - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); - let end = *pos + len; - - if len == 0 { - Poll::Ready(None) - } else { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..len].copy_from_slice(&data[*pos..end]); - *pos = end; - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: end < data.len(), - })) + let mut this = self.project(); + + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; } - DataReaderProj::Streaming { - mut reader, - packet: res_packet, - pos, - buf, - eos, - } => { - if *eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - loop { - let packet = match res_packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *res_packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *pos; - let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - buf.extend_from_slice(&packet[*pos..*pos + to_read]); - *pos += to_read; - if buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *res_packet = p; - *pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if res_packet.is_err() && !buf.is_empty() { - break; - } - } else { - *eos = true; - break; - } + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = buf.len(); - body[..len].copy_from_slice(buf); - buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*eos, - })) + } else { + *this.eos = true; + break; } } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) } } @@ -334,7 +297,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { pub(crate) trait SendLoop: Sync { async fn send_loop<W>( self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, mut write: BoxStreamWrite<W>, ) -> Result<(), Error> where @@ -380,38 +343,82 @@ pub(crate) trait SendLoop: Sync { } } -struct ChannelPair { - receiver: Option<UnboundedReceiver<Vec<u8>>>, - sender: Option<UnboundedSender<Vec<u8>>>, +pub(crate) struct Framing { + direct: Vec<u8>, + stream: Option<AssociatedStream>, } -impl ChannelPair { - fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> { - self.receiver.take() +impl Framing { + pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } } - fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> { - self.sender.take() - } + pub fn into_stream(self) -> AssociatedStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; - fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> { - self.sender.as_ref().take() - } + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); - fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) } } -} -impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), + pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>( + mut stream: S, + ) -> Result<Self, Error> { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; } + + let stream: AssociatedStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) } } @@ -424,14 +431,13 @@ impl Default for ChannelPair { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream); async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new(); - let mut streams: HashMap<RequestID, ChannelPair> = HashMap::new(); + let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -450,55 +456,30 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; - let size = if !is_error { - size & !CHUNK_HAS_CONTINUATION + let packet = if is_error { + Err(size as u8) } else { - 0 + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) }; - // TODO propagate errors - - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - - if id & 1 == 0 { - // main stream - let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); - - if has_cont { - receiving.insert(id, msg_bytes); - } else { - let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); - - if let Some(receiver) = channel_pair.take_receiver() { - use futures::StreamExt; - self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); - } else { - warn!("Couldn't take receiver part of stream") - } - channel_pair.insert_into(&mut streams, id | 1); - } + let sender = if let Some(send) = streams.remove(&(id)) { + send } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + let (send, recv) = unbounded(); + self.recv_handler(id, Box::pin(recv)); + send + }; - if !has_cont { - channel_pair.take_sender(); - } + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + let _ = sender.unbounded_send(packet); - channel_pair.insert_into(&mut streams, id); + if has_cont { + streams.insert(id, sender); } } Ok(()) @@ -509,55 +490,44 @@ pub(crate) trait RecvLoop: Sync + 'static { mod test { use super::*; + fn empty_data() -> DataReader { + type Item = Result<Vec<u8>, u8>; + let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> = + Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>()); + stream.into() + } + #[test] fn test_priority_queue() { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let mut q = SendQueue::new(); diff --git a/src/server.rs b/src/server.rs index 6cd4056..86e5156 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,6 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; -use bytes::Bytes; use log::{debug, trace}; #[cfg(feature = "telemetry")] @@ -55,7 +54,7 @@ pub(crate) struct ServerConn { netapp: Arc<NetApp>, - resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>, + resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>, } impl ServerConn { @@ -177,13 +176,13 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>, stream: AssociatedStream) { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); tokio::spawn(async move { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); + trace!("ServerConn recv_handler {}", id); + let (bytes, stream) = Framing::from_stream(stream).await?.into_parts(); let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; let resp = self2.recv_handler_aux(&bytes[..], stream).await; @@ -204,18 +203,13 @@ impl RecvLoop for ServerConn { trace!("ServerConn sending response to {}: ", id); resp_send - .send((id, prio, Data::Full(resp_bytes))) + .send(( + id, + prio, + Framing::new(resp_bytes, resp_stream).into_stream(), + )) .log_err("ServerConn recv_handler send resp bytes"); - - if let Some(resp_stream) = resp_stream { - resp_send - .send((id + 1, prio, Data::Streaming(resp_stream))) - .log_err("ServerConn recv_handler send resp stream"); - } else { - resp_send - .send((id + 1, prio, Data::Full(Vec::new()))) - .log_err("ServerConn recv_handler send resp stream"); - } + Ok::<_, Error>(()) }); } } |