diff options
Diffstat (limited to 'src/proto.rs')
-rw-r--r-- | src/proto.rs | 617 |
1 files changed, 0 insertions, 617 deletions
diff --git a/src/proto.rs b/src/proto.rs deleted file mode 100644 index 92d8d80..0000000 --- a/src/proto.rs +++ /dev/null @@ -1,617 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use log::trace; - -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::{AsyncReadExt, AsyncWriteExt}; -use futures::{Stream, StreamExt}; -use kuska_handshake::async_std::BoxStreamWrite; - -use tokio::sync::mpsc; - -use async_trait::async_trait; - -use crate::error::*; -use crate::util::{AssociatedStream, Packet}; - -/// Priority of a request (click to read more about priorities). -/// -/// This priority value is used to priorize messages -/// in the send queue of the client, and their responses in the send queue of the -/// server. Lower values mean higher priority. -/// -/// This mechanism is usefull for messages bigger than the maximum chunk size -/// (set at `0x4000` bytes), such as large file transfers. -/// In such case, all of the messages in the send queue with the highest priority -/// will take turns to send individual chunks, in a round-robin fashion. -/// Once all highest priority messages are sent successfully, the messages with -/// the next highest priority will begin being sent in the same way. -/// -/// The same priority value is given to a request and to its associated response. -pub type RequestPriority = u8; - -/// Priority class: high -pub const PRIO_HIGH: RequestPriority = 0x20; -/// Priority class: normal -pub const PRIO_NORMAL: RequestPriority = 0x40; -/// Priority class: background -pub const PRIO_BACKGROUND: RequestPriority = 0x80; -/// Priority: primary among given class -pub const PRIO_PRIMARY: RequestPriority = 0x00; -/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) -pub const PRIO_SECONDARY: RequestPriority = 0x01; - -// Messages are sent by chunks -// Chunk format: -// - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data - -pub(crate) type RequestID = u32; -type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; -const ERROR_MARKER: ChunkLength = 0x4000; -const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; - -struct SendQueueItem { - id: RequestID, - prio: RequestPriority, - data: DataReader, -} - -#[pin_project::pin_project] -struct DataReader { - #[pin] - reader: AssociatedStream, - packet: Packet, - pos: usize, - buf: Vec<u8>, - eos: bool, -} - -impl From<AssociatedStream> for DataReader { - fn from(data: AssociatedStream) -> DataReader { - DataReader { - reader: data, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - } - } -} - -enum DataFrame { - Data { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actual lenght of data - len: usize, - }, - Error(u8), -} - -struct DataReaderItem { - data: DataFrame, - /// whethere there may be more data comming from this stream. Can be used for some - /// optimization. It's an error to set it to false if there is more data, but it is correct - /// (albeit sub-optimal) to set it to true if there is nothing coming after - may_have_more: bool, -} - -impl DataReaderItem { - fn empty_last() -> Self { - DataReaderItem { - data: DataFrame::Data { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, - }, - may_have_more: false, - } - } - - fn header(&self) -> [u8; 2] { - let continuation = if self.may_have_more { - CHUNK_HAS_CONTINUATION - } else { - 0 - }; - let len = match self.data { - DataFrame::Data { len, .. } => len as u16, - DataFrame::Error(e) => e as u16 | ERROR_MARKER, - }; - - ChunkLength::to_be_bytes(len | continuation) - } - - fn data(&self) -> &[u8] { - match self.data { - DataFrame::Data { ref data, len } => &data[..len], - DataFrame::Error(_) => &[], - } - } -} - -impl Stream for DataReader { - type Item = DataReaderItem; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let mut this = self.project(); - - if *this.eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - - loop { - let packet = match this.packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *this.packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *this.pos; - let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - this.buf - .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); - *this.pos += to_read; - if this.buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { - *this.packet = p; - *this.pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if this.packet.is_err() && !this.buf.is_empty() { - break; - } - } else { - *this.eos = true; - break; - } - } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = this.buf.len(); - body[..len].copy_from_slice(this.buf); - this.buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*this.eos, - })) - } -} - -struct SendQueue { - items: VecDeque<(u8, VecDeque<SendQueueItem>)>, -} - -impl SendQueue { - fn new() -> Self { - Self { - items: VecDeque::with_capacity(64), - } - } - fn push(&mut self, item: SendQueueItem) { - let prio = item.prio; - let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { - Ok(i) => i, - Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); - i - } - }; - self.items[pos_prio].1.push_back(item); - } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] - fn pop(&mut self) -> Option<SendQueueItem> { - match self.items.pop_front() { - None => None, - Some((prio, mut items_at_prio)) => { - let ret = items_at_prio.pop_front(); - if !items_at_prio.is_empty() { - self.items.push_front((prio, items_at_prio)); - } - ret.or_else(|| self.pop()) - } - } - } - fn is_empty(&self) -> bool { - self.items.iter().all(|(_k, v)| v.is_empty()) - } - - // this is like an async fn, but hand implemented - fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { - SendQueuePollNextReady { queue: self } - } -} - -struct SendQueuePollNextReady<'a> { - queue: &'a mut SendQueue, -} - -impl<'a> futures::Future for SendQueuePollNextReady<'a> { - type Output = (RequestID, DataReaderItem); - - fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { - for i in 0..self.queue.items.len() { - let (_prio, items_at_prio) = &mut self.queue.items[i]; - - for _ in 0..items_at_prio.len() { - let mut item = items_at_prio.pop_front().unwrap(); - - match Pin::new(&mut item.data).poll_next(ctx) { - Poll::Pending => items_at_prio.push_back(item), - Poll::Ready(Some(data)) => { - let id = item.id; - if data.may_have_more { - self.queue.push(item); - } else { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - } - return Poll::Ready((id, data)); - } - Poll::Ready(None) => { - if items_at_prio.is_empty() { - // this priority level is empty, remove it - self.queue.items.remove(i); - } - return Poll::Ready((item.id, DataReaderItem::empty_last())); - } - } - } - } - // TODO what do we do if self.queue is empty? We won't get scheduled again. - Poll::Pending - } -} - -/// The SendLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` -/// that takes a channel of messages to send and an asynchronous writer, -/// and sends messages from the channel to the async writer, putting them in a queue -/// before being sent and doing the round-robin sending strategy. -/// -/// The `.send_loop()` exits when the sending end of the channel is closed, -/// or if there is an error at any time writing to the async writer. -#[async_trait] -pub(crate) trait SendLoop: Sync { - async fn send_loop<W>( - self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, - mut write: BoxStreamWrite<W>, - ) -> Result<(), Error> - where - W: AsyncWriteExt + Unpin + Send + Sync, - { - let mut sending = SendQueue::new(); - let mut should_exit = false; - while !should_exit || !sending.is_empty() { - let recv_fut = msg_recv.recv(); - futures::pin_mut!(recv_fut); - let send_fut = sending.next_ready(); - - // recv_fut is cancellation-safe according to tokio doc, - // send_fut is cancellation-safe as implemented above? - use futures::future::Either; - match futures::future::select(recv_fut, send_fut).await { - Either::Left((sth, _send_fut)) => { - if let Some((id, prio, data)) = sth { - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; - }; - } - Either::Right(((id, data), _recv_fut)) => { - trace!("send_loop: sending bytes for {}", id); - - let header_id = RequestID::to_be_bytes(id); - write.write_all(&header_id[..]).await?; - - write.write_all(&data.header()).await?; - write.write_all(data.data()).await?; - write.flush().await?; - } - } - } - - let _ = write.goodbye().await; - Ok(()) - } -} - -pub(crate) struct Framing { - direct: Vec<u8>, - stream: Option<AssociatedStream>, -} - -impl Framing { - pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } - } - - pub fn into_stream(self) -> AssociatedStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; - - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) - .chain(stream::once(async move { Ok(direct) })); - - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } - - pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>( - mut stream: S, - ) -> Result<Self, Error> { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); - } - - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet.drain(..4); - - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); - - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet.drain(..max_cp); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } - - let stream: AssociatedStream = if packet.is_empty() { - Box::pin(stream) - } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; - - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) - } -} - -/// Structure to warn when the sender is dropped before end of stream was reached, like when -/// connection to some remote drops while transmitting data -struct Sender { - inner: UnboundedSender<Packet>, - closed: bool, -} - -impl Sender { - fn new(inner: UnboundedSender<Packet>) -> Self { - Sender { - inner, - closed: false, - } - } - - fn send(&self, packet: Packet) { - let _ = self.inner.unbounded_send(packet); - } - - fn end(&mut self) { - self.closed = true; - } -} - -impl Drop for Sender { - fn drop(&mut self) { - if !self.closed { - self.send(Err(255)); - } - self.inner.close_channel(); - } -} - -/// The RecvLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` -/// and a prototype of a handler for received messages `.recv_handler()` that -/// must be filled by implementors. `.recv_loop()` receives messages in a loop -/// according to the protocol defined above: chunks of message in progress of being -/// received are stored in a buffer, and when the last chunk of a message is received, -/// the full message is passed to the receive handler. -#[async_trait] -pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>); - - async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> - where - R: AsyncReadExt + Unpin + Send + Sync, - { - let mut streams: HashMap<RequestID, Sender> = HashMap::new(); - loop { - trace!("recv_loop: reading packet"); - let mut header_id = [0u8; RequestID::BITS as usize / 8]; - match read.read_exact(&mut header_id[..]).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e.into()), - }; - let id = RequestID::from_be_bytes(header_id); - trace!("recv_loop: got header id: {:04x}", id); - - let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; - read.read_exact(&mut header_size[..]).await?; - let size = ChunkLength::from_be_bytes(header_size); - trace!("recv_loop: got header size: {:04x}", size); - - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let is_error = (size & ERROR_MARKER) != 0; - let packet = if is_error { - Err(size as u8) - } else { - let size = size & !CHUNK_HAS_CONTINUATION; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - Ok(next_slice) - }; - - let mut sender = if let Some(send) = streams.remove(&(id)) { - send - } else { - let (send, recv) = unbounded(); - self.recv_handler(id, recv); - Sender::new(send) - }; - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - sender.send(packet); - - if has_cont { - streams.insert(id, sender); - } else { - sender.end(); - } - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn empty_data() -> DataReader { - type Item = Packet; - let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> = - Box::pin(futures::stream::empty::<Packet>()); - stream.into() - } - - #[test] - fn test_priority_queue() { - let i1 = SendQueueItem { - id: 1, - prio: PRIO_NORMAL, - data: empty_data(), - }; - let i2 = SendQueueItem { - id: 2, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i2bis = SendQueueItem { - id: 20, - prio: PRIO_HIGH, - data: empty_data(), - }; - let i3 = SendQueueItem { - id: 3, - prio: PRIO_HIGH | PRIO_SECONDARY, - data: empty_data(), - }; - let i4 = SendQueueItem { - id: 4, - prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: empty_data(), - }; - let i5 = SendQueueItem { - id: 5, - prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: empty_data(), - }; - - let mut q = SendQueue::new(); - - q.push(i1); // 1 - let a = q.pop().unwrap(); // empty -> 1 - assert_eq!(a.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 1 - q.push(i2); // 2 1 - q.push(i2bis); // [2 20] 1 - let a = q.pop().unwrap(); // 20 1 -> 2 - assert_eq!(a.id, 2); - let b = q.pop().unwrap(); // 1 -> 20 - assert_eq!(b.id, 20); - let c = q.pop().unwrap(); // empty -> 1 - assert_eq!(c.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 2 - q.push(b); // [2 20] - q.push(c); // [2 20] 1 - q.push(i3); // [2 20] 3 1 - q.push(i4); // [2 20] 3 1 4 - q.push(i5); // [2 20] 3 1 5 4 - - let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 - assert_eq!(a.id, 2); - q.push(a); // [20 2] 3 1 5 4 - - let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 - assert_eq!(a.id, 20); - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - q.push(b); // 2 3 1 5 4 - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - let c = q.pop().unwrap(); // 1 5 4 -> 3 - assert_eq!(c.id, 3); - q.push(b); // 2 1 5 4 - let b = q.pop().unwrap(); // 1 5 4 -> 2 - assert_eq!(b.id, 2); - let e = q.pop().unwrap(); // 5 4 -> 1 - assert_eq!(e.id, 1); - let f = q.pop().unwrap(); // 4 -> 5 - assert_eq!(f.id, 5); - let g = q.pop().unwrap(); // empty -> 4 - assert_eq!(g.id, 4); - assert!(q.pop().is_none()); - } -} |