diff options
-rw-r--r-- | src/client.rs | 2 | ||||
-rw-r--r-- | src/message.rs | 2 | ||||
-rw-r--r-- | src/send.rs | 112 | ||||
-rw-r--r-- | src/server.rs | 2 |
4 files changed, 81 insertions, 37 deletions
diff --git a/src/client.rs b/src/client.rs index aef7bbb..df54810 100644 --- a/src/client.rs +++ b/src/client.rs @@ -190,7 +190,7 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, req_stream))?; + query_send.send((id, prio, req_order, req_stream))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { diff --git a/src/message.rs b/src/message.rs index ca68cac..1834f28 100644 --- a/src/message.rs +++ b/src/message.rs @@ -44,7 +44,7 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; #[derive(Clone, Copy)] pub struct OrderTagStream(u64); #[derive(Clone, Copy, Serialize, Deserialize, Debug)] -pub struct OrderTag(u64, u64); +pub struct OrderTag(pub(crate) u64, pub(crate) u64); impl OrderTag { pub fn stream() -> OrderTagStream { diff --git a/src/send.rs b/src/send.rs index c40787f..ea6cf9f 100644 --- a/src/send.rs +++ b/src/send.rs @@ -1,4 +1,4 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -7,7 +7,7 @@ use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use log::*; -use futures::AsyncWriteExt; +use futures::{AsyncWriteExt, Future}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -36,15 +36,21 @@ pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; -pub(crate) type SendStream = (RequestID, RequestPriority, ByteStream); +pub(crate) type SendStream = (RequestID, RequestPriority, Option<OrderTag>, ByteStream); struct SendQueue { - items: Vec<(u8, VecDeque<SendQueueItem>)>, + items: Vec<(u8, SendQueuePriority)>, +} + +struct SendQueuePriority { + items: VecDeque<SendQueueItem>, + order: HashMap<u64, VecDeque<u64>>, } struct SendQueueItem { id: RequestID, prio: RequestPriority, + order_tag: Option<OrderTag>, data: ByteStreamReader, } @@ -59,11 +65,11 @@ impl SendQueue { let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { Ok(i) => i, Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); + self.items.insert(i, (prio, SendQueuePriority::new())); i } }; - self.items[pos_prio].1.push_back(item); + self.items[pos_prio].1.push(item); } fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) @@ -75,29 +81,34 @@ impl SendQueue { } } -struct SendQueuePollNextReady<'a> { - queue: &'a mut SendQueue, -} - -impl<'a> futures::Future for SendQueuePollNextReady<'a> { - type Output = (RequestID, DataFrame); - - fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { - for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { - let mut ready_item = None; - for (j, item) in items_at_prio.iter_mut().enumerate() { - let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); - match Pin::new(&mut item_reader).poll(ctx) { - Poll::Pending => (), - Poll::Ready(ready_v) => { - ready_item = Some((j, ready_v)); - break; - } +impl SendQueuePriority { + fn new() -> Self { + Self { + items: VecDeque::new(), + order: HashMap::new(), + } + } + fn push(&mut self, item: SendQueueItem) { + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.entry(stream).or_default(); + let i = order_vec.iter().take_while(|o2| **o2 < order).count(); + order_vec.insert(i, order); + } + self.items.push_back(item); + } + fn is_empty(&self) -> bool { + self.items.is_empty() + } + fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> { + for (j, item) in self.items.iter_mut().enumerate() { + if let Some(OrderTag(stream, order)) = item.order_tag { + if order > *self.order.get(&stream).unwrap().front().unwrap() { + continue; } } - if let Some((j, bytes_or_err)) = ready_item { - let item = items_at_prio.remove(j).unwrap(); + let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); + if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) { let id = item.id; let eos = item.data.eos(); @@ -106,15 +117,47 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { _ => unreachable!(), }); + if eos || packet.is_err() { + if let Some(OrderTag(stream, order)) = item.order_tag { + assert_eq!( + self.order.get_mut(&stream).unwrap().pop_front(), + Some(order) + ) + } + self.items.remove(j); + } + let data_frame = DataFrame::from_packet(packet, !eos); - if !eos && !matches!(data_frame, DataFrame::Error(_)) { - items_at_prio.push_back(item); - } else if items_at_prio.is_empty() { + return Poll::Ready((id, data_frame)); + } + } + + Poll::Pending + } + fn dump(&self, prio: u8) -> String { + self.items + .iter() + .map(|i| format!("[{} {} {:?}]", prio, i.id, i.order_tag)) + .collect::<Vec<_>>() + .join(" ") + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataFrame); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { + for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { + if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) { + if items_at_prio.is_empty() { self.queue.items.remove(i); } - - return Poll::Ready((id, data_frame)); + return Poll::Ready(res); } } // If the queue is empty, this futures is eternally pending. @@ -200,8 +243,7 @@ pub(crate) trait SendLoop: Sync { sending .items .iter() - .map(|(_, i)| i.iter().map(|x| x.id)) - .flatten() + .map(|(prio, i)| i.dump(*prio)) .collect::<Vec<_>>() ); @@ -217,12 +259,14 @@ pub(crate) trait SendLoop: Sync { // recv_fut is cancellation-safe according to tokio doc, // send_fut is cancellation-safe as implemented above? tokio::select! { + biased; // always read incomming channel first if it has data sth = recv_fut => { - if let Some((id, prio, data)) = sth { + if let Some((id, prio, order_tag, data)) = sth { trace!("send_loop: add stream {} to send", id); sending.push(SendQueueItem { id, prio, + order_tag, data: ByteStreamReader::new(data), }); } else { diff --git a/src/server.rs b/src/server.rs index c23c9e4..f8c3f98 100644 --- a/src/server.rs +++ b/src/server.rs @@ -186,7 +186,7 @@ impl RecvLoop for ServerConn { let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_stream)) + .send((id, prio, resp_order, resp_stream)) .log_err("ServerConn recv_handler send resp bytes"); Ok::<_, Error>(()) }); |