aboutsummaryrefslogtreecommitdiff
path: root/src/recv.rs
blob: 2be8728eba952cadc869c074617b354e249a61ad (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use log::trace;

use futures::AsyncReadExt;
use tokio::sync::mpsc;

use crate::error::*;
use crate::send::*;
use crate::stream::*;

/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
struct Sender {
	inner: Option<mpsc::Sender<Packet>>,
}

impl Sender {
	fn new(inner: mpsc::Sender<Packet>) -> Self {
		Sender { inner: Some(inner) }
	}

	async fn send(&self, packet: Packet) {
		let _ = self.inner.as_ref().unwrap().send(packet).await;
	}

	fn end(&mut self) {
		self.inner = None;
	}
}

impl Drop for Sender {
	fn drop(&mut self) {
		if let Some(inner) = self.inner.take() {
			let _ = inner.blocking_send(Err(255));
		}
	}
}

/// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that
/// must be filled by implementors. `.recv_loop()` receives messages in a loop
/// according to the protocol defined above: chunks of message in progress of being
/// received are stored in a buffer, and when the last chunk of a message is received,
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
	fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);

	async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
	where
		R: AsyncReadExt + Unpin + Send + Sync,
	{
		let mut streams: HashMap<RequestID, Sender> = HashMap::new();
		loop {
			trace!("recv_loop: reading packet");
			let mut header_id = [0u8; RequestID::BITS as usize / 8];
			match read.read_exact(&mut header_id[..]).await {
				Ok(_) => (),
				Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
				Err(e) => return Err(e.into()),
			};
			let id = RequestID::from_be_bytes(header_id);
			trace!("recv_loop: got header id: {:04x}", id);

			let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
			read.read_exact(&mut header_size[..]).await?;
			let size = ChunkLength::from_be_bytes(header_size);
			trace!("recv_loop: got header size: {:04x}", size);

			let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
			let is_error = (size & ERROR_MARKER) != 0;
			let packet = if is_error {
				Err((size & !ERROR_MARKER) as u8)
			} else {
				let size = size & !CHUNK_HAS_CONTINUATION;
				let mut next_slice = vec![0; size as usize];
				read.read_exact(&mut next_slice[..]).await?;
				trace!("recv_loop: read {} bytes", next_slice.len());
				Ok(Bytes::from(next_slice))
			};

			let mut sender = if let Some(send) = streams.remove(&(id)) {
				send
			} else {
				let (send, recv) = mpsc::channel(4);
				self.recv_handler(
					id,
					Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)),
				);
				Sender::new(send)
			};

			// If we get an error, the receiving end is disconnected.
			// We still need to reach eos before dropping this sender
			let _ = sender.send(packet).await;

			if has_cont {
				streams.insert(id, sender);
			} else {
				sender.end();
			}
		}
		Ok(())
	}
}