aboutsummaryrefslogtreecommitdiff
path: root/src/proto.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/proto.rs')
-rw-r--r--src/proto.rs150
1 files changed, 43 insertions, 107 deletions
diff --git a/src/proto.rs b/src/proto.rs
index b044280..d90042f 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -3,14 +3,14 @@ use std::sync::Arc;
use log::trace;
-use async_trait::async_trait;
-
use async_std::io::prelude::WriteExt;
use async_std::io::ReadExt;
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
-use tokio::sync::{mpsc, watch};
+use tokio::sync::mpsc;
+
+use async_trait::async_trait;
use crate::error::*;
@@ -85,26 +85,33 @@ impl SendQueue {
}
}
}
+ fn is_empty(&self) -> bool {
+ self.items.iter().all(|(_k, v)| v.is_empty())
+ }
}
#[async_trait]
pub(crate) trait SendLoop: Sync {
async fn send_loop(
self: Arc<Self>,
- mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
+ mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
- mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut sending = SendQueue::new();
- while !*must_exit.borrow() {
- if let Ok((id, prio, data)) = msg_recv.try_recv() {
- trace!("send_loop: got {}, {} bytes", id, data.len());
- sending.push(SendQueueItem {
- id,
- prio,
- data,
- cursor: 0,
- });
+ let mut should_exit = false;
+ while !should_exit || !sending.is_empty() {
+ if let Ok(sth) = msg_recv.try_recv() {
+ if let Some((id, prio, data)) = sth {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ data,
+ cursor: 0,
+ });
+ } else {
+ should_exit = true;
+ }
} else if let Some(mut item) = sending.pop() {
trace!(
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
@@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync {
item.cursor
);
let header_id = u16::to_be_bytes(item.id);
- if write_all_or_exit(&header_id[..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&header_id[..]).await?;
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
- if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&header_size[..]).await?;
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
- if write_all_or_exit(
- &item.data[item.cursor..new_cursor],
- &mut write,
- &mut must_exit,
- )
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&item.data[item.cursor..new_cursor]).await?;
item.cursor = new_cursor;
sending.push(item);
@@ -147,33 +135,27 @@ pub(crate) trait SendLoop: Sync {
let send_len = (item.data.len() - item.cursor) as u16;
let header_size = u16::to_be_bytes(send_len);
- if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&header_size[..]).await?;
- if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&item.data[item.cursor..]).await?;
}
write.flush().await.log_err("Could not flush in send_loop");
} else {
- let (id, prio, data) = msg_recv
+ let sth = msg_recv
.recv()
.await
.ok_or(Error::Message("Connection closed.".into()))?;
- trace!("send_loop: got {}, {} bytes", id, data.len());
- sending.push(SendQueueItem {
- id,
- prio,
- data,
- cursor: 0,
- });
+ if let Some((id, prio, data)) = sth {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ data,
+ cursor: 0,
+ });
+ } else {
+ should_exit = true;
+ }
}
}
Ok(())
@@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync {
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
+ // Returns true if we should stop receiving after this
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop(
self: Arc<Self>,
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
- mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut receiving = HashMap::new();
- while !*must_exit.borrow() {
+ loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; 2];
- if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ read.read_exact(&mut header_id[..]).await?;
let id = RequestID::from_be_bytes(header_id);
trace!("recv_loop: got header id: {:04x}", id);
let mut header_size = [0u8; 2];
- if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ read.read_exact(&mut header_size[..]).await?;
let size = RequestID::from_be_bytes(header_size);
trace!("recv_loop: got header size: {:04x}", id);
@@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
let size = size & !0x8000;
let mut next_slice = vec![0; size as usize];
- if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ read.read_exact(&mut next_slice[..]).await?;
trace!("recv_loop: read {} bytes", size);
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
@@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static {
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
}
}
- Ok(())
- }
-}
-
-async fn read_exact_or_exit(
- buf: &mut [u8],
- read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
- must_exit: &mut watch::Receiver<bool>,
-) -> Result<Option<()>, Error> {
- tokio::select!(
- res = read.read_exact(buf) => Ok(Some(res?)),
- _ = await_exit(must_exit) => Ok(None),
- )
-}
-
-async fn write_all_or_exit(
- buf: &[u8],
- write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
- must_exit: &mut watch::Receiver<bool>,
-) -> Result<Option<()>, Error> {
- tokio::select!(
- res = write.write_all(buf) => Ok(Some(res?)),
- _ = await_exit(must_exit) => Ok(None),
- )
-}
-
-async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
- loop {
- if must_exit.recv().await == Some(true) {
- return;
- }
}
}