diff options
author | Alex Auvolat <alex@adnab.me> | 2020-12-02 13:30:47 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2020-12-02 13:30:47 +0100 |
commit | d4de2ffc40fe9d003f12139053ca070eda0b7221 (patch) | |
tree | e95476f0b7a6d1c75cc462b3ea7eee74c4faf09f /src/proto.rs | |
download | netapp-d4de2ffc40fe9d003f12139053ca070eda0b7221.tar.gz netapp-d4de2ffc40fe9d003f12139053ca070eda0b7221.zip |
First commit
Diffstat (limited to 'src/proto.rs')
-rw-r--r-- | src/proto.rs | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/src/proto.rs b/src/proto.rs new file mode 100644 index 0000000..58c914e --- /dev/null +++ b/src/proto.rs @@ -0,0 +1,251 @@ +use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::sync::Arc; + +use log::trace; + +use async_trait::async_trait; + +use async_std::io::prelude::WriteExt; +use async_std::io::ReadExt; + +use tokio::io::{ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio::sync::{mpsc, watch}; + +use crate::error::*; + +use kuska_handshake::async_std::{BoxStreamRead, BoxStreamWrite, TokioCompat}; + +const MAX_CHUNK_SIZE: usize = 0x4000; + +pub mod prio { + pub const HIGH: u8 = 0x20; + pub const NORMAL: u8 = 0x40; + pub const BACKGROUND: u8 = 0x80; + + pub const PRIMARY: u8 = 0x00; + pub const SECONDARY: u8 = 0x01; +} + +pub type RequestID = u16; +pub type RequestPriority = u8; + +struct SendQueueItem { + id: RequestID, + prio: RequestPriority, + data: Vec<u8>, + cursor: usize, +} + +struct SendQueue { + items: BTreeMap<u8, VecDeque<SendQueueItem>>, +} + +impl SendQueue { + fn new() -> Self { + Self { + items: BTreeMap::new(), + } + } + fn push(&mut self, item: SendQueueItem) { + let prio = item.prio; + let mut items_at_prio = self + .items + .remove(&prio) + .unwrap_or(VecDeque::with_capacity(4)); + items_at_prio.push_back(item); + self.items.insert(prio, items_at_prio); + } + fn pop(&mut self) -> Option<SendQueueItem> { + match self.items.pop_first() { + None => None, + Some((prio, mut items_at_prio)) => { + let ret = items_at_prio.pop_front(); + if !items_at_prio.is_empty() { + self.items.insert(prio, items_at_prio); + } + ret + } + } + } +} + +#[async_trait] +pub(crate) trait SendLoop: Sync { + async fn send_loop( + self: Arc<Self>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, + mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>, + mut must_exit: watch::Receiver<bool>, + ) -> Result<(), Error> { + let mut sending = SendQueue::new(); + while !*must_exit.borrow() { + if let Ok((id, prio, data)) = msg_recv.try_recv() { + trace!("send_loop: got {}, {} bytes", id, data.len()); + sending.push(SendQueueItem { + id, + prio, + data, + cursor: 0, + }); + } else if let Some(mut item) = sending.pop() { + trace!( + "send_loop: sending bytes for {} ({} bytes, {} already sent)", + item.id, + item.data.len(), + item.cursor + ); + let header_id = u16::to_be_bytes(item.id); + if write_all_or_exit(&header_id[..], &mut write, &mut must_exit) + .await? + .is_none() + { + break; + } + + if item.data.len() - item.cursor > MAX_CHUNK_SIZE { + let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); + if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) + .await? + .is_none() + { + break; + } + + let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; + if write_all_or_exit( + &item.data[item.cursor..new_cursor], + &mut write, + &mut must_exit, + ) + .await? + .is_none() + { + break; + } + item.cursor = new_cursor; + + sending.push(item); + } else { + let send_len = (item.data.len() - item.cursor) as u16; + + let header_size = u16::to_be_bytes(send_len); + if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) + .await? + .is_none() + { + break; + } + + if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit) + .await? + .is_none() + { + break; + } + } + write.flush().await.log_err("Could not flush in send_loop"); + } else { + let (id, prio, data) = msg_recv + .recv() + .await + .ok_or(Error::Message("Connection closed.".into()))?; + trace!("send_loop: got {}, {} bytes", id, data.len()); + sending.push(SendQueueItem { + id, + prio, + data, + cursor: 0, + }); + } + } + Ok(()) + } +} + +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>); + + async fn recv_loop( + self: Arc<Self>, + mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>, + mut must_exit: watch::Receiver<bool>, + ) -> Result<(), Error> { + let mut receiving = HashMap::new(); + while !*must_exit.borrow() { + trace!("recv_loop: reading packet"); + let mut header_id = [0u8; 2]; + if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit) + .await? + .is_none() + { + break; + } + let id = RequestID::from_be_bytes(header_id); + trace!("recv_loop: got header id: {:04x}", id); + + let mut header_size = [0u8; 2]; + if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit) + .await? + .is_none() + { + break; + } + let size = RequestID::from_be_bytes(header_size); + trace!("recv_loop: got header size: {:04x}", id); + + let has_cont = (size & 0x8000) != 0; + let size = size & !0x8000; + + let mut next_slice = vec![0; size as usize]; + if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit) + .await? + .is_none() + { + break; + } + trace!("recv_loop: read {} bytes", size); + + let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]); + msg_bytes.extend_from_slice(&next_slice[..]); + + if has_cont { + receiving.insert(id, msg_bytes); + } else { + tokio::spawn(self.clone().recv_handler(id, msg_bytes)); + } + } + Ok(()) + } +} + +async fn read_exact_or_exit( + buf: &mut [u8], + read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>, + must_exit: &mut watch::Receiver<bool>, +) -> Result<Option<()>, Error> { + tokio::select!( + res = read.read_exact(buf) => Ok(Some(res?)), + _ = await_exit(must_exit) => Ok(None), + ) +} + +async fn write_all_or_exit( + buf: &[u8], + write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>, + must_exit: &mut watch::Receiver<bool>, +) -> Result<Option<()>, Error> { + tokio::select!( + res = write.write_all(buf) => Ok(Some(res?)), + _ = await_exit(must_exit) => Ok(None), + ) +} + +async fn await_exit(must_exit: &mut watch::Receiver<bool>) { + loop { + if must_exit.recv().await == Some(true) { + return; + } + } +} |