aboutsummaryrefslogtreecommitdiff
path: root/src/proto.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/proto.rs')
-rw-r--r--src/proto.rs159
1 files changed, 94 insertions, 65 deletions
diff --git a/src/proto.rs b/src/proto.rs
index b45ff13..ca1a3d2 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -53,7 +53,7 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
pub(crate) type RequestID = u32;
type ChunkLength = u16;
-pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
+const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem {
@@ -77,6 +77,10 @@ enum DataReader {
Streaming {
#[pin]
reader: AssociatedStream,
+ packet: Vec<u8>,
+ pos: usize,
+ buf: Vec<u8>,
+ eos: bool,
},
}
@@ -84,7 +88,13 @@ impl From<Data> for DataReader {
fn from(data: Data) -> DataReader {
match data {
Data::Full(data) => DataReader::Full { data, pos: 0 },
- Data::Streaming(reader) => DataReader::Streaming { reader },
+ Data::Streaming(reader) => DataReader::Streaming {
+ reader,
+ packet: Vec::new(),
+ pos: 0,
+ buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
+ eos: false,
+ },
}
}
}
@@ -107,16 +117,43 @@ impl Stream for DataReader {
Poll::Ready(Some((body, len)))
}
}
- DataReaderProj::Streaming { reader } => {
- reader.poll_next(cx).map(|opt| {
- opt.map(|v| {
- let mut body = [0; MAX_CHUNK_LENGTH as usize];
- let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len());
- // TODO this can throw away long vec, they should be splited instead
- body[..len].copy_from_slice(&v[..len]);
- (body, len)
- })
- })
+ DataReaderProj::Streaming {
+ mut reader,
+ packet,
+ pos,
+ buf,
+ eos,
+ } => {
+ if *eos {
+ // eos was reached at previous call to poll_next, where a partial packet
+ // was returned. Now return None
+ return Poll::Ready(None);
+ }
+ loop {
+ let packet_left = packet.len() - *pos;
+ let buf_left = MAX_CHUNK_LENGTH as usize - buf.len();
+ let to_read = std::cmp::min(buf_left, packet_left);
+ buf.extend_from_slice(&packet[*pos..*pos + to_read]);
+ *pos += to_read;
+ if buf.len() == MAX_CHUNK_LENGTH as usize {
+ // we have a full buf, ready to send
+ break;
+ }
+
+ // we don't have a full buf, packet is empty; try receive more
+ if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) {
+ *packet = p;
+ *pos = 0;
+ } else {
+ *eos = true;
+ break;
+ }
+ }
+
+ let mut body = [0; MAX_CHUNK_LENGTH as usize];
+ body[..buf.len()].copy_from_slice(&buf);
+ buf.clear();
+ Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize)))
}
}
}
@@ -196,10 +233,7 @@ pub(crate) trait SendLoop: Sync {
data: data.into(),
});
} else if let Some(mut item) = sending.pop() {
- trace!(
- "send_loop: sending bytes for {}",
- item.id,
- );
+ trace!("send_loop: sending bytes for {}", item.id,);
let data = futures::select! {
data = item.data.next().fuse() => data,
@@ -210,7 +244,6 @@ pub(crate) trait SendLoop: Sync {
// TODO if every SendQueueItem is waiting on data, use select_all to await
// something to do
- // TODO find some way to not require sending empty last chunk
}
};
@@ -222,7 +255,7 @@ pub(crate) trait SendLoop: Sync {
None => &[],
};
- if !data.is_empty() {
+ if data.len() == MAX_CHUNK_LENGTH as usize {
let size_header =
ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION);
write.write_all(&size_header[..]).await?;
@@ -231,7 +264,6 @@ pub(crate) trait SendLoop: Sync {
sending.push(item);
} else {
- // this is always zero for now, but may be more when above TODO get fixed
let size_header = ChunkLength::to_be_bytes(data.len() as u16);
write.write_all(&size_header[..]).await?;
@@ -267,38 +299,38 @@ pub(crate) trait SendLoop: Sync {
}
struct ChannelPair {
- receiver: Option<UnboundedReceiver<Vec<u8>>>,
- sender: Option<UnboundedSender<Vec<u8>>>,
+ receiver: Option<UnboundedReceiver<Vec<u8>>>,
+ sender: Option<UnboundedSender<Vec<u8>>>,
}
impl ChannelPair {
- fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> {
- self.receiver.take()
- }
-
- fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> {
- self.sender.take()
- }
-
- fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> {
- self.sender.as_ref().take()
- }
-
- fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) {
- if self.receiver.is_some() || self.sender.is_some() {
- map.insert(index, self);
- }
- }
+ fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> {
+ self.receiver.take()
+ }
+
+ fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> {
+ self.sender.take()
+ }
+
+ fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> {
+ self.sender.as_ref().take()
+ }
+
+ fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) {
+ if self.receiver.is_some() || self.sender.is_some() {
+ map.insert(index, self);
+ }
+ }
}
impl Default for ChannelPair {
- fn default() -> Self {
- let (send, recv) = unbounded();
- ChannelPair {
- receiver: Some(recv),
- sender: Some(send),
- }
- }
+ fn default() -> Self {
+ let (send, recv) = unbounded();
+ ChannelPair {
+ receiver: Some(recv),
+ sender: Some(send),
+ }
+ }
}
/// The RecvLoop trait, which is implemented both by the client and the server
@@ -317,10 +349,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
R: AsyncReadExt + Unpin + Send + Sync,
{
let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new();
- let mut streams: HashMap<
- RequestID,
- ChannelPair,
- > = HashMap::new();
+ let mut streams: HashMap<RequestID, ChannelPair> = HashMap::new();
loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8];
@@ -345,7 +374,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
trace!("recv_loop: read {} bytes", next_slice.len());
if id & 1 == 0 {
- // main stream
+ // main stream
let mut msg_bytes = receiving.remove(&id).unwrap_or_default();
msg_bytes.extend_from_slice(&next_slice[..]);
@@ -357,30 +386,30 @@ pub(crate) trait RecvLoop: Sync + 'static {
if let Some(receiver) = channel_pair.take_receiver() {
self.recv_handler(id, msg_bytes, Box::pin(receiver));
} else {
- warn!("Couldn't take receiver part of stream")
- }
+ warn!("Couldn't take receiver part of stream")
+ }
- channel_pair.insert_into(&mut streams, id | 1);
+ channel_pair.insert_into(&mut streams, id | 1);
}
} else {
- // associated stream
- let mut channel_pair = streams.remove(&(id)).unwrap_or_default();
+ // associated stream
+ let mut channel_pair = streams.remove(&(id)).unwrap_or_default();
// if we get an error, the receiving end is disconnected. We still need to
// reach eos before dropping this sender
- if !next_slice.is_empty() {
- if let Some(sender) = channel_pair.ref_sender() {
- let _ = sender.unbounded_send(next_slice);
- } else {
- warn!("Couldn't take sending part of stream")
- }
- }
+ if !next_slice.is_empty() {
+ if let Some(sender) = channel_pair.ref_sender() {
+ let _ = sender.unbounded_send(next_slice);
+ } else {
+ warn!("Couldn't take sending part of stream")
+ }
+ }
if !has_cont {
- channel_pair.take_sender();
- }
+ channel_pair.take_sender();
+ }
- channel_pair.insert_into(&mut streams, id);
+ channel_pair.insert_into(&mut streams, id);
}
}
Ok(())