From 368ba908794901bc793c6a087c02241be046bdf2 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 5 Jun 2022 15:33:43 +0200 Subject: initial work on associated stream still require testing, and fixing a few kinks: - sending packets > 16k truncate them - send one more packet than it could at eos - probably update documentation /!\ contains breaking changes --- src/proto.rs | 260 +++++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 216 insertions(+), 44 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index e843bff..b45ff13 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,9 +1,13 @@ use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; -use log::trace; +use log::{trace, warn}; -use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::Stream; +use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -11,6 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; +use crate::util::AssociatedStream; /// Priority of a request (click to read more about priorities). /// @@ -48,14 +53,73 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, prio: RequestPriority, - data: Vec, - cursor: usize, + data: DataReader, +} + +pub(crate) enum Data { + Full(Vec), + Streaming(AssociatedStream), +} + +#[pin_project::pin_project(project = DataReaderProj)] +enum DataReader { + Full { + #[pin] + data: Vec, + pos: usize, + }, + Streaming { + #[pin] + reader: AssociatedStream, + }, +} + +impl From for DataReader { + fn from(data: Data) -> DataReader { + match data { + Data::Full(data) => DataReader::Full { data, pos: 0 }, + Data::Streaming(reader) => DataReader::Streaming { reader }, + } + } +} + +impl Stream for DataReader { + type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + DataReaderProj::Full { data, pos } => { + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); + let end = *pos + len; + + if len == 0 { + Poll::Ready(None) + } else { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..len].copy_from_slice(&data[*pos..end]); + *pos = end; + Poll::Ready(Some((body, len))) + } + } + DataReaderProj::Streaming { reader } => { + reader.poll_next(cx).map(|opt| { + opt.map(|v| { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); + // TODO this can throw away long vec, they should be splited instead + body[..len].copy_from_slice(&v[..len]); + (body, len) + }) + }) + } + } + } } struct SendQueue { @@ -108,7 +172,7 @@ impl SendQueue { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -118,51 +182,78 @@ pub(crate) trait SendLoop: Sync { let mut should_exit = false; while !should_exit || !sending.is_empty() { if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else if let Some(mut item) = sending.pop() { trace!( - "send_loop: sending bytes for {} ({} bytes, {} already sent)", - item.id, - item.data.len(), - item.cursor + "send_loop: sending bytes for {}", + item.id, ); + + let data = futures::select! { + data = item.data.next().fuse() => data, + default => { + // nothing to send yet; re-schedule and find something else to do + sending.push(item); + continue; + + // TODO if every SendQueueItem is waiting on data, use select_all to await + // something to do + // TODO find some way to not require sending empty last chunk + } + }; + let header_id = RequestID::to_be_bytes(item.id); write.write_all(&header_id[..]).await?; - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { + let data = match data.as_ref() { + Some((data, len)) => &data[..*len], + None => &[], + }; + + if !data.is_empty() { let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); + ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; + write.write_all(data).await?; sending.push(item); } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); + // this is always zero for now, but may be more when above TODO get fixed + let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; - write.write_all(&item.data[item.cursor..]).await?; + write.write_all(data).await?; } + write.flush().await?; } else { let sth = msg_recv.recv().await; if let Some((id, prio, data)) = sth { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else { should_exit = true; @@ -175,6 +266,41 @@ pub(crate) trait SendLoop: Sync { } } +struct ChannelPair { + receiver: Option>>, + sender: Option>>, +} + +impl ChannelPair { + fn take_receiver(&mut self) -> Option>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } +} + +impl Default for ChannelPair { + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -184,13 +310,17 @@ pub(crate) trait SendLoop: Sync { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec); + fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving = HashMap::new(); + let mut receiving: HashMap> = HashMap::new(); + let mut streams: HashMap< + RequestID, + ChannelPair, + > = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -214,13 +344,43 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", next_slice.len()); - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); + if id & 1 == 0 { + // main stream + let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); + msg_bytes.extend_from_slice(&next_slice[..]); - if has_cont { - receiving.insert(id, msg_bytes); + if has_cont { + receiving.insert(id, msg_bytes); + } else { + let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); + + if let Some(receiver) = channel_pair.take_receiver() { + self.recv_handler(id, msg_bytes, Box::pin(receiver)); + } else { + warn!("Couldn't take receiver part of stream") + } + + channel_pair.insert_into(&mut streams, id | 1); + } } else { - self.recv_handler(id, msg_bytes); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } + + if !has_cont { + channel_pair.take_sender(); + } + + channel_pair.insert_into(&mut streams, id); } } Ok(()) @@ -236,38 +396,50 @@ mod test { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let mut q = SendQueue::new(); -- cgit v1.2.3 From fb5462ecdb6b5731a63a902519d3ec9b1061b8dd Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 5 Jun 2022 16:47:29 +0200 Subject: rechunk stream --- src/proto.rs | 159 +++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 94 insertions(+), 65 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index b45ff13..ca1a3d2 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -53,7 +53,7 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -77,6 +77,10 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, + packet: Vec, + pos: usize, + buf: Vec, + eos: bool, }, } @@ -84,7 +88,13 @@ impl From for DataReader { fn from(data: Data) -> DataReader { match data { Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { reader }, + Data::Streaming(reader) => DataReader::Streaming { + reader, + packet: Vec::new(), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, + }, } } } @@ -107,16 +117,43 @@ impl Stream for DataReader { Poll::Ready(Some((body, len))) } } - DataReaderProj::Streaming { reader } => { - reader.poll_next(cx).map(|opt| { - opt.map(|v| { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); - // TODO this can throw away long vec, they should be splited instead - body[..len].copy_from_slice(&v[..len]); - (body, len) - }) - }) + DataReaderProj::Streaming { + mut reader, + packet, + pos, + buf, + eos, + } => { + if *eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + loop { + let packet_left = packet.len() - *pos; + let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + buf.extend_from_slice(&packet[*pos..*pos + to_read]); + *pos += to_read; + if buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; + } + + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { + *packet = p; + *pos = 0; + } else { + *eos = true; + break; + } + } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..buf.len()].copy_from_slice(&buf); + buf.clear(); + Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) } } } @@ -196,10 +233,7 @@ pub(crate) trait SendLoop: Sync { data: data.into(), }); } else if let Some(mut item) = sending.pop() { - trace!( - "send_loop: sending bytes for {}", - item.id, - ); + trace!("send_loop: sending bytes for {}", item.id,); let data = futures::select! { data = item.data.next().fuse() => data, @@ -210,7 +244,6 @@ pub(crate) trait SendLoop: Sync { // TODO if every SendQueueItem is waiting on data, use select_all to await // something to do - // TODO find some way to not require sending empty last chunk } }; @@ -222,7 +255,7 @@ pub(crate) trait SendLoop: Sync { None => &[], }; - if !data.is_empty() { + if data.len() == MAX_CHUNK_LENGTH as usize { let size_header = ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; @@ -231,7 +264,6 @@ pub(crate) trait SendLoop: Sync { sending.push(item); } else { - // this is always zero for now, but may be more when above TODO get fixed let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; @@ -267,38 +299,38 @@ pub(crate) trait SendLoop: Sync { } struct ChannelPair { - receiver: Option>>, - sender: Option>>, + receiver: Option>>, + sender: Option>>, } impl ChannelPair { - fn take_receiver(&mut self) -> Option>> { - self.receiver.take() - } - - fn take_sender(&mut self) -> Option>> { - self.sender.take() - } - - fn ref_sender(&mut self) -> Option<&UnboundedSender>> { - self.sender.as_ref().take() - } - - fn insert_into(self, map: &mut HashMap, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); - } - } + fn take_receiver(&mut self) -> Option>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } } impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), - } - } + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } } /// The RecvLoop trait, which is implemented both by the client and the server @@ -317,10 +349,7 @@ pub(crate) trait RecvLoop: Sync + 'static { R: AsyncReadExt + Unpin + Send + Sync, { let mut receiving: HashMap> = HashMap::new(); - let mut streams: HashMap< - RequestID, - ChannelPair, - > = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -345,7 +374,7 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: read {} bytes", next_slice.len()); if id & 1 == 0 { - // main stream + // main stream let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); msg_bytes.extend_from_slice(&next_slice[..]); @@ -357,30 +386,30 @@ pub(crate) trait RecvLoop: Sync + 'static { if let Some(receiver) = channel_pair.take_receiver() { self.recv_handler(id, msg_bytes, Box::pin(receiver)); } else { - warn!("Couldn't take receiver part of stream") - } + warn!("Couldn't take receiver part of stream") + } - channel_pair.insert_into(&mut streams, id | 1); + channel_pair.insert_into(&mut streams, id | 1); } } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } if !has_cont { - channel_pair.take_sender(); - } + channel_pair.take_sender(); + } - channel_pair.insert_into(&mut streams, id); + channel_pair.insert_into(&mut streams, id); } } Ok(()) -- cgit v1.2.3 From 4745e7c4ba5665d3303ae567087781778cec9c34 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Wed, 8 Jun 2022 00:30:56 +0200 Subject: further work on streams most changes still required are related to error handling --- src/proto.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index ca1a3d2..073a317 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -151,9 +151,10 @@ impl Stream for DataReader { } let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..buf.len()].copy_from_slice(&buf); + let len = buf.len(); + body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) + Poll::Ready(Some((body, len))) } } } -- cgit v1.2.3 From 5d7541e13a4c3640f0dc8aead595b51775fc0ac8 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 19 Jun 2022 17:44:07 +0200 Subject: wait for any ready stream instead of the highest priority one --- src/proto.rs | 185 +++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 115 insertions(+), 70 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index 073a317..417b508 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -7,7 +7,7 @@ use log::{trace, warn}; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::Stream; -use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; +use futures::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -53,7 +53,8 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +const ERROR_MARKER: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -99,8 +100,29 @@ impl From for DataReader { } } +struct DataReaderItem { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actuall lenght of data + len: usize, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, +} + +impl DataReaderItem { + fn empty_last() -> Self { + DataReaderItem { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + may_have_more: false, + } + } +} + impl Stream for DataReader { - type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { @@ -114,7 +136,11 @@ impl Stream for DataReader { let mut body = [0; MAX_CHUNK_LENGTH as usize]; body[..len].copy_from_slice(&data[*pos..end]); *pos = end; - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: end < data.len(), + })) } } DataReaderProj::Streaming { @@ -154,7 +180,11 @@ impl Stream for DataReader { let len = buf.len(); body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: !*eos, + })) } } } @@ -181,6 +211,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -196,6 +228,54 @@ impl SendQueue { fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataReaderItem); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for i in 0..self.queue.items.len() { + let (_prio, items_at_prio) = &mut self.queue.items[i]; + + for _ in 0..items_at_prio.len() { + let mut item = items_at_prio.pop_front().unwrap(); + + match Pin::new(&mut item.data).poll_next(ctx) { + Poll::Pending => items_at_prio.push_back(item), + Poll::Ready(Some(data)) => { + let id = item.id; + if data.may_have_more { + self.queue.push(item); + } else { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + } + return Poll::Ready((id, data)); + } + Poll::Ready(None) => { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + return Poll::Ready((item.id, DataReaderItem::empty_last())); + } + } + } + } + // TODO what do we do if self.queue is empty? We won't get scheduled again. + Poll::Pending + } } /// The SendLoop trait, which is implemented both by the client and the server @@ -219,77 +299,42 @@ pub(crate) trait SendLoop: Sync { let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { - if let Ok((id, prio, data)) = msg_recv.try_recv() { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } + let recv_fut = msg_recv.recv(); + futures::pin_mut!(recv_fut); + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + use futures::future::Either; + match futures::future::select(recv_fut, send_fut).await { + Either::Left((sth, _send_fut)) => { + if let Some((id, prio, data)) = sth { + sending.push(SendQueueItem { + id, + prio, + data: data.into(), + }); + } else { + should_exit = true; + }; } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else if let Some(mut item) = sending.pop() { - trace!("send_loop: sending bytes for {}", item.id,); - - let data = futures::select! { - data = item.data.next().fuse() => data, - default => { - // nothing to send yet; re-schedule and find something else to do - sending.push(item); - continue; - - // TODO if every SendQueueItem is waiting on data, use select_all to await - // something to do - } - }; - - let header_id = RequestID::to_be_bytes(item.id); - write.write_all(&header_id[..]).await?; + Either::Right(((id, data), _recv_fut)) => { + trace!("send_loop: sending bytes for {}", id); - let data = match data.as_ref() { - Some((data, len)) => &data[..*len], - None => &[], - }; + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; - if data.len() == MAX_CHUNK_LENGTH as usize { - let size_header = - ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); - write.write_all(&size_header[..]).await?; + let body = &data.data[..data.len]; - write.write_all(data).await?; + let size_header = if data.may_have_more { + ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) + } else { + ChunkLength::to_be_bytes(data.len as u16) + }; - sending.push(item); - } else { - let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; - - write.write_all(data).await?; - } - - write.flush().await?; - } else { - let sth = msg_recv.recv().await; - if let Some((id, prio, data)) = sth { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } - } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; + write.write_all(body).await?; + write.flush().await?; } } } -- cgit v1.2.3 From 0fec85b47a1bc679d2684994bfae6ef0fe7d4911 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 19 Jun 2022 18:42:27 +0200 Subject: start supporting sending error on stream --- src/proto.rs | 99 +++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 28 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index 417b508..e3f9be8 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -78,7 +78,7 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, - packet: Vec, + packet: Result, u8>, pos: usize, buf: Vec, eos: bool, @@ -91,7 +91,7 @@ impl From for DataReader { Data::Full(data) => DataReader::Full { data, pos: 0 }, Data::Streaming(reader) => DataReader::Streaming { reader, - packet: Vec::new(), + packet: Ok(Vec::new()), pos: 0, buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), eos: false, @@ -100,11 +100,18 @@ impl From for DataReader { } } +enum DataFrame { + Data { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actual lenght of data + len: usize, + }, + Error(u8), +} + struct DataReaderItem { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actuall lenght of data - len: usize, + data: DataFrame, /// whethere there may be more data comming from this stream. Can be used for some /// optimization. It's an error to set it to false if there is more data, but it is correct /// (albeit sub-optimal) to set it to true if there is nothing coming after @@ -114,11 +121,34 @@ struct DataReaderItem { impl DataReaderItem { fn empty_last() -> Self { DataReaderItem { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, + data: DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + }, may_have_more: false, } } + + fn header(&self) -> [u8; 2] { + let continuation = if self.may_have_more { + CHUNK_HAS_CONTINUATION + } else { + 0 + }; + let len = match self.data { + DataFrame::Data { len, .. } => len as u16, + DataFrame::Error(e) => e as u16 | ERROR_MARKER, + }; + + ChunkLength::to_be_bytes(len | continuation) + } + + fn data(&self) -> &[u8] { + match self.data { + DataFrame::Data { ref data, len } => &data[..len], + DataFrame::Error(_) => &[], + } + } } impl Stream for DataReader { @@ -137,15 +167,14 @@ impl Stream for DataReader { body[..len].copy_from_slice(&data[*pos..end]); *pos = end; Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: end < data.len(), })) } } DataReaderProj::Streaming { mut reader, - packet, + packet: res_packet, pos, buf, eos, @@ -156,6 +185,17 @@ impl Stream for DataReader { return Poll::Ready(None); } loop { + let packet = match res_packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *res_packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); + } + }; let packet_left = packet.len() - *pos; let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); let to_read = std::cmp::min(buf_left, packet_left); @@ -168,8 +208,13 @@ impl Stream for DataReader { // we don't have a full buf, packet is empty; try receive more if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *packet = p; + *res_packet = p; *pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if res_packet.is_err() && !buf.is_empty() { + break; + } } else { *eos = true; break; @@ -181,8 +226,7 @@ impl Stream for DataReader { body[..len].copy_from_slice(buf); buf.clear(); Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: !*eos, })) } @@ -211,8 +255,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -324,16 +368,8 @@ pub(crate) trait SendLoop: Sync { let header_id = RequestID::to_be_bytes(id); write.write_all(&header_id[..]).await?; - let body = &data.data[..data.len]; - - let size_header = if data.may_have_more { - ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) - } else { - ChunkLength::to_be_bytes(data.len as u16) - }; - - write.write_all(&size_header[..]).await?; - write.write_all(body).await?; + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; write.flush().await?; } } @@ -413,7 +449,13 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: got header size: {:04x}", size); let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let size = size & !CHUNK_HAS_CONTINUATION; + let is_error = (size & ERROR_MARKER) != 0; + let size = if !is_error { + size & !CHUNK_HAS_CONTINUATION + } else { + 0 + }; + // TODO propagate errors let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; @@ -430,7 +472,8 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); if let Some(receiver) = channel_pair.take_receiver() { - self.recv_handler(id, msg_bytes, Box::pin(receiver)); + use futures::StreamExt; + self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); } else { warn!("Couldn't take receiver part of stream") } -- cgit v1.2.3 From d3d18b8e8bde5fee81022fd050d5f4c114262fcf Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Mon, 20 Jun 2022 23:40:31 +0200 Subject: use a framing protocol instead of even/odd channel --- src/proto.rs | 364 +++++++++++++++++++++++++++-------------------------------- 1 file changed, 167 insertions(+), 197 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index e3f9be8..d6dc35a 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -3,11 +3,11 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use log::{trace, warn}; +use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::Stream; +use futures::channel::mpsc::{unbounded, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -63,39 +63,24 @@ struct SendQueueItem { data: DataReader, } -pub(crate) enum Data { - Full(Vec), - Streaming(AssociatedStream), +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: AssociatedStream, + packet: Result, u8>, + pos: usize, + buf: Vec, + eos: bool, } -#[pin_project::pin_project(project = DataReaderProj)] -enum DataReader { - Full { - #[pin] - data: Vec, - pos: usize, - }, - Streaming { - #[pin] - reader: AssociatedStream, - packet: Result, u8>, - pos: usize, - buf: Vec, - eos: bool, - }, -} - -impl From for DataReader { - fn from(data: Data) -> DataReader { - match data { - Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { - reader, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - }, +impl From for DataReader { + fn from(data: AssociatedStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, } } } @@ -155,82 +140,60 @@ impl Stream for DataReader { type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - DataReaderProj::Full { data, pos } => { - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); - let end = *pos + len; - - if len == 0 { - Poll::Ready(None) - } else { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..len].copy_from_slice(&data[*pos..end]); - *pos = end; - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: end < data.len(), - })) + let mut this = self.project(); + + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; } - DataReaderProj::Streaming { - mut reader, - packet: res_packet, - pos, - buf, - eos, - } => { - if *eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - loop { - let packet = match res_packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *res_packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *pos; - let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - buf.extend_from_slice(&packet[*pos..*pos + to_read]); - *pos += to_read; - if buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *res_packet = p; - *pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if res_packet.is_err() && !buf.is_empty() { - break; - } - } else { - *eos = true; - break; - } + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = buf.len(); - body[..len].copy_from_slice(buf); - buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*eos, - })) + } else { + *this.eos = true; + break; } } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) } } @@ -334,7 +297,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -380,38 +343,82 @@ pub(crate) trait SendLoop: Sync { } } -struct ChannelPair { - receiver: Option>>, - sender: Option>>, +pub(crate) struct Framing { + direct: Vec, + stream: Option, } -impl ChannelPair { - fn take_receiver(&mut self) -> Option>> { - self.receiver.take() +impl Framing { + pub fn new(direct: Vec, stream: Option) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } } - fn take_sender(&mut self) -> Option>> { - self.sender.take() - } + pub fn into_stream(self) -> AssociatedStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; - fn ref_sender(&mut self) -> Option<&UnboundedSender>> { - self.sender.as_ref().take() - } + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); - fn insert_into(self, map: &mut HashMap, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) } } -} -impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), + pub async fn from_stream, u8>> + Unpin + Send + 'static>( + mut stream: S, + ) -> Result { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; } + + let stream: AssociatedStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec, AssociatedStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) } } @@ -424,14 +431,13 @@ impl Default for ChannelPair { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream); + fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving: HashMap> = HashMap::new(); - let mut streams: HashMap = HashMap::new(); + let mut streams: HashMap, u8>>> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -450,55 +456,30 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; - let size = if !is_error { - size & !CHUNK_HAS_CONTINUATION + let packet = if is_error { + Err(size as u8) } else { - 0 + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) }; - // TODO propagate errors - - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - - if id & 1 == 0 { - // main stream - let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); - - if has_cont { - receiving.insert(id, msg_bytes); - } else { - let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); - - if let Some(receiver) = channel_pair.take_receiver() { - use futures::StreamExt; - self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); - } else { - warn!("Couldn't take receiver part of stream") - } - channel_pair.insert_into(&mut streams, id | 1); - } + let sender = if let Some(send) = streams.remove(&(id)) { + send } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + let (send, recv) = unbounded(); + self.recv_handler(id, Box::pin(recv)); + send + }; - if !has_cont { - channel_pair.take_sender(); - } + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + let _ = sender.unbounded_send(packet); - channel_pair.insert_into(&mut streams, id); + if has_cont { + streams.insert(id, sender); } } Ok(()) @@ -509,55 +490,44 @@ pub(crate) trait RecvLoop: Sync + 'static { mod test { use super::*; + fn empty_data() -> DataReader { + type Item = Result, u8>; + let stream: Pin + Send + 'static>> = + Box::pin(futures::stream::empty::, u8>>()); + stream.into() + } + #[test] fn test_priority_queue() { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let mut q = SendQueue::new(); -- cgit v1.2.3 From cdff8ae1beab44a22d0eb0eb00c624e49971b6ca Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Mon, 18 Jul 2022 15:21:13 +0200 Subject: add detection of premature eos --- src/proto.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 12 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index d6dc35a..92d8d80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedSender}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; -use crate::util::AssociatedStream; +use crate::util::{AssociatedStream, Packet}; /// Priority of a request (click to read more about priorities). /// @@ -67,7 +67,7 @@ struct SendQueueItem { struct DataReader { #[pin] reader: AssociatedStream, - packet: Result, u8>, + packet: Packet, pos: usize, buf: Vec, eos: bool, @@ -370,7 +370,7 @@ impl Framing { } } - pub async fn from_stream, u8>> + Unpin + Send + 'static>( + pub async fn from_stream + Unpin + Send + 'static>( mut stream: S, ) -> Result { let mut packet = stream @@ -422,6 +422,39 @@ impl Framing { } } +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -431,13 +464,13 @@ impl Framing { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream); + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut streams: HashMap, u8>>> = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static { Ok(next_slice) }; - let sender = if let Some(send) = streams.remove(&(id)) { + let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { let (send, recv) = unbounded(); - self.recv_handler(id, Box::pin(recv)); - send + self.recv_handler(id, recv); + Sender::new(send) }; // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - let _ = sender.unbounded_send(packet); + sender.send(packet); if has_cont { streams.insert(id, sender); + } else { + sender.end(); } } Ok(()) @@ -491,9 +526,9 @@ mod test { use super::*; fn empty_data() -> DataReader { - type Item = Result, u8>; + type Item = Packet; let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::, u8>>()); + Box::pin(futures::stream::empty::()); stream.into() } -- cgit v1.2.3 From f35fa7d18d9e0f51bed311355ec1310b1d311ab3 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 17:34:53 +0200 Subject: Move things around --- src/proto.rs | 617 ----------------------------------------------------------- 1 file changed, 617 deletions(-) delete mode 100644 src/proto.rs (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs deleted file mode 100644 index 92d8d80..0000000 --- a/src/proto.rs +++ /dev/null @@ -1,617 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use log::trace; - -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::{AsyncReadExt, AsyncWriteExt}; -use futures::{Stream, StreamExt}; -use kuska_handshake::async_std::BoxStreamWrite; - -use tokio::sync::mpsc; - -use async_trait::async_trait; - -use crate::error::*; -use crate::util::{AssociatedStream, Packet}; - -/// Priority of a request (click to read more about priorities). -/// -/// This priority value is used to priorize messages -/// in the send queue of the client, and their responses in the send queue of the -/// server. Lower values mean higher priority. -/// -/// This mechanism is usefull for messages bigger than the maximum chunk size -/// (set at `0x4000` bytes), such as large file transfers. -/// In such case, all of the messages in the send queue with the highest priority -/// will take turns to send individual chunks, in a round-robin fashion. -/// Once all highest priority messages are sent successfully, the messages with -/// the next highest priority will begin being sent in the same way. -/// -/// The same priority value is given to a request and to its associated response. -pub type RequestPriority = u8; - -/// Priority class: high -pub const PRIO_HIGH: RequestPriority = 0x20; -/// Priority class: normal -pub const PRIO_NORMAL: RequestPriority = 0x40; -/// Priority class: background -pub const PRIO_BACKGROUND: RequestPriority = 0x80; -/// Priority: primary among given class -pub const PRIO_PRIMARY: RequestPriority = 0x00; -/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) -pub const PRIO_SECONDARY: RequestPriority = 0x01; - -// Messages are sent by chunks -// Chunk format: -// - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data - -pub(crate) type RequestID = u32; -type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; -const ERROR_MARKER: ChunkLength = 0x4000; -const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; - -struct SendQueueItem { - id: RequestID, - prio: RequestPriority, - data: DataReader, -} - -#[pin_project::pin_project] -struct DataReader { - #[pin] - reader: AssociatedStream, - packet: Packet, - pos: usize, - buf: Vec, - eos: bool, -} - -impl From for DataReader { - fn from(data: AssociatedStream) -> DataReader { - DataReader { - reader: data, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - } - } -} - -enum DataFrame { - Data { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actual lenght of data - len: usize, - }, - Error(u8), -} - -struct DataReaderItem { - data: DataFrame, - /// whethere there may be more data comming from this stream. Can be used for some - /// optimization. It's an error to set it to false if there is more data, but it is correct - /// (albeit sub-optimal) to set it to true if there is nothing coming after - may_have_more: bool, -} - -impl DataReaderItem { - fn empty_last() -> Self { - DataReaderItem { - data: DataFrame::Data { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, - }, - may_have_more: false, - } - } - - fn header(&self) -> [u8; 2] { - let continuation = if self.may_have_more { - CHUNK_HAS_CONTINUATION - } else { - 0 - }; - let len = match self.data { - DataFrame::Data { len, .. } => len as u16, - DataFrame::Error(e) => e as u16 | ERROR_MARKER, - }; - - ChunkLength::to_be_bytes(len | continuation) - } - - fn data(&self) -> &[u8] { - match self.data { - DataFrame::Data { ref data, len } => &data[..len], - DataFrame::Error(_) => &[], - } - } -} - -impl Stream for DataReader { - type Item = DataReaderItem; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - if *this.eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - - loop { - let packet = match this.packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *this.packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *this.pos; - let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - this.buf - .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); - *this.pos += to_read; - if this.buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { - *this.packet = p; - *this.pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if this.packet.is_err() && !this.buf.is_empty() { - break; - } - } else { - *this.eos = true; - break; - } - } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = this.buf.len(); - body[..len].copy_from_slice(this.buf); - this.buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*this.eos, - })) - } -} - -struct SendQueue { - items: VecDeque<(u8, VecDeque)>, -} - -impl SendQueue { - fn new() -> Self { - Self { - items: VecDeque::with_capacity(64), - } - } - fn push(&mut self, item: SendQueueItem) { - let prio = item.prio; - let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { - Ok(i) => i, - Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); - i - } - }; - self.items[pos_prio].1.push_back(item); - } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] - fn pop(&mut self) -> Option { - match self.items.pop_front() { - None => None, - Some((prio, mut items_at_prio)) => { - let ret = items_at_prio.pop_front(); - if !items_at_prio.is_empty() { - self.items.push_front((prio, items_at_prio)); - } - ret.or_else(|| self.pop()) - } - } - } - fn is_empty(&self) -> bool { - self.items.iter().all(|(_k, v)| v.is_empty()) - } - - // this is like an async fn, but hand implemented - fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { - SendQueuePollNextReady { queue: self } - } -} - -struct SendQueuePollNextReady<'a> { - queue: &'a mut SendQueue, -} - -impl<'a> futures::Future for SendQueuePollNextReady<'a> { - type Output = (RequestID, DataReaderItem); - - fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - for i in 0..self.queue.items.len() { - let (_prio, items_at_prio) = &mut self.queue.items[i]; - - for _ in 0..items_at_prio.len() { - let mut item = items_at_prio.pop_front().unwrap(); - - match Pin::new(&mut item.data).poll_next(ctx) { - Poll::Pending => items_at_prio.push_back(item), - Poll::Ready(Some(data)) => { - let id = item.id; - if data.may_have_more { - self.queue.push(item); - } else { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - } - return Poll::Ready((id, data)); - } - Poll::Ready(None) => { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - return Poll::Ready((item.id, DataReaderItem::empty_last())); - } - } - } - } - // TODO what do we do if self.queue is empty? We won't get scheduled again. - Poll::Pending - } -} - -/// The SendLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` -/// that takes a channel of messages to send and an asynchronous writer, -/// and sends messages from the channel to the async writer, putting them in a queue -/// before being sent and doing the round-robin sending strategy. -/// -/// The `.send_loop()` exits when the sending end of the channel is closed, -/// or if there is an error at any time writing to the async writer. -#[async_trait] -pub(crate) trait SendLoop: Sync { - async fn send_loop( - self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, - mut write: BoxStreamWrite, - ) -> Result<(), Error> - where - W: AsyncWriteExt + Unpin + Send + Sync, - { - let mut sending = SendQueue::new(); - let mut should_exit = false; - while !should_exit || !sending.is_empty() { - let recv_fut = msg_recv.recv(); - futures::pin_mut!(recv_fut); - let send_fut = sending.next_ready(); - - // recv_fut is cancellation-safe according to tokio doc, - // send_fut is cancellation-safe as implemented above? - use futures::future::Either; - match futures::future::select(recv_fut, send_fut).await { - Either::Left((sth, _send_fut)) => { - if let Some((id, prio, data)) = sth { - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; - }; - } - Either::Right(((id, data), _recv_fut)) => { - trace!("send_loop: sending bytes for {}", id); - - let header_id = RequestID::to_be_bytes(id); - write.write_all(&header_id[..]).await?; - - write.write_all(&data.header()).await?; - write.write_all(data.data()).await?; - write.flush().await?; - } - } - } - - let _ = write.goodbye().await; - Ok(()) - } -} - -pub(crate) struct Framing { - direct: Vec, - stream: Option, -} - -impl Framing { - pub fn new(direct: Vec, stream: Option) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } - } - - pub fn into_stream(self) -> AssociatedStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; - - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) - .chain(stream::once(async move { Ok(direct) })); - - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } - - pub async fn from_stream + Unpin + Send + 'static>( - mut stream: S, - ) -> Result { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); - } - - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet.drain(..4); - - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); - - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet.drain(..max_cp); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } - - let stream: AssociatedStream = if packet.is_empty() { - Box::pin(stream) - } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; - - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec, AssociatedStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) - } -} - -/// Structure to warn when the sender is dropped before end of stream was reached, like when -/// connection to some remote drops while transmitting data -struct Sender { - inner: UnboundedSender, - closed: bool, -} - -impl Sender { - fn new(inner: UnboundedSender) -> Self { - Sender { - inner, - closed: false, - } - } - - fn send(&self, packet: Packet) { - let _ = self.inner.unbounded_send(packet); - } - - fn end(&mut self) { - self.closed = true; - } -} - -impl Drop for Sender { - fn drop(&mut self) { - if !self.closed { - self.send(Err(255)); - } - self.inner.close_channel(); - } -} - -/// The RecvLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` -/// and a prototype of a handler for received messages `.recv_handler()` that -/// must be filled by implementors. `.recv_loop()` receives messages in a loop -/// according to the protocol defined above: chunks of message in progress of being -/// received are stored in a buffer, and when the last chunk of a message is received, -/// the full message is passed to the receive handler. -#[async_trait] -pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); - - async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> - where - R: AsyncReadExt + Unpin + Send + Sync, - { - let mut streams: HashMap = HashMap::new(); - loop { - trace!("recv_loop: reading packet"); - let mut header_id = [0u8; RequestID::BITS as usize / 8]; - match read.read_exact(&mut header_id[..]).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e.into()), - }; - let id = RequestID::from_be_bytes(header_id); - trace!("recv_loop: got header id: {:04x}", id); - - let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; - read.read_exact(&mut header_size[..]).await?; - let size = ChunkLength::from_be_bytes(header_size); - trace!("recv_loop: got header size: {:04x}", size); - - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let is_error = (size & ERROR_MARKER) != 0; - let packet = if is_error { - Err(size as u8) - } else { - let size = size & !CHUNK_HAS_CONTINUATION; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - Ok(next_slice) - }; - - let mut sender = if let Some(send) = streams.remove(&(id)) { - send - } else { - let (send, recv) = unbounded(); - self.recv_handler(id, recv); - Sender::new(send) - }; - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - sender.send(packet); - - if has_cont { - streams.insert(id, sender); - } else { - sender.end(); - } - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn empty_data() -> DataReader { - type Item = Packet; - let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::()); - stream.into() - } - - #[test] - fn test_priority_queue() { - let i1 = SendQueueItem { - id: 1, - prio: PRIO_NORMAL, - data: empty_data(), - }; - let i2 = SendQueueItem { - id: 2, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i2bis = SendQueueItem { - id: 20, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i3 = SendQueueItem { - id: 3, - prio: PRIO_HIGH | PRIO_SECONDARY, - data: empty_data(), - }; - let i4 = SendQueueItem { - id: 4, - prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: empty_data(), - }; - let i5 = SendQueueItem { - id: 5, - prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: empty_data(), - }; - - let mut q = SendQueue::new(); - - q.push(i1); // 1 - let a = q.pop().unwrap(); // empty -> 1 - assert_eq!(a.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 1 - q.push(i2); // 2 1 - q.push(i2bis); // [2 20] 1 - let a = q.pop().unwrap(); // 20 1 -> 2 - assert_eq!(a.id, 2); - let b = q.pop().unwrap(); // 1 -> 20 - assert_eq!(b.id, 20); - let c = q.pop().unwrap(); // empty -> 1 - assert_eq!(c.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 2 - q.push(b); // [2 20] - q.push(c); // [2 20] 1 - q.push(i3); // [2 20] 3 1 - q.push(i4); // [2 20] 3 1 4 - q.push(i5); // [2 20] 3 1 5 4 - - let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 - assert_eq!(a.id, 2); - q.push(a); // [20 2] 3 1 5 4 - - let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 - assert_eq!(a.id, 20); - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - q.push(b); // 2 3 1 5 4 - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - let c = q.pop().unwrap(); // 1 5 4 -> 3 - assert_eq!(c.id, 3); - q.push(b); // 2 1 5 4 - let b = q.pop().unwrap(); // 1 5 4 -> 2 - assert_eq!(b.id, 2); - let e = q.pop().unwrap(); // 5 4 -> 1 - assert_eq!(e.id, 1); - let f = q.pop().unwrap(); // 4 -> 5 - assert_eq!(f.id, 5); - let g = q.pop().unwrap(); // empty -> 4 - assert_eq!(g.id, 4); - assert!(q.pop().is_none()); - } -} -- cgit v1.2.3