use std::sync::Mutex; use futures::prelude::*; use futures::stream::BoxStream; use http_body_util::{BodyExt, StreamBody}; use hyper::body::{Bytes, Frame}; use serde::Deserialize; use tokio::sync::mpsc; use tokio::task; use opentelemetry::{ trace::{FutureExt as OtelFutureExt, TraceContextExt, Tracer}, Context, }; use super::*; use crate::signature::checksum::*; pub struct ReqBody { // why need mutex to be sync?? pub(crate) stream: Mutex, Error>>>, pub(crate) checksummer: Checksummer, pub(crate) expected_checksums: ExpectedChecksums, } pub type StreamingChecksumReceiver = task::JoinHandle>; impl ReqBody { pub fn add_expected_checksums(&mut self, more: ExpectedChecksums) { if more.md5.is_some() { self.expected_checksums.md5 = more.md5; } if more.sha256.is_some() { self.expected_checksums.sha256 = more.sha256; } if more.extra.is_some() { self.expected_checksums.extra = more.extra; } self.checksummer.add_expected(&self.expected_checksums); } pub fn add_md5(&mut self) { self.checksummer.add_md5(); } // ============ non-streaming ============= pub async fn json Deserialize<'a>>(self) -> Result { let body = self.collect().await?; let resp: T = serde_json::from_slice(&body).ok_or_bad_request("Invalid JSON")?; Ok(resp) } pub async fn collect(self) -> Result { self.collect_with_checksums().await.map(|(b, _)| b) } pub async fn collect_with_checksums(mut self) -> Result<(Bytes, Checksums), Error> { let stream: BoxStream<_> = self.stream.into_inner().unwrap(); let bytes = BodyExt::collect(StreamBody::new(stream)).await?.to_bytes(); self.checksummer.update(&bytes); let checksums = self.checksummer.finalize(); checksums.verify(&self.expected_checksums)?; Ok((bytes, checksums)) } // ============ streaming ============= pub fn streaming_with_checksums( self, ) -> ( BoxStream<'static, Result>, StreamingChecksumReceiver, ) { let Self { stream, mut checksummer, mut expected_checksums, } = self; let (frame_tx, mut frame_rx) = mpsc::channel::>(1); let join_checksums = tokio::spawn(async move { let tracer = opentelemetry::global::tracer("garage"); while let Some(frame) = frame_rx.recv().await { match frame.into_data() { Ok(data) => { checksummer = tokio::task::spawn_blocking(move || { checksummer.update(&data); checksummer }) .await .unwrap() } Err(frame) => { let trailers = frame.into_trailers().unwrap(); if let Some(cv) = request_checksum_value(&trailers)? { expected_checksums.extra = Some(cv); } break; } } } let checksums = checksummer.finalize(); checksums.verify(&expected_checksums)?; return Ok(checksums); }); let stream: BoxStream<_> = stream.into_inner().unwrap(); let stream = stream.filter_map(move |x| { let frame_tx = frame_tx.clone(); async move { match x { Err(e) => Some(Err(e)), Ok(frame) => { if frame.is_data() { let data = frame.data_ref().unwrap().clone(); let _ = frame_tx.send(frame).await; Some(Ok(data)) } else { let _ = frame_tx.send(frame).await; None } } } } }); (stream.boxed(), join_checksums) } }