aboutsummaryrefslogtreecommitdiff
path: root/src/proto.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/proto.rs')
-rw-r--r--src/proto.rs364
1 files changed, 167 insertions, 197 deletions
diff --git a/src/proto.rs b/src/proto.rs
index e3f9be8..d6dc35a 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -3,11 +3,11 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
-use log::{trace, warn};
+use log::trace;
-use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
-use futures::Stream;
+use futures::channel::mpsc::{unbounded, UnboundedSender};
use futures::{AsyncReadExt, AsyncWriteExt};
+use futures::{Stream, StreamExt};
use kuska_handshake::async_std::BoxStreamWrite;
use tokio::sync::mpsc;
@@ -63,39 +63,24 @@ struct SendQueueItem {
data: DataReader,
}
-pub(crate) enum Data {
- Full(Vec<u8>),
- Streaming(AssociatedStream),
+#[pin_project::pin_project]
+struct DataReader {
+ #[pin]
+ reader: AssociatedStream,
+ packet: Result<Vec<u8>, u8>,
+ pos: usize,
+ buf: Vec<u8>,
+ eos: bool,
}
-#[pin_project::pin_project(project = DataReaderProj)]
-enum DataReader {
- Full {
- #[pin]
- data: Vec<u8>,
- pos: usize,
- },
- Streaming {
- #[pin]
- reader: AssociatedStream,
- packet: Result<Vec<u8>, u8>,
- pos: usize,
- buf: Vec<u8>,
- eos: bool,
- },
-}
-
-impl From<Data> for DataReader {
- fn from(data: Data) -> DataReader {
- match data {
- Data::Full(data) => DataReader::Full { data, pos: 0 },
- Data::Streaming(reader) => DataReader::Streaming {
- reader,
- packet: Ok(Vec::new()),
- pos: 0,
- buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
- eos: false,
- },
+impl From<AssociatedStream> for DataReader {
+ fn from(data: AssociatedStream) -> DataReader {
+ DataReader {
+ reader: data,
+ packet: Ok(Vec::new()),
+ pos: 0,
+ buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
+ eos: false,
}
}
}
@@ -155,82 +140,60 @@ impl Stream for DataReader {
type Item = DataReaderItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- match self.project() {
- DataReaderProj::Full { data, pos } => {
- let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos);
- let end = *pos + len;
-
- if len == 0 {
- Poll::Ready(None)
- } else {
- let mut body = [0; MAX_CHUNK_LENGTH as usize];
- body[..len].copy_from_slice(&data[*pos..end]);
- *pos = end;
- Poll::Ready(Some(DataReaderItem {
- data: DataFrame::Data { data: body, len },
- may_have_more: end < data.len(),
- }))
+ let mut this = self.project();
+
+ if *this.eos {
+ // eos was reached at previous call to poll_next, where a partial packet
+ // was returned. Now return None
+ return Poll::Ready(None);
+ }
+
+ loop {
+ let packet = match this.packet {
+ Ok(v) => v,
+ Err(e) => {
+ let e = *e;
+ *this.packet = Ok(Vec::new());
+ return Poll::Ready(Some(DataReaderItem {
+ data: DataFrame::Error(e),
+ may_have_more: true,
+ }));
}
+ };
+ let packet_left = packet.len() - *this.pos;
+ let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len();
+ let to_read = std::cmp::min(buf_left, packet_left);
+ this.buf
+ .extend_from_slice(&packet[*this.pos..*this.pos + to_read]);
+ *this.pos += to_read;
+ if this.buf.len() == MAX_CHUNK_LENGTH as usize {
+ // we have a full buf, ready to send
+ break;
}
- DataReaderProj::Streaming {
- mut reader,
- packet: res_packet,
- pos,
- buf,
- eos,
- } => {
- if *eos {
- // eos was reached at previous call to poll_next, where a partial packet
- // was returned. Now return None
- return Poll::Ready(None);
- }
- loop {
- let packet = match res_packet {
- Ok(v) => v,
- Err(e) => {
- let e = *e;
- *res_packet = Ok(Vec::new());
- return Poll::Ready(Some(DataReaderItem {
- data: DataFrame::Error(e),
- may_have_more: true,
- }));
- }
- };
- let packet_left = packet.len() - *pos;
- let buf_left = MAX_CHUNK_LENGTH as usize - buf.len();
- let to_read = std::cmp::min(buf_left, packet_left);
- buf.extend_from_slice(&packet[*pos..*pos + to_read]);
- *pos += to_read;
- if buf.len() == MAX_CHUNK_LENGTH as usize {
- // we have a full buf, ready to send
- break;
- }
- // we don't have a full buf, packet is empty; try receive more
- if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) {
- *res_packet = p;
- *pos = 0;
- // if buf is empty, we will loop and return the error directly. If buf
- // isn't empty, send it before by breaking.
- if res_packet.is_err() && !buf.is_empty() {
- break;
- }
- } else {
- *eos = true;
- break;
- }
+ // we don't have a full buf, packet is empty; try receive more
+ if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) {
+ *this.packet = p;
+ *this.pos = 0;
+ // if buf is empty, we will loop and return the error directly. If buf
+ // isn't empty, send it before by breaking.
+ if this.packet.is_err() && !this.buf.is_empty() {
+ break;
}
-
- let mut body = [0; MAX_CHUNK_LENGTH as usize];
- let len = buf.len();
- body[..len].copy_from_slice(buf);
- buf.clear();
- Poll::Ready(Some(DataReaderItem {
- data: DataFrame::Data { data: body, len },
- may_have_more: !*eos,
- }))
+ } else {
+ *this.eos = true;
+ break;
}
}
+
+ let mut body = [0; MAX_CHUNK_LENGTH as usize];
+ let len = this.buf.len();
+ body[..len].copy_from_slice(this.buf);
+ this.buf.clear();
+ Poll::Ready(Some(DataReaderItem {
+ data: DataFrame::Data { data: body, len },
+ may_have_more: !*this.eos,
+ }))
}
}
@@ -334,7 +297,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> {
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
self: Arc<Self>,
- mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>,
+ mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>,
mut write: BoxStreamWrite<W>,
) -> Result<(), Error>
where
@@ -380,38 +343,82 @@ pub(crate) trait SendLoop: Sync {
}
}
-struct ChannelPair {
- receiver: Option<UnboundedReceiver<Vec<u8>>>,
- sender: Option<UnboundedSender<Vec<u8>>>,
+pub(crate) struct Framing {
+ direct: Vec<u8>,
+ stream: Option<AssociatedStream>,
}
-impl ChannelPair {
- fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> {
- self.receiver.take()
+impl Framing {
+ pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self {
+ assert!(direct.len() <= u32::MAX as usize);
+ Framing { direct, stream }
}
- fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> {
- self.sender.take()
- }
+ pub fn into_stream(self) -> AssociatedStream {
+ use futures::stream;
+ let len = self.direct.len() as u32;
+ // required because otherwise the borrow-checker complains
+ let Framing { direct, stream } = self;
- fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> {
- self.sender.as_ref().take()
- }
+ let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
+ .chain(stream::once(async move { Ok(direct) }));
- fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) {
- if self.receiver.is_some() || self.sender.is_some() {
- map.insert(index, self);
+ if let Some(stream) = stream {
+ Box::pin(res.chain(stream))
+ } else {
+ Box::pin(res)
}
}
-}
-impl Default for ChannelPair {
- fn default() -> Self {
- let (send, recv) = unbounded();
- ChannelPair {
- receiver: Some(recv),
- sender: Some(send),
+ pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>(
+ mut stream: S,
+ ) -> Result<Self, Error> {
+ let mut packet = stream
+ .next()
+ .await
+ .ok_or(Error::Framing)?
+ .map_err(|_| Error::Framing)?;
+ if packet.len() < 4 {
+ return Err(Error::Framing);
+ }
+
+ let mut len = [0; 4];
+ len.copy_from_slice(&packet[..4]);
+ let len = u32::from_be_bytes(len);
+ packet.drain(..4);
+
+ let mut buffer = Vec::new();
+ let len = len as usize;
+ loop {
+ let max_cp = std::cmp::min(len - buffer.len(), packet.len());
+
+ buffer.extend_from_slice(&packet[..max_cp]);
+ if buffer.len() == len {
+ packet.drain(..max_cp);
+ break;
+ }
+ packet = stream
+ .next()
+ .await
+ .ok_or(Error::Framing)?
+ .map_err(|_| Error::Framing)?;
}
+
+ let stream: AssociatedStream = if packet.is_empty() {
+ Box::pin(stream)
+ } else {
+ Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
+ };
+
+ Ok(Framing {
+ direct: buffer,
+ stream: Some(stream),
+ })
+ }
+
+ pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) {
+ let Framing { direct, stream } = self;
+ (direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
}
}
@@ -424,14 +431,13 @@ impl Default for ChannelPair {
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
- fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream);
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where
R: AsyncReadExt + Unpin + Send + Sync,
{
- let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new();
- let mut streams: HashMap<RequestID, ChannelPair> = HashMap::new();
+ let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new();
loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8];
@@ -450,55 +456,30 @@ pub(crate) trait RecvLoop: Sync + 'static {
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
let is_error = (size & ERROR_MARKER) != 0;
- let size = if !is_error {
- size & !CHUNK_HAS_CONTINUATION
+ let packet = if is_error {
+ Err(size as u8)
} else {
- 0
+ let size = size & !CHUNK_HAS_CONTINUATION;
+ let mut next_slice = vec![0; size as usize];
+ read.read_exact(&mut next_slice[..]).await?;
+ trace!("recv_loop: read {} bytes", next_slice.len());
+ Ok(next_slice)
};
- // TODO propagate errors
-
- let mut next_slice = vec![0; size as usize];
- read.read_exact(&mut next_slice[..]).await?;
- trace!("recv_loop: read {} bytes", next_slice.len());
-
- if id & 1 == 0 {
- // main stream
- let mut msg_bytes = receiving.remove(&id).unwrap_or_default();
- msg_bytes.extend_from_slice(&next_slice[..]);
-
- if has_cont {
- receiving.insert(id, msg_bytes);
- } else {
- let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default();
-
- if let Some(receiver) = channel_pair.take_receiver() {
- use futures::StreamExt;
- self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v))));
- } else {
- warn!("Couldn't take receiver part of stream")
- }
- channel_pair.insert_into(&mut streams, id | 1);
- }
+ let sender = if let Some(send) = streams.remove(&(id)) {
+ send
} else {
- // associated stream
- let mut channel_pair = streams.remove(&(id)).unwrap_or_default();
-
- // if we get an error, the receiving end is disconnected. We still need to
- // reach eos before dropping this sender
- if !next_slice.is_empty() {
- if let Some(sender) = channel_pair.ref_sender() {
- let _ = sender.unbounded_send(next_slice);
- } else {
- warn!("Couldn't take sending part of stream")
- }
- }
+ let (send, recv) = unbounded();
+ self.recv_handler(id, Box::pin(recv));
+ send
+ };
- if !has_cont {
- channel_pair.take_sender();
- }
+ // if we get an error, the receiving end is disconnected. We still need to
+ // reach eos before dropping this sender
+ let _ = sender.unbounded_send(packet);
- channel_pair.insert_into(&mut streams, id);
+ if has_cont {
+ streams.insert(id, sender);
}
}
Ok(())
@@ -509,55 +490,44 @@ pub(crate) trait RecvLoop: Sync + 'static {
mod test {
use super::*;
+ fn empty_data() -> DataReader {
+ type Item = Result<Vec<u8>, u8>;
+ let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
+ Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>());
+ stream.into()
+ }
+
#[test]
fn test_priority_queue() {
let i1 = SendQueueItem {
id: 1,
prio: PRIO_NORMAL,
- data: DataReader::Full {
- data: vec![],
- pos: 0,
- },
+ data: empty_data(),
};
let i2 = SendQueueItem {
id: 2,
prio: PRIO_HIGH,
- data: DataReader::Full {
- data: vec![],
- pos: 0,
- },
+ data: empty_data(),
};
let i2bis = SendQueueItem {
id: 20,
prio: PRIO_HIGH,
- data: DataReader::Full {
- data: vec![],
- pos: 0,
- },
+ data: empty_data(),
};
let i3 = SendQueueItem {
id: 3,
prio: PRIO_HIGH | PRIO_SECONDARY,
- data: DataReader::Full {
- data: vec![],
- pos: 0,
- },
+ data: empty_data(),
};
let i4 = SendQueueItem {
id: 4,
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
- data: DataReader::Full {
- data: vec![],
- pos: 0,
- },
+ data: empty_data(),
};
let i5 = SendQueueItem {
id: 5,
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
- data: DataReader::Full {
- data: vec![],
- pos: 0,
- },
+ data: empty_data(),
};
let mut q = SendQueue::new();