diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 37 | ||||
-rw-r--r-- | src/endpoint.rs | 66 | ||||
-rw-r--r-- | src/proto.rs | 260 | ||||
-rw-r--r-- | src/server.rs | 38 | ||||
-rw-r--r-- | src/test.rs | 1 | ||||
-rw-r--r-- | src/util.rs | 17 |
6 files changed, 338 insertions, 81 deletions
diff --git a/src/client.rs b/src/client.rs index 8227e8f..bce7aca 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,10 +37,10 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, + query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>, next_query_number: AtomicU32, - inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, + inflight: Mutex<HashMap<RequestID, oneshot::Sender<(Vec<u8>, AssociatedStream)>>>, } impl ClientConn { @@ -148,9 +148,11 @@ impl ClientConn { { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; + // increment by 2; even are direct data; odd are associated stream let id = self .next_query_number - .fetch_add(1, atomic::Ordering::Relaxed); + .fetch_add(2, atomic::Ordering::Relaxed); + let stream_id = id + 1; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -166,7 +168,7 @@ impl ClientConn { }; // Encode request - let body = rmp_to_vec_all_named(rq.borrow())?; + let (body, stream) = rmp_to_vec_all_named(rq.borrow())?; drop(rq); let request = QueryMessage { @@ -185,7 +187,10 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(vec![]).is_err() { + if old_ch + .send((vec![], Box::pin(futures::stream::empty()))) + .is_err() + { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -195,15 +200,20 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); - query_send.send((id, prio, bytes))?; + query_send.send((id, prio, Data::Full(bytes)))?; + if let Some(stream) = stream { + query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?; + } else { + query_send.send((stream_id, prio, Data::Full(Vec::new())))?; + } cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let resp = resp_recv + let (resp, stream) = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let resp = resp_recv.await?; + let (resp, stream) = resp_recv.await?; } } @@ -217,10 +227,9 @@ impl ClientConn { let code = resp[0]; if code == 0 { - Ok(rmp_serde::decode::from_read_ref::< - _, - <T as Message>::Response, - >(&resp[1..])?) + let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]); + let res = T::Response::deserialize_msg(&mut deser, stream).await?; + Ok(res) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) @@ -232,12 +241,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) { + fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream) { trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send(msg).is_err() { + if ch.send((msg, stream)).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 42e9a98..81ed036 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; +use serde::de::Error as DeError; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::error::Error; use crate::netapp::*; @@ -14,8 +15,50 @@ use crate::util::*; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + fn serialize_msg<S: Serializer>( + &self, + serializer: S, + ) -> Result<(S::Ok, Option<AssociatedStream>), S::Error>; + + async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( + deserializer: D, + stream: AssociatedStream, + ) -> Result<Self, D::Error>; +} + +#[async_trait] +impl<T> SerializeMessage for T +where + T: Serialize + for<'de> Deserialize<'de> + Send + Sync, +{ + fn serialize_msg<S: Serializer>( + &self, + serializer: S, + ) -> Result<(S::Ok, Option<AssociatedStream>), S::Error> { + self.serialize(serializer).map(|r| (r, None)) + } + + async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( + deserializer: D, + mut stream: AssociatedStream, + ) -> Result<Self, D::Error> { + use futures::StreamExt; + + let res = Self::deserialize(deserializer)?; + if stream.next().await.is_some() { + return Err(D::Error::custom( + "failed to deserialize: found associated stream when none expected", + )); + } + Ok(res) + } } /// This trait should be implemented by an object of your application @@ -128,7 +171,12 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>; + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -145,11 +193,17 @@ where M: Message + 'static, H: EndpointHandler<M> + 'static, { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> { + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; + let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf); + let req = M::deserialize_msg(&mut deser, stream).await?; let res = h.handle(&req, from).await; let res_bytes = rmp_to_vec_all_named(&res)?; Ok(res_bytes) diff --git a/src/proto.rs b/src/proto.rs index e843bff..b45ff13 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,9 +1,13 @@ use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; -use log::trace; +use log::{trace, warn}; -use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::Stream; +use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -11,6 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; +use crate::util::AssociatedStream; /// Priority of a request (click to read more about priorities). /// @@ -48,14 +53,73 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, prio: RequestPriority, - data: Vec<u8>, - cursor: usize, + data: DataReader, +} + +pub(crate) enum Data { + Full(Vec<u8>), + Streaming(AssociatedStream), +} + +#[pin_project::pin_project(project = DataReaderProj)] +enum DataReader { + Full { + #[pin] + data: Vec<u8>, + pos: usize, + }, + Streaming { + #[pin] + reader: AssociatedStream, + }, +} + +impl From<Data> for DataReader { + fn from(data: Data) -> DataReader { + match data { + Data::Full(data) => DataReader::Full { data, pos: 0 }, + Data::Streaming(reader) => DataReader::Streaming { reader }, + } + } +} + +impl Stream for DataReader { + type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match self.project() { + DataReaderProj::Full { data, pos } => { + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); + let end = *pos + len; + + if len == 0 { + Poll::Ready(None) + } else { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..len].copy_from_slice(&data[*pos..end]); + *pos = end; + Poll::Ready(Some((body, len))) + } + } + DataReaderProj::Streaming { reader } => { + reader.poll_next(cx).map(|opt| { + opt.map(|v| { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); + // TODO this can throw away long vec, they should be splited instead + body[..len].copy_from_slice(&v[..len]); + (body, len) + }) + }) + } + } + } } struct SendQueue { @@ -108,7 +172,7 @@ impl SendQueue { pub(crate) trait SendLoop: Sync { async fn send_loop<W>( self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, mut write: BoxStreamWrite<W>, ) -> Result<(), Error> where @@ -118,51 +182,78 @@ pub(crate) trait SendLoop: Sync { let mut should_exit = false; while !should_exit || !sending.is_empty() { if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else if let Some(mut item) = sending.pop() { trace!( - "send_loop: sending bytes for {} ({} bytes, {} already sent)", - item.id, - item.data.len(), - item.cursor + "send_loop: sending bytes for {}", + item.id, ); + + let data = futures::select! { + data = item.data.next().fuse() => data, + default => { + // nothing to send yet; re-schedule and find something else to do + sending.push(item); + continue; + + // TODO if every SendQueueItem is waiting on data, use select_all to await + // something to do + // TODO find some way to not require sending empty last chunk + } + }; + let header_id = RequestID::to_be_bytes(item.id); write.write_all(&header_id[..]).await?; - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { + let data = match data.as_ref() { + Some((data, len)) => &data[..*len], + None => &[], + }; + + if !data.is_empty() { let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); + ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; + write.write_all(data).await?; sending.push(item); } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); + // this is always zero for now, but may be more when above TODO get fixed + let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; - write.write_all(&item.data[item.cursor..]).await?; + write.write_all(data).await?; } + write.flush().await?; } else { let sth = msg_recv.recv().await; if let Some((id, prio, data)) = sth { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else { should_exit = true; @@ -175,6 +266,41 @@ pub(crate) trait SendLoop: Sync { } } +struct ChannelPair { + receiver: Option<UnboundedReceiver<Vec<u8>>>, + sender: Option<UnboundedSender<Vec<u8>>>, +} + +impl ChannelPair { + fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } +} + +impl Default for ChannelPair { + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -184,13 +310,17 @@ pub(crate) trait SendLoop: Sync { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>); + fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream); async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving = HashMap::new(); + let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new(); + let mut streams: HashMap< + RequestID, + ChannelPair, + > = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -214,13 +344,43 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", next_slice.len()); - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); + if id & 1 == 0 { + // main stream + let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); + msg_bytes.extend_from_slice(&next_slice[..]); - if has_cont { - receiving.insert(id, msg_bytes); + if has_cont { + receiving.insert(id, msg_bytes); + } else { + let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); + + if let Some(receiver) = channel_pair.take_receiver() { + self.recv_handler(id, msg_bytes, Box::pin(receiver)); + } else { + warn!("Couldn't take receiver part of stream") + } + + channel_pair.insert_into(&mut streams, id | 1); + } } else { - self.recv_handler(id, msg_bytes); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } + + if !has_cont { + channel_pair.take_sender(); + } + + channel_pair.insert_into(&mut streams, id); } } Ok(()) @@ -236,38 +396,50 @@ mod test { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let mut q = SendQueue::new(); diff --git a/src/server.rs b/src/server.rs index 5465307..6cd4056 100644 --- a/src/server.rs +++ b/src/server.rs @@ -55,7 +55,7 @@ pub(crate) struct ServerConn { netapp: Arc<NetApp>, - resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, + resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>, } impl ServerConn { @@ -123,7 +123,11 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { + async fn recv_handler_aux( + self: &Arc<Self>, + bytes: &[u8], + stream: AssociatedStream, + ) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> { let msg = QueryMessage::decode(bytes)?; let path = String::from_utf8(msg.path.to_vec())?; @@ -156,11 +160,11 @@ impl ServerConn { span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); - handler.handle(msg.body, self.peer_id) + handler.handle(msg.body, stream, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, self.peer_id).await + handler.handle(msg.body, stream, self.peer_id).await } } } else { @@ -173,7 +177,7 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) { + fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>, stream: AssociatedStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); @@ -182,26 +186,36 @@ impl RecvLoop for ServerConn { let bytes: Bytes = bytes.into(); let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..]).await; + let resp = self2.recv_handler_aux(&bytes[..], stream).await; - let resp_bytes = match resp { - Ok(rb) => { + let (resp_bytes, resp_stream) = match resp { + Ok((rb, rs)) => { let mut resp_bytes = vec![0u8]; resp_bytes.extend(rb); - resp_bytes + (resp_bytes, rs) } Err(e) => { let mut resp_bytes = vec![e.code()]; resp_bytes.extend(e.to_string().into_bytes()); - resp_bytes + (resp_bytes, None) } }; trace!("ServerConn sending response to {}: ", id); resp_send - .send((id, prio, resp_bytes)) - .log_err("ServerConn recv_handler send resp"); + .send((id, prio, Data::Full(resp_bytes))) + .log_err("ServerConn recv_handler send resp bytes"); + + if let Some(resp_stream) = resp_stream { + resp_send + .send((id + 1, prio, Data::Streaming(resp_stream))) + .log_err("ServerConn recv_handler send resp stream"); + } else { + resp_send + .send((id + 1, prio, Data::Full(Vec::new()))) + .log_err("ServerConn recv_handler send resp stream"); + } }); } } diff --git a/src/test.rs b/src/test.rs index 82c7ba6..ecd5450 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,7 @@ use crate::NodeID; #[tokio::test(flavor = "current_thread")] async fn test_with_basic_scheduler() { + pretty_env_logger::init(); run_test().await } diff --git a/src/util.rs b/src/util.rs index f4dfac7..4333080 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,10 @@ +use crate::endpoint::SerializeMessage; + use std::net::SocketAddr; use std::net::ToSocketAddrs; +use std::pin::Pin; -use serde::Serialize; +use futures::Stream; use log::info; @@ -14,21 +17,25 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; +pub type AssociatedStream = Pin<Box<dyn Stream<Item = Vec<u8>> + Send>>; + /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. /// /// Field names and variant names are included in the serialization. /// This is used internally by the netapp communication protocol. -pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error> +pub fn rmp_to_vec_all_named<T>( + val: &T, +) -> Result<(Vec<u8>, Option<AssociatedStream>), rmp_serde::encode::Error> where - T: Serialize + ?Sized, + T: SerializeMessage + ?Sized, { let mut wr = Vec::with_capacity(128); let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - val.serialize(&mut se)?; - Ok(wr) + let (_, stream) = val.serialize_msg(&mut se)?; + Ok((wr, stream)) } /// This async function returns only when a true signal was received |