aboutsummaryrefslogblamecommitdiff
path: root/src/stream.rs
blob: 5ba2ed49e558a7612015600d5ee49f3a0f2500e9 (plain) (tree)
1
2
3
4
5
6
7
8
9






                               

                                               











                                                                                          

       































































                                                                                                  
                                                
                                                                                                      





                                   
                                             

         





























































                                                                                                    
                                                                                          




                                                                                          



                                                                                   











                                                                  





















                                                    
                               















                                                                                        
                                                                                   




                                                                                       
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};

use bytes::Bytes;

use futures::Future;
use futures::{Stream, StreamExt, TryStreamExt};
use tokio::io::AsyncRead;

/// A stream of associated data.
///
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered
///
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// meaning, it's up to your application to define their semantic.
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;

pub type Packet = Result<Bytes, u8>;

// ----

pub struct ByteStreamReader {
	stream: ByteStream,
	buf: VecDeque<Bytes>,
	buf_len: usize,
	eos: bool,
	err: Option<u8>,
}

impl ByteStreamReader {
	pub fn new(stream: ByteStream) -> Self {
		ByteStreamReader {
			stream,
			buf: VecDeque::with_capacity(8),
			buf_len: 0,
			eos: false,
			err: None,
		}
	}

	pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
		ByteStreamReadExact {
			reader: self,
			read_len,
			fail_on_eos: true,
		}
	}

	pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
		ByteStreamReadExact {
			reader: self,
			read_len,
			fail_on_eos: false,
		}
	}

	pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
		Ok(self.read_exact(1).await?[0])
	}

	pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
		let bytes = self.read_exact(2).await?;
		let mut b = [0u8; 2];
		b.copy_from_slice(&bytes[..]);
		Ok(u16::from_be_bytes(b))
	}

	pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
		let bytes = self.read_exact(4).await?;
		let mut b = [0u8; 4];
		b.copy_from_slice(&bytes[..]);
		Ok(u32::from_be_bytes(b))
	}

	pub fn into_stream(self) -> ByteStream {
		let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok));
		if let Some(err) = self.err {
			Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
		} else if self.eos {
			Box::pin(buf_stream)
		} else {
			Box::pin(buf_stream.chain(self.stream))
		}
	}

	pub fn take_buffer(&mut self) -> Bytes {
		let bytes = Bytes::from(self.buf.iter().map(|x| &x[..]).collect::<Vec<_>>().concat());
		self.buf.clear();
		self.buf_len = 0;
		bytes
	}

	pub fn eos(&self) -> bool {
		self.buf_len == 0 && self.eos
	}

	fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
		if self.buf_len >= read_len {
			let mut slices = Vec::with_capacity(self.buf.len());
			let mut taken = 0;
			while taken < read_len {
				let front = self.buf.pop_front().unwrap();
				if taken + front.len() <= read_len {
					taken += front.len();
					self.buf_len -= front.len();
					slices.push(front);
				} else {
					let front_take = read_len - taken;
					slices.push(front.slice(..front_take));
					self.buf.push_front(front.slice(front_take..));
					self.buf_len -= front_take;
					break;
				}
			}
			Some(
				slices
					.iter()
					.map(|x| &x[..])
					.collect::<Vec<_>>()
					.concat()
					.into(),
			)
		} else {
			None
		}
	}
}

pub enum ReadExactError {
	UnexpectedEos,
	Stream(u8),
}

#[pin_project::pin_project]
pub struct ByteStreamReadExact<'a> {
	#[pin]
	reader: &'a mut ByteStreamReader,
	read_len: usize,
	fail_on_eos: bool,
}

impl<'a> Future for ByteStreamReadExact<'a> {
	type Output = Result<Bytes, ReadExactError>;

	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
		let mut this = self.project();

		loop {
			if let Some(bytes) = this.reader.try_get(*this.read_len) {
				return Poll::Ready(Ok(bytes));
			}
			if let Some(err) = this.reader.err {
				return Poll::Ready(Err(ReadExactError::Stream(err)));
			}
			if this.reader.eos {
				if *this.fail_on_eos {
					return Poll::Ready(Err(ReadExactError::UnexpectedEos));
				} else {
					return Poll::Ready(Ok(this.reader.take_buffer()));
				}
			}

			match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) {
				Some(Ok(slice)) => {
					if !slice.is_empty() {
						this.reader.buf_len += slice.len();
						this.reader.buf.push_back(slice);
					}
				}
				Some(Err(e)) => {
					this.reader.err = Some(e);
					this.reader.eos = true;
				}
				None => {
					this.reader.eos = true;
				}
			}
		}
	}
}

// ----

fn u8_to_io_error(v: u8) -> std::io::Error {
	use std::io::{Error, ErrorKind};
	let kind = match v {
		101 => ErrorKind::ConnectionAborted,
		102 => ErrorKind::BrokenPipe,
		103 => ErrorKind::WouldBlock,
		104 => ErrorKind::InvalidInput,
		105 => ErrorKind::InvalidData,
		106 => ErrorKind::TimedOut,
		107 => ErrorKind::Interrupted,
		108 => ErrorKind::UnexpectedEof,
		109 => ErrorKind::OutOfMemory,
		110 => ErrorKind::ConnectionReset,
		_ => ErrorKind::Other,
	};
	Error::new(kind, "(in netapp stream)")
}

fn io_error_to_u8(e: std::io::Error) -> u8 {
	use std::io::ErrorKind;
	match e.kind() {
		ErrorKind::ConnectionAborted => 101,
		ErrorKind::BrokenPipe => 102,
		ErrorKind::WouldBlock => 103,
		ErrorKind::InvalidInput => 104,
		ErrorKind::InvalidData => 105,
		ErrorKind::TimedOut => 106,
		ErrorKind::Interrupted => 107,
		ErrorKind::UnexpectedEof => 108,
		ErrorKind::OutOfMemory => 109,
		ErrorKind::ConnectionReset => 110,
		_ => 100,
	}
}

pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream {
	Box::pin(tokio_util::io::ReaderStream::new(reader).map_err(io_error_to_u8))
}

pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static {
	tokio_util::io::StreamReader::new(stream.map_err(u8_to_io_error))
}