diff options
Diffstat (limited to 'src/proto.rs')
-rw-r--r-- | src/proto.rs | 159 |
1 files changed, 94 insertions, 65 deletions
diff --git a/src/proto.rs b/src/proto.rs index b45ff13..ca1a3d2 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -53,7 +53,7 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -77,6 +77,10 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, + packet: Vec<u8>, + pos: usize, + buf: Vec<u8>, + eos: bool, }, } @@ -84,7 +88,13 @@ impl From<Data> for DataReader { fn from(data: Data) -> DataReader { match data { Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { reader }, + Data::Streaming(reader) => DataReader::Streaming { + reader, + packet: Vec::new(), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, + }, } } } @@ -107,16 +117,43 @@ impl Stream for DataReader { Poll::Ready(Some((body, len))) } } - DataReaderProj::Streaming { reader } => { - reader.poll_next(cx).map(|opt| { - opt.map(|v| { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); - // TODO this can throw away long vec, they should be splited instead - body[..len].copy_from_slice(&v[..len]); - (body, len) - }) - }) + DataReaderProj::Streaming { + mut reader, + packet, + pos, + buf, + eos, + } => { + if *eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + loop { + let packet_left = packet.len() - *pos; + let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + buf.extend_from_slice(&packet[*pos..*pos + to_read]); + *pos += to_read; + if buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; + } + + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { + *packet = p; + *pos = 0; + } else { + *eos = true; + break; + } + } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..buf.len()].copy_from_slice(&buf); + buf.clear(); + Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) } } } @@ -196,10 +233,7 @@ pub(crate) trait SendLoop: Sync { data: data.into(), }); } else if let Some(mut item) = sending.pop() { - trace!( - "send_loop: sending bytes for {}", - item.id, - ); + trace!("send_loop: sending bytes for {}", item.id,); let data = futures::select! { data = item.data.next().fuse() => data, @@ -210,7 +244,6 @@ pub(crate) trait SendLoop: Sync { // TODO if every SendQueueItem is waiting on data, use select_all to await // something to do - // TODO find some way to not require sending empty last chunk } }; @@ -222,7 +255,7 @@ pub(crate) trait SendLoop: Sync { None => &[], }; - if !data.is_empty() { + if data.len() == MAX_CHUNK_LENGTH as usize { let size_header = ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; @@ -231,7 +264,6 @@ pub(crate) trait SendLoop: Sync { sending.push(item); } else { - // this is always zero for now, but may be more when above TODO get fixed let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; @@ -267,38 +299,38 @@ pub(crate) trait SendLoop: Sync { } struct ChannelPair { - receiver: Option<UnboundedReceiver<Vec<u8>>>, - sender: Option<UnboundedSender<Vec<u8>>>, + receiver: Option<UnboundedReceiver<Vec<u8>>>, + sender: Option<UnboundedSender<Vec<u8>>>, } impl ChannelPair { - fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> { - self.receiver.take() - } - - fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> { - self.sender.take() - } - - fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> { - self.sender.as_ref().take() - } - - fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); - } - } + fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } } impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), - } - } + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } } /// The RecvLoop trait, which is implemented both by the client and the server @@ -317,10 +349,7 @@ pub(crate) trait RecvLoop: Sync + 'static { R: AsyncReadExt + Unpin + Send + Sync, { let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new(); - let mut streams: HashMap< - RequestID, - ChannelPair, - > = HashMap::new(); + let mut streams: HashMap<RequestID, ChannelPair> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -345,7 +374,7 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: read {} bytes", next_slice.len()); if id & 1 == 0 { - // main stream + // main stream let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); msg_bytes.extend_from_slice(&next_slice[..]); @@ -357,30 +386,30 @@ pub(crate) trait RecvLoop: Sync + 'static { if let Some(receiver) = channel_pair.take_receiver() { self.recv_handler(id, msg_bytes, Box::pin(receiver)); } else { - warn!("Couldn't take receiver part of stream") - } + warn!("Couldn't take receiver part of stream") + } - channel_pair.insert_into(&mut streams, id | 1); + channel_pair.insert_into(&mut streams, id | 1); } } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } if !has_cont { - channel_pair.take_sender(); - } + channel_pair.take_sender(); + } - channel_pair.insert_into(&mut streams, id); + channel_pair.insert_into(&mut streams, id); } } Ok(()) |