From f35fa7d18d9e0f51bed311355ec1310b1d311ab3 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 17:34:53 +0200 Subject: Move things around --- src/client.rs | 18 +- src/endpoint.rs | 78 +----- src/error.rs | 2 +- src/lib.rs | 5 +- src/message.rs | 255 ++++++++++++++++++++ src/netapp.rs | 2 +- src/peering/basalt.rs | 3 +- src/peering/fullmesh.rs | 3 +- src/proto.rs | 617 ------------------------------------------------ src/proto2.rs | 75 ------ src/recv.rs | 114 +++++++++ src/send.rs | 410 ++++++++++++++++++++++++++++++++ src/server.rs | 32 ++- src/util.rs | 12 +- 14 files changed, 820 insertions(+), 806 deletions(-) create mode 100644 src/message.rs delete mode 100644 src/proto.rs delete mode 100644 src/proto2.rs create mode 100644 src/recv.rs create mode 100644 src/send.rs (limited to 'src') diff --git a/src/client.rs b/src/client.rs index 6d49f5c..663a3e4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,9 +5,12 @@ use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; +use async_trait::async_trait; use log::{debug, error, trace}; use futures::channel::mpsc::{unbounded, UnboundedReceiver}; +use futures::io::AsyncReadExt; +use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -21,25 +24,18 @@ use opentelemetry::{ #[cfg(feature = "telemetry")] use opentelemetry_contrib::trace::propagator::binary::*; -use futures::io::AsyncReadExt; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_client, BoxStream}; - -use crate::endpoint::*; use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; use crate::util::*; pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: - ArcSwapOption>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, inflight: Mutex>>>, diff --git a/src/endpoint.rs b/src/endpoint.rs index f31141d..e6b2236 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,79 +5,11 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - use crate::error::Error; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; use crate::util::*; -/// This trait should be implemented by all messages your application -/// wants to handle -pub trait Message: SerializeMessage + Send + Sync { - type Response: SerializeMessage + Send + Sync; -} - -/// A trait for de/serializing messages, with possible associated stream. -#[async_trait] -pub trait SerializeMessage: Sized { - type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - - fn serialize_msg(&self) -> (Self::SerializableSelf, Option); - - // TODO should return Result - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self; -} - -pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} - -#[async_trait] -impl SerializeMessage for T -where - T: AutoSerialize, -{ - type SerializableSelf = Self; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - (self.clone(), None) - } - - async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self { - // TODO verify no stream - ser_self - } -} - -impl AutoSerialize for () {} - -#[async_trait] -impl SerializeMessage for Result -where - T: SerializeMessage + Send, - E: SerializeMessage + Send, -{ - type SerializableSelf = Result; - - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - match self { - Ok(ok) => { - let (msg, stream) = ok.serialize_msg(); - (Ok(msg), stream) - } - Err(err) => { - let (msg, stream) = err.serialize_msg(); - (Err(msg), stream) - } - } - } - - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { - match ser_self { - Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), - Err(err) => Err(E::deserialize_msg(err, stream).await), - } - } -} - /// This trait should be implemented by an object of your application /// that can handle a message of type `M`. /// @@ -191,9 +123,9 @@ pub(crate) trait GenericEndpoint { async fn handle( &self, buf: &[u8], - stream: AssociatedStream, + stream: ByteStream, from: NodeID, - ) -> Result<(Vec, Option), Error>; + ) -> Result<(Vec, Option), Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -213,9 +145,9 @@ where async fn handle( &self, buf: &[u8], - stream: AssociatedStream, + stream: ByteStream, from: NodeID, - ) -> Result<(Vec, Option), Error> { + ) -> Result<(Vec, Option), Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { diff --git a/src/error.rs b/src/error.rs index 7911c29..665647c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ -use err_derive::Error; use std::io; +use err_derive::Error; use log::error; #[derive(Debug, Error)] diff --git a/src/lib.rs b/src/lib.rs index cb24337..1edb919 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,10 +17,11 @@ pub mod error; pub mod util; pub mod endpoint; -pub mod proto; +pub mod message; mod client; -mod proto2; +mod recv; +mod send; mod server; pub mod netapp; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..dbcc857 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,255 @@ +use async_trait::async_trait; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::error::*; +use crate::util::*; + +/// Priority of a request (click to read more about priorities). +/// +/// This priority value is used to priorize messages +/// in the send queue of the client, and their responses in the send queue of the +/// server. Lower values mean higher priority. +/// +/// This mechanism is usefull for messages bigger than the maximum chunk size +/// (set at `0x4000` bytes), such as large file transfers. +/// In such case, all of the messages in the send queue with the highest priority +/// will take turns to send individual chunks, in a round-robin fashion. +/// Once all highest priority messages are sent successfully, the messages with +/// the next highest priority will begin being sent in the same way. +/// +/// The same priority value is given to a request and to its associated response. +pub type RequestPriority = u8; + +/// Priority class: high +pub const PRIO_HIGH: RequestPriority = 0x20; +/// Priority class: normal +pub const PRIO_NORMAL: RequestPriority = 0x40; +/// Priority class: background +pub const PRIO_BACKGROUND: RequestPriority = 0x80; +/// Priority: primary among given class +pub const PRIO_PRIMARY: RequestPriority = 0x00; +/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) +pub const PRIO_SECONDARY: RequestPriority = 0x01; + +// ---- + +/// This trait should be implemented by all messages your application +/// wants to handle +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +} + +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + +#[async_trait] +impl SerializeMessage for T +where + T: AutoSerialize, +{ + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + (self.clone(), None) + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { + // TODO verify no stream + ser_self + } +} + +impl AutoSerialize for () {} + +#[async_trait] +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), + } + } +} + +// ---- + +pub(crate) struct QueryMessage<'a> { + pub(crate) prio: RequestPriority, + pub(crate) path: &'a [u8], + pub(crate) telemetry_id: Option>, + pub(crate) body: &'a [u8], +} + +/// QueryMessage encoding: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - body [u8; ..] +impl<'a> QueryMessage<'a> { + pub(crate) fn encode(self) -> Vec { + let tel_len = match &self.telemetry_id { + Some(t) => t.len(), + None => 0, + }; + + let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + + ret.push(self.prio); + + ret.push(self.path.len() as u8); + ret.extend_from_slice(self.path); + + if let Some(t) = self.telemetry_id { + ret.push(t.len() as u8); + ret.extend(t); + } else { + ret.push(0u8); + } + + ret.extend_from_slice(self.body); + + ret + } + + pub(crate) fn decode(bytes: &'a [u8]) -> Result { + if bytes.len() < 3 { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path_length = bytes[1] as usize; + if bytes.len() < 3 + path_length { + return Err(Error::Message("Invalid protocol message".into())); + } + + let telemetry_id_len = bytes[2 + path_length] as usize; + if bytes.len() < 3 + path_length + telemetry_id_len { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path = &bytes[2..2 + path_length]; + let telemetry_id = if telemetry_id_len > 0 { + Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) + } else { + None + }; + + let body = &bytes[3 + path_length + telemetry_id_len..]; + + Ok(Self { + prio: bytes[0], + path, + telemetry_id, + body, + }) + } +} + +pub(crate) struct Framing { + direct: Vec, + stream: Option, +} + +impl Framing { + pub fn new(direct: Vec, stream: Option) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } + } + + pub fn into_stream(self) -> ByteStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; + + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); + + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) + } + } + + pub async fn from_stream + Unpin + Send + 'static>( + mut stream: S, + ) -> Result { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + } + + let stream: ByteStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec, ByteStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + } +} diff --git a/src/netapp.rs b/src/netapp.rs index 27f17e6..dd22d90 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -20,7 +20,7 @@ use tokio::sync::{mpsc, watch}; use crate::client::*; use crate::endpoint::*; use crate::error::*; -use crate::proto::*; +use crate::message::*; use crate::server::*; use crate::util::*; diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 7f77995..98977a3 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -14,8 +14,9 @@ use sodiumoxide::crypto::hash; use tokio::sync::watch; use crate::endpoint::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; +use crate::send::*; use crate::NodeID; // -- Protocol messages -- diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 7dfc5c4..5b489ae 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -17,7 +17,8 @@ use sodiumoxide::crypto::hash; use crate::endpoint::*; use crate::error::*; use crate::netapp::*; -use crate::proto::*; + +use crate::message::*; use crate::NodeID; const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30); diff --git a/src/proto.rs b/src/proto.rs deleted file mode 100644 index 92d8d80..0000000 --- a/src/proto.rs +++ /dev/null @@ -1,617 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use log::trace; - -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::{AsyncReadExt, AsyncWriteExt}; -use futures::{Stream, StreamExt}; -use kuska_handshake::async_std::BoxStreamWrite; - -use tokio::sync::mpsc; - -use async_trait::async_trait; - -use crate::error::*; -use crate::util::{AssociatedStream, Packet}; - -/// Priority of a request (click to read more about priorities). -/// -/// This priority value is used to priorize messages -/// in the send queue of the client, and their responses in the send queue of the -/// server. Lower values mean higher priority. -/// -/// This mechanism is usefull for messages bigger than the maximum chunk size -/// (set at `0x4000` bytes), such as large file transfers. -/// In such case, all of the messages in the send queue with the highest priority -/// will take turns to send individual chunks, in a round-robin fashion. -/// Once all highest priority messages are sent successfully, the messages with -/// the next highest priority will begin being sent in the same way. -/// -/// The same priority value is given to a request and to its associated response. -pub type RequestPriority = u8; - -/// Priority class: high -pub const PRIO_HIGH: RequestPriority = 0x20; -/// Priority class: normal -pub const PRIO_NORMAL: RequestPriority = 0x40; -/// Priority class: background -pub const PRIO_BACKGROUND: RequestPriority = 0x80; -/// Priority: primary among given class -pub const PRIO_PRIMARY: RequestPriority = 0x00; -/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) -pub const PRIO_SECONDARY: RequestPriority = 0x01; - -// Messages are sent by chunks -// Chunk format: -// - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data - -pub(crate) type RequestID = u32; -type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; -const ERROR_MARKER: ChunkLength = 0x4000; -const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; - -struct SendQueueItem { - id: RequestID, - prio: RequestPriority, - data: DataReader, -} - -#[pin_project::pin_project] -struct DataReader { - #[pin] - reader: AssociatedStream, - packet: Packet, - pos: usize, - buf: Vec, - eos: bool, -} - -impl From for DataReader { - fn from(data: AssociatedStream) -> DataReader { - DataReader { - reader: data, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - } - } -} - -enum DataFrame { - Data { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actual lenght of data - len: usize, - }, - Error(u8), -} - -struct DataReaderItem { - data: DataFrame, - /// whethere there may be more data comming from this stream. Can be used for some - /// optimization. It's an error to set it to false if there is more data, but it is correct - /// (albeit sub-optimal) to set it to true if there is nothing coming after - may_have_more: bool, -} - -impl DataReaderItem { - fn empty_last() -> Self { - DataReaderItem { - data: DataFrame::Data { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, - }, - may_have_more: false, - } - } - - fn header(&self) -> [u8; 2] { - let continuation = if self.may_have_more { - CHUNK_HAS_CONTINUATION - } else { - 0 - }; - let len = match self.data { - DataFrame::Data { len, .. } => len as u16, - DataFrame::Error(e) => e as u16 | ERROR_MARKER, - }; - - ChunkLength::to_be_bytes(len | continuation) - } - - fn data(&self) -> &[u8] { - match self.data { - DataFrame::Data { ref data, len } => &data[..len], - DataFrame::Error(_) => &[], - } - } -} - -impl Stream for DataReader { - type Item = DataReaderItem; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - if *this.eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - - loop { - let packet = match this.packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *this.packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *this.pos; - let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - this.buf - .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); - *this.pos += to_read; - if this.buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { - *this.packet = p; - *this.pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if this.packet.is_err() && !this.buf.is_empty() { - break; - } - } else { - *this.eos = true; - break; - } - } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = this.buf.len(); - body[..len].copy_from_slice(this.buf); - this.buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*this.eos, - })) - } -} - -struct SendQueue { - items: VecDeque<(u8, VecDeque)>, -} - -impl SendQueue { - fn new() -> Self { - Self { - items: VecDeque::with_capacity(64), - } - } - fn push(&mut self, item: SendQueueItem) { - let prio = item.prio; - let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { - Ok(i) => i, - Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); - i - } - }; - self.items[pos_prio].1.push_back(item); - } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] - fn pop(&mut self) -> Option { - match self.items.pop_front() { - None => None, - Some((prio, mut items_at_prio)) => { - let ret = items_at_prio.pop_front(); - if !items_at_prio.is_empty() { - self.items.push_front((prio, items_at_prio)); - } - ret.or_else(|| self.pop()) - } - } - } - fn is_empty(&self) -> bool { - self.items.iter().all(|(_k, v)| v.is_empty()) - } - - // this is like an async fn, but hand implemented - fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { - SendQueuePollNextReady { queue: self } - } -} - -struct SendQueuePollNextReady<'a> { - queue: &'a mut SendQueue, -} - -impl<'a> futures::Future for SendQueuePollNextReady<'a> { - type Output = (RequestID, DataReaderItem); - - fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - for i in 0..self.queue.items.len() { - let (_prio, items_at_prio) = &mut self.queue.items[i]; - - for _ in 0..items_at_prio.len() { - let mut item = items_at_prio.pop_front().unwrap(); - - match Pin::new(&mut item.data).poll_next(ctx) { - Poll::Pending => items_at_prio.push_back(item), - Poll::Ready(Some(data)) => { - let id = item.id; - if data.may_have_more { - self.queue.push(item); - } else { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - } - return Poll::Ready((id, data)); - } - Poll::Ready(None) => { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - return Poll::Ready((item.id, DataReaderItem::empty_last())); - } - } - } - } - // TODO what do we do if self.queue is empty? We won't get scheduled again. - Poll::Pending - } -} - -/// The SendLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` -/// that takes a channel of messages to send and an asynchronous writer, -/// and sends messages from the channel to the async writer, putting them in a queue -/// before being sent and doing the round-robin sending strategy. -/// -/// The `.send_loop()` exits when the sending end of the channel is closed, -/// or if there is an error at any time writing to the async writer. -#[async_trait] -pub(crate) trait SendLoop: Sync { - async fn send_loop( - self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, - mut write: BoxStreamWrite, - ) -> Result<(), Error> - where - W: AsyncWriteExt + Unpin + Send + Sync, - { - let mut sending = SendQueue::new(); - let mut should_exit = false; - while !should_exit || !sending.is_empty() { - let recv_fut = msg_recv.recv(); - futures::pin_mut!(recv_fut); - let send_fut = sending.next_ready(); - - // recv_fut is cancellation-safe according to tokio doc, - // send_fut is cancellation-safe as implemented above? - use futures::future::Either; - match futures::future::select(recv_fut, send_fut).await { - Either::Left((sth, _send_fut)) => { - if let Some((id, prio, data)) = sth { - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; - }; - } - Either::Right(((id, data), _recv_fut)) => { - trace!("send_loop: sending bytes for {}", id); - - let header_id = RequestID::to_be_bytes(id); - write.write_all(&header_id[..]).await?; - - write.write_all(&data.header()).await?; - write.write_all(data.data()).await?; - write.flush().await?; - } - } - } - - let _ = write.goodbye().await; - Ok(()) - } -} - -pub(crate) struct Framing { - direct: Vec, - stream: Option, -} - -impl Framing { - pub fn new(direct: Vec, stream: Option) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } - } - - pub fn into_stream(self) -> AssociatedStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; - - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) - .chain(stream::once(async move { Ok(direct) })); - - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } - - pub async fn from_stream + Unpin + Send + 'static>( - mut stream: S, - ) -> Result { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); - } - - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet.drain(..4); - - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); - - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet.drain(..max_cp); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } - - let stream: AssociatedStream = if packet.is_empty() { - Box::pin(stream) - } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; - - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec, AssociatedStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) - } -} - -/// Structure to warn when the sender is dropped before end of stream was reached, like when -/// connection to some remote drops while transmitting data -struct Sender { - inner: UnboundedSender, - closed: bool, -} - -impl Sender { - fn new(inner: UnboundedSender) -> Self { - Sender { - inner, - closed: false, - } - } - - fn send(&self, packet: Packet) { - let _ = self.inner.unbounded_send(packet); - } - - fn end(&mut self) { - self.closed = true; - } -} - -impl Drop for Sender { - fn drop(&mut self) { - if !self.closed { - self.send(Err(255)); - } - self.inner.close_channel(); - } -} - -/// The RecvLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` -/// and a prototype of a handler for received messages `.recv_handler()` that -/// must be filled by implementors. `.recv_loop()` receives messages in a loop -/// according to the protocol defined above: chunks of message in progress of being -/// received are stored in a buffer, and when the last chunk of a message is received, -/// the full message is passed to the receive handler. -#[async_trait] -pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); - - async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> - where - R: AsyncReadExt + Unpin + Send + Sync, - { - let mut streams: HashMap = HashMap::new(); - loop { - trace!("recv_loop: reading packet"); - let mut header_id = [0u8; RequestID::BITS as usize / 8]; - match read.read_exact(&mut header_id[..]).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e.into()), - }; - let id = RequestID::from_be_bytes(header_id); - trace!("recv_loop: got header id: {:04x}", id); - - let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; - read.read_exact(&mut header_size[..]).await?; - let size = ChunkLength::from_be_bytes(header_size); - trace!("recv_loop: got header size: {:04x}", size); - - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let is_error = (size & ERROR_MARKER) != 0; - let packet = if is_error { - Err(size as u8) - } else { - let size = size & !CHUNK_HAS_CONTINUATION; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - Ok(next_slice) - }; - - let mut sender = if let Some(send) = streams.remove(&(id)) { - send - } else { - let (send, recv) = unbounded(); - self.recv_handler(id, recv); - Sender::new(send) - }; - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - sender.send(packet); - - if has_cont { - streams.insert(id, sender); - } else { - sender.end(); - } - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn empty_data() -> DataReader { - type Item = Packet; - let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::()); - stream.into() - } - - #[test] - fn test_priority_queue() { - let i1 = SendQueueItem { - id: 1, - prio: PRIO_NORMAL, - data: empty_data(), - }; - let i2 = SendQueueItem { - id: 2, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i2bis = SendQueueItem { - id: 20, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i3 = SendQueueItem { - id: 3, - prio: PRIO_HIGH | PRIO_SECONDARY, - data: empty_data(), - }; - let i4 = SendQueueItem { - id: 4, - prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: empty_data(), - }; - let i5 = SendQueueItem { - id: 5, - prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: empty_data(), - }; - - let mut q = SendQueue::new(); - - q.push(i1); // 1 - let a = q.pop().unwrap(); // empty -> 1 - assert_eq!(a.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 1 - q.push(i2); // 2 1 - q.push(i2bis); // [2 20] 1 - let a = q.pop().unwrap(); // 20 1 -> 2 - assert_eq!(a.id, 2); - let b = q.pop().unwrap(); // 1 -> 20 - assert_eq!(b.id, 20); - let c = q.pop().unwrap(); // empty -> 1 - assert_eq!(c.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 2 - q.push(b); // [2 20] - q.push(c); // [2 20] 1 - q.push(i3); // [2 20] 3 1 - q.push(i4); // [2 20] 3 1 4 - q.push(i5); // [2 20] 3 1 5 4 - - let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 - assert_eq!(a.id, 2); - q.push(a); // [20 2] 3 1 5 4 - - let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 - assert_eq!(a.id, 20); - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - q.push(b); // 2 3 1 5 4 - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - let c = q.pop().unwrap(); // 1 5 4 -> 3 - assert_eq!(c.id, 3); - q.push(b); // 2 1 5 4 - let b = q.pop().unwrap(); // 1 5 4 -> 2 - assert_eq!(b.id, 2); - let e = q.pop().unwrap(); // 5 4 -> 1 - assert_eq!(e.id, 1); - let f = q.pop().unwrap(); // 4 -> 5 - assert_eq!(f.id, 5); - let g = q.pop().unwrap(); // empty -> 4 - assert_eq!(g.id, 4); - assert!(q.pop().is_none()); - } -} diff --git a/src/proto2.rs b/src/proto2.rs deleted file mode 100644 index 7210781..0000000 --- a/src/proto2.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::error::*; -use crate::proto::*; - -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option>, - pub(crate) body: &'a [u8], -} - -/// QueryMessage encoding: -/// - priority: u8 -/// - path length: u8 -/// - path: [u8; path length] -/// - telemetry id length: u8 -/// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; - - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); - - ret.push(self.prio); - - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); - - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); - } else { - ret.push(0u8); - } - - ret.extend_from_slice(self.body); - - ret - } - - pub(crate) fn decode(bytes: &'a [u8]) -> Result { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } - - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; - - let body = &bytes[3 + path_length + telemetry_id_len..]; - - Ok(Self { - prio: bytes[0], - path, - telemetry_id, - body, - }) - } -} diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..628612b --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use std::sync::Arc; + +use log::trace; + +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::AsyncReadExt; + +use async_trait::async_trait; + +use crate::error::*; + +use crate::send::*; +use crate::util::Packet; + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); + + async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> + where + R: AsyncReadExt + Unpin + Send + Sync, + { + let mut streams: HashMap = HashMap::new(); + loop { + trace!("recv_loop: reading packet"); + let mut header_id = [0u8; RequestID::BITS as usize / 8]; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; + let id = RequestID::from_be_bytes(header_id); + trace!("recv_loop: got header id: {:04x}", id); + + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; + read.read_exact(&mut header_size[..]).await?; + let size = ChunkLength::from_be_bytes(header_size); + trace!("recv_loop: got header size: {:04x}", size); + + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; + let is_error = (size & ERROR_MARKER) != 0; + let packet = if is_error { + Err(size as u8) + } else { + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) + }; + + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = unbounded(); + self.recv_handler(id, recv); + Sender::new(send) + }; + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + sender.send(packet); + + if has_cont { + streams.insert(id, sender); + } else { + sender.end(); + } + } + Ok(()) + } +} diff --git a/src/send.rs b/src/send.rs new file mode 100644 index 0000000..330d41d --- /dev/null +++ b/src/send.rs @@ -0,0 +1,410 @@ +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use log::trace; + +use futures::AsyncWriteExt; +use futures::Stream; +use kuska_handshake::async_std::BoxStreamWrite; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::message::*; +use crate::util::{ByteStream, Packet}; + +// Messages are sent by chunks +// Chunk format: +// - u32 BE: request id (same for request and response) +// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag +// when this is not the last chunk of the message +// - [u8; chunk_length] chunk data + +pub(crate) type RequestID = u32; +pub(crate) type ChunkLength = u16; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; +pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; + +struct SendQueueItem { + id: RequestID, + prio: RequestPriority, + data: DataReader, +} + +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: ByteStream, + packet: Packet, + pos: usize, + buf: Vec, + eos: bool, +} + +impl From for DataReader { + fn from(data: ByteStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, + } + } +} + +enum DataFrame { + Data { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actual lenght of data + len: usize, + }, + Error(u8), +} + +struct DataReaderItem { + data: DataFrame, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, +} + +impl DataReaderItem { + fn empty_last() -> Self { + DataReaderItem { + data: DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + }, + may_have_more: false, + } + } + + fn header(&self) -> [u8; 2] { + let continuation = if self.may_have_more { + CHUNK_HAS_CONTINUATION + } else { + 0 + }; + let len = match self.data { + DataFrame::Data { len, .. } => len as u16, + DataFrame::Error(e) => e as u16 | ERROR_MARKER, + }; + + ChunkLength::to_be_bytes(len | continuation) + } + + fn data(&self) -> &[u8] { + match self.data { + DataFrame::Data { ref data, len } => &data[..len], + DataFrame::Error(_) => &[], + } + } +} + +impl Stream for DataReader { + type Item = DataReaderItem; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); + } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; + } + + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; + } + } else { + *this.eos = true; + break; + } + } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) + } +} + +struct SendQueue { + items: VecDeque<(u8, VecDeque)>, +} + +impl SendQueue { + fn new() -> Self { + Self { + items: VecDeque::with_capacity(64), + } + } + fn push(&mut self, item: SendQueueItem) { + let prio = item.prio; + let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { + Ok(i) => i, + Err(i) => { + self.items.insert(i, (prio, VecDeque::new())); + i + } + }; + self.items[pos_prio].1.push_back(item); + } + // used only in tests. They should probably be rewriten + #[allow(dead_code)] + fn pop(&mut self) -> Option { + match self.items.pop_front() { + None => None, + Some((prio, mut items_at_prio)) => { + let ret = items_at_prio.pop_front(); + if !items_at_prio.is_empty() { + self.items.push_front((prio, items_at_prio)); + } + ret.or_else(|| self.pop()) + } + } + } + fn is_empty(&self) -> bool { + self.items.iter().all(|(_k, v)| v.is_empty()) + } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataReaderItem); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for i in 0..self.queue.items.len() { + let (_prio, items_at_prio) = &mut self.queue.items[i]; + + for _ in 0..items_at_prio.len() { + let mut item = items_at_prio.pop_front().unwrap(); + + match Pin::new(&mut item.data).poll_next(ctx) { + Poll::Pending => items_at_prio.push_back(item), + Poll::Ready(Some(data)) => { + let id = item.id; + if data.may_have_more { + self.queue.push(item); + } else { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + } + return Poll::Ready((id, data)); + } + Poll::Ready(None) => { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + return Poll::Ready((item.id, DataReaderItem::empty_last())); + } + } + } + } + // TODO what do we do if self.queue is empty? We won't get scheduled again. + Poll::Pending + } +} + +/// The SendLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` +/// that takes a channel of messages to send and an asynchronous writer, +/// and sends messages from the channel to the async writer, putting them in a queue +/// before being sent and doing the round-robin sending strategy. +/// +/// The `.send_loop()` exits when the sending end of the channel is closed, +/// or if there is an error at any time writing to the async writer. +#[async_trait] +pub(crate) trait SendLoop: Sync { + async fn send_loop( + self: Arc, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>, + mut write: BoxStreamWrite, + ) -> Result<(), Error> + where + W: AsyncWriteExt + Unpin + Send + Sync, + { + let mut sending = SendQueue::new(); + let mut should_exit = false; + while !should_exit || !sending.is_empty() { + let recv_fut = msg_recv.recv(); + futures::pin_mut!(recv_fut); + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + use futures::future::Either; + match futures::future::select(recv_fut, send_fut).await { + Either::Left((sth, _send_fut)) => { + if let Some((id, prio, data)) = sth { + sending.push(SendQueueItem { + id, + prio, + data: data.into(), + }); + } else { + should_exit = true; + }; + } + Either::Right(((id, data), _recv_fut)) => { + trace!("send_loop: sending bytes for {}", id); + + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; + write.flush().await?; + } + } + } + + let _ = write.goodbye().await; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn empty_data() -> DataReader { + type Item = Packet; + let stream: Pin + Send + 'static>> = + Box::pin(futures::stream::empty::()); + stream.into() + } + + #[test] + fn test_priority_queue() { + let i1 = SendQueueItem { + id: 1, + prio: PRIO_NORMAL, + data: empty_data(), + }; + let i2 = SendQueueItem { + id: 2, + prio: PRIO_HIGH, + data: empty_data(), + }; + let i2bis = SendQueueItem { + id: 20, + prio: PRIO_HIGH, + data: empty_data(), + }; + let i3 = SendQueueItem { + id: 3, + prio: PRIO_HIGH | PRIO_SECONDARY, + data: empty_data(), + }; + let i4 = SendQueueItem { + id: 4, + prio: PRIO_BACKGROUND | PRIO_SECONDARY, + data: empty_data(), + }; + let i5 = SendQueueItem { + id: 5, + prio: PRIO_BACKGROUND | PRIO_PRIMARY, + data: empty_data(), + }; + + let mut q = SendQueue::new(); + + q.push(i1); // 1 + let a = q.pop().unwrap(); // empty -> 1 + assert_eq!(a.id, 1); + assert!(q.pop().is_none()); + + q.push(a); // 1 + q.push(i2); // 2 1 + q.push(i2bis); // [2 20] 1 + let a = q.pop().unwrap(); // 20 1 -> 2 + assert_eq!(a.id, 2); + let b = q.pop().unwrap(); // 1 -> 20 + assert_eq!(b.id, 20); + let c = q.pop().unwrap(); // empty -> 1 + assert_eq!(c.id, 1); + assert!(q.pop().is_none()); + + q.push(a); // 2 + q.push(b); // [2 20] + q.push(c); // [2 20] 1 + q.push(i3); // [2 20] 3 1 + q.push(i4); // [2 20] 3 1 4 + q.push(i5); // [2 20] 3 1 5 4 + + let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 + assert_eq!(a.id, 2); + q.push(a); // [20 2] 3 1 5 4 + + let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 + assert_eq!(a.id, 20); + let b = q.pop().unwrap(); // 3 1 5 4 -> 2 + assert_eq!(b.id, 2); + q.push(b); // 2 3 1 5 4 + let b = q.pop().unwrap(); // 3 1 5 4 -> 2 + assert_eq!(b.id, 2); + let c = q.pop().unwrap(); // 1 5 4 -> 3 + assert_eq!(c.id, 3); + q.push(b); // 2 1 5 4 + let b = q.pop().unwrap(); // 1 5 4 -> 2 + assert_eq!(b.id, 2); + let e = q.pop().unwrap(); // 5 4 -> 1 + assert_eq!(e.id, 1); + let f = q.pop().unwrap(); // 4 -> 5 + assert_eq!(f.id, 5); + let g = q.pop().unwrap(); // empty -> 4 + assert_eq!(g.id, 4); + assert!(q.pop().is_none()); + } +} diff --git a/src/server.rs b/src/server.rs index 8075484..1f1c22a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,8 +2,17 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; +use async_trait::async_trait; use log::{debug, trace}; +use futures::channel::mpsc::UnboundedReceiver; +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::{handshake_server, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, watch}; +use tokio_util::compat::*; + #[cfg(feature = "telemetry")] use opentelemetry::{ trace::{FutureExt, Span, SpanKind, TraceContextExt, TraceId, Tracer}, @@ -14,22 +23,11 @@ use opentelemetry_contrib::trace::propagator::binary::*; #[cfg(feature = "telemetry")] use rand::{thread_rng, Rng}; -use tokio::net::TcpStream; -use tokio::select; -use tokio::sync::{mpsc, watch}; -use tokio_util::compat::*; - -use futures::channel::mpsc::UnboundedReceiver; -use futures::io::{AsyncReadExt, AsyncWriteExt}; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_server, BoxStream}; - use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -55,7 +53,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -126,8 +124,8 @@ impl ServerConn { async fn recv_handler_aux( self: &Arc, bytes: &[u8], - stream: AssociatedStream, - ) -> Result<(Vec, Option), Error> { + stream: ByteStream, + ) -> Result<(Vec, Option), Error> { let msg = QueryMessage::decode(bytes)?; let path = String::from_utf8(msg.path.to_vec())?; diff --git a/src/util.rs b/src/util.rs index 186678d..6fbafe6 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,17 +1,15 @@ -use crate::endpoint::SerializeMessage; - use std::net::SocketAddr; use std::net::ToSocketAddrs; use std::pin::Pin; -use futures::Stream; - use log::info; - use serde::Serialize; +use futures::Stream; use tokio::sync::watch; +use crate::message::SerializeMessage; + /// A node's identifier, which is also its public cryptographic key pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; /// A node's secret key @@ -27,7 +25,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// Error code 255 means the stream was cut before its end. Other codes have no predefined /// meaning, it's up to your application to define their semantic. -pub type AssociatedStream = Pin + Send>>; +pub type ByteStream = Pin + Send>>; pub type Packet = Result, u8>; @@ -38,7 +36,7 @@ pub type Packet = Result, u8>; /// This is used internally by the netapp communication protocol. pub fn rmp_to_vec_all_named( val: &T, -) -> Result<(Vec, Option), rmp_serde::encode::Error> +) -> Result<(Vec, Option), rmp_serde::encode::Error> where T: SerializeMessage + ?Sized, { -- cgit v1.2.3