aboutsummaryrefslogtreecommitdiff
path: root/src/proto.rs
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2020-12-02 13:30:47 +0100
committerAlex Auvolat <alex@adnab.me>2020-12-02 13:30:47 +0100
commitd4de2ffc40fe9d003f12139053ca070eda0b7221 (patch)
treee95476f0b7a6d1c75cc462b3ea7eee74c4faf09f /src/proto.rs
downloadnetapp-d4de2ffc40fe9d003f12139053ca070eda0b7221.tar.gz
netapp-d4de2ffc40fe9d003f12139053ca070eda0b7221.zip
First commit
Diffstat (limited to 'src/proto.rs')
-rw-r--r--src/proto.rs251
1 files changed, 251 insertions, 0 deletions
diff --git a/src/proto.rs b/src/proto.rs
new file mode 100644
index 0000000..58c914e
--- /dev/null
+++ b/src/proto.rs
@@ -0,0 +1,251 @@
+use std::collections::{BTreeMap, HashMap, VecDeque};
+use std::sync::Arc;
+
+use log::trace;
+
+use async_trait::async_trait;
+
+use async_std::io::prelude::WriteExt;
+use async_std::io::ReadExt;
+
+use tokio::io::{ReadHalf, WriteHalf};
+use tokio::net::TcpStream;
+use tokio::sync::{mpsc, watch};
+
+use crate::error::*;
+
+use kuska_handshake::async_std::{BoxStreamRead, BoxStreamWrite, TokioCompat};
+
+const MAX_CHUNK_SIZE: usize = 0x4000;
+
+pub mod prio {
+ pub const HIGH: u8 = 0x20;
+ pub const NORMAL: u8 = 0x40;
+ pub const BACKGROUND: u8 = 0x80;
+
+ pub const PRIMARY: u8 = 0x00;
+ pub const SECONDARY: u8 = 0x01;
+}
+
+pub type RequestID = u16;
+pub type RequestPriority = u8;
+
+struct SendQueueItem {
+ id: RequestID,
+ prio: RequestPriority,
+ data: Vec<u8>,
+ cursor: usize,
+}
+
+struct SendQueue {
+ items: BTreeMap<u8, VecDeque<SendQueueItem>>,
+}
+
+impl SendQueue {
+ fn new() -> Self {
+ Self {
+ items: BTreeMap::new(),
+ }
+ }
+ fn push(&mut self, item: SendQueueItem) {
+ let prio = item.prio;
+ let mut items_at_prio = self
+ .items
+ .remove(&prio)
+ .unwrap_or(VecDeque::with_capacity(4));
+ items_at_prio.push_back(item);
+ self.items.insert(prio, items_at_prio);
+ }
+ fn pop(&mut self) -> Option<SendQueueItem> {
+ match self.items.pop_first() {
+ None => None,
+ Some((prio, mut items_at_prio)) => {
+ let ret = items_at_prio.pop_front();
+ if !items_at_prio.is_empty() {
+ self.items.insert(prio, items_at_prio);
+ }
+ ret
+ }
+ }
+ }
+}
+
+#[async_trait]
+pub(crate) trait SendLoop: Sync {
+ async fn send_loop(
+ self: Arc<Self>,
+ mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
+ mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
+ mut must_exit: watch::Receiver<bool>,
+ ) -> Result<(), Error> {
+ let mut sending = SendQueue::new();
+ while !*must_exit.borrow() {
+ if let Ok((id, prio, data)) = msg_recv.try_recv() {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ data,
+ cursor: 0,
+ });
+ } else if let Some(mut item) = sending.pop() {
+ trace!(
+ "send_loop: sending bytes for {} ({} bytes, {} already sent)",
+ item.id,
+ item.data.len(),
+ item.cursor
+ );
+ let header_id = u16::to_be_bytes(item.id);
+ if write_all_or_exit(&header_id[..], &mut write, &mut must_exit)
+ .await?
+ .is_none()
+ {
+ break;
+ }
+
+ if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
+ let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
+ if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
+ .await?
+ .is_none()
+ {
+ break;
+ }
+
+ let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
+ if write_all_or_exit(
+ &item.data[item.cursor..new_cursor],
+ &mut write,
+ &mut must_exit,
+ )
+ .await?
+ .is_none()
+ {
+ break;
+ }
+ item.cursor = new_cursor;
+
+ sending.push(item);
+ } else {
+ let send_len = (item.data.len() - item.cursor) as u16;
+
+ let header_size = u16::to_be_bytes(send_len);
+ if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
+ .await?
+ .is_none()
+ {
+ break;
+ }
+
+ if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit)
+ .await?
+ .is_none()
+ {
+ break;
+ }
+ }
+ write.flush().await.log_err("Could not flush in send_loop");
+ } else {
+ let (id, prio, data) = msg_recv
+ .recv()
+ .await
+ .ok_or(Error::Message("Connection closed.".into()))?;
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ data,
+ cursor: 0,
+ });
+ }
+ }
+ Ok(())
+ }
+}
+
+#[async_trait]
+pub(crate) trait RecvLoop: Sync + 'static {
+ async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
+
+ async fn recv_loop(
+ self: Arc<Self>,
+ mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
+ mut must_exit: watch::Receiver<bool>,
+ ) -> Result<(), Error> {
+ let mut receiving = HashMap::new();
+ while !*must_exit.borrow() {
+ trace!("recv_loop: reading packet");
+ let mut header_id = [0u8; 2];
+ if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit)
+ .await?
+ .is_none()
+ {
+ break;
+ }
+ let id = RequestID::from_be_bytes(header_id);
+ trace!("recv_loop: got header id: {:04x}", id);
+
+ let mut header_size = [0u8; 2];
+ if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit)
+ .await?
+ .is_none()
+ {
+ break;
+ }
+ let size = RequestID::from_be_bytes(header_size);
+ trace!("recv_loop: got header size: {:04x}", id);
+
+ let has_cont = (size & 0x8000) != 0;
+ let size = size & !0x8000;
+
+ let mut next_slice = vec![0; size as usize];
+ if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit)
+ .await?
+ .is_none()
+ {
+ break;
+ }
+ trace!("recv_loop: read {} bytes", size);
+
+ let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
+ msg_bytes.extend_from_slice(&next_slice[..]);
+
+ if has_cont {
+ receiving.insert(id, msg_bytes);
+ } else {
+ tokio::spawn(self.clone().recv_handler(id, msg_bytes));
+ }
+ }
+ Ok(())
+ }
+}
+
+async fn read_exact_or_exit(
+ buf: &mut [u8],
+ read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
+ must_exit: &mut watch::Receiver<bool>,
+) -> Result<Option<()>, Error> {
+ tokio::select!(
+ res = read.read_exact(buf) => Ok(Some(res?)),
+ _ = await_exit(must_exit) => Ok(None),
+ )
+}
+
+async fn write_all_or_exit(
+ buf: &[u8],
+ write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
+ must_exit: &mut watch::Receiver<bool>,
+) -> Result<Option<()>, Error> {
+ tokio::select!(
+ res = write.write_all(buf) => Ok(Some(res?)),
+ _ = await_exit(must_exit) => Ok(None),
+ )
+}
+
+async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
+ loop {
+ if must_exit.recv().await == Some(true) {
+ return;
+ }
+ }
+}