aboutsummaryrefslogtreecommitdiff
path: root/src/api/common/signature/body.rs
blob: 96be0d5bc71a003bdeb3a872c2e88af8a6cf96ed (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use std::sync::Mutex;

use futures::prelude::*;
use futures::stream::BoxStream;
use http_body_util::{BodyExt, StreamBody};
use hyper::body::{Bytes, Frame};
use serde::Deserialize;
use tokio::sync::mpsc;
use tokio::task;

use super::*;

use crate::signature::checksum::*;

pub struct ReqBody {
	// why need mutex to be sync??
	pub(crate) stream: Mutex<BoxStream<'static, Result<Frame<Bytes>, Error>>>,
	pub(crate) checksummer: Checksummer,
	pub(crate) expected_checksums: ExpectedChecksums,
	pub(crate) trailer_algorithm: Option<ChecksumAlgorithm>,
}

pub type StreamingChecksumReceiver = task::JoinHandle<Result<Checksums, Error>>;

impl ReqBody {
	pub fn add_expected_checksums(&mut self, more: ExpectedChecksums) {
		if more.md5.is_some() {
			self.expected_checksums.md5 = more.md5;
		}
		if more.sha256.is_some() {
			self.expected_checksums.sha256 = more.sha256;
		}
		if more.extra.is_some() {
			self.expected_checksums.extra = more.extra;
		}
		self.checksummer.add_expected(&self.expected_checksums);
	}

	pub fn add_md5(&mut self) {
		self.checksummer.add_md5();
	}

	// ============ non-streaming =============

	pub async fn json<T: for<'a> Deserialize<'a>>(self) -> Result<T, Error> {
		let body = self.collect().await?;
		let resp: T = serde_json::from_slice(&body).ok_or_bad_request("Invalid JSON")?;
		Ok(resp)
	}

	pub async fn collect(self) -> Result<Bytes, Error> {
		self.collect_with_checksums().await.map(|(b, _)| b)
	}

	pub async fn collect_with_checksums(mut self) -> Result<(Bytes, Checksums), Error> {
		let stream: BoxStream<_> = self.stream.into_inner().unwrap();
		let bytes = BodyExt::collect(StreamBody::new(stream)).await?.to_bytes();

		self.checksummer.update(&bytes);
		let checksums = self.checksummer.finalize();
		checksums.verify(&self.expected_checksums)?;

		Ok((bytes, checksums))
	}

	// ============ streaming =============

	pub fn streaming_with_checksums(
		self,
	) -> (
		BoxStream<'static, Result<Bytes, Error>>,
		StreamingChecksumReceiver,
	) {
		let Self {
			stream,
			mut checksummer,
			mut expected_checksums,
			trailer_algorithm,
		} = self;

		let (frame_tx, mut frame_rx) = mpsc::channel::<Frame<Bytes>>(5);

		let join_checksums = tokio::spawn(async move {
			while let Some(frame) = frame_rx.recv().await {
				match frame.into_data() {
					Ok(data) => {
						checksummer = tokio::task::spawn_blocking(move || {
							checksummer.update(&data);
							checksummer
						})
						.await
						.unwrap()
					}
					Err(frame) => {
						let trailers = frame.into_trailers().unwrap();
						let algo = trailer_algorithm.unwrap();
						expected_checksums.extra = Some(extract_checksum_value(&trailers, algo)?);
						break;
					}
				}
			}

			if trailer_algorithm.is_some() && expected_checksums.extra.is_none() {
				return Err(Error::bad_request("trailing checksum was not sent"));
			}

			let checksums = checksummer.finalize();
			checksums.verify(&expected_checksums)?;

			Ok(checksums)
		});

		let stream: BoxStream<_> = stream.into_inner().unwrap();
		let stream = stream.filter_map(move |x| {
			let frame_tx = frame_tx.clone();
			async move {
				match x {
					Err(e) => Some(Err(e)),
					Ok(frame) => {
						if frame.is_data() {
							let data = frame.data_ref().unwrap().clone();
							let _ = frame_tx.send(frame).await;
							Some(Ok(data))
						} else {
							let _ = frame_tx.send(frame).await;
							None
						}
					}
				}
			}
		});

		(stream.boxed(), join_checksums)
	}
}