diff options
Diffstat (limited to 'src/api/common')
-rw-r--r-- | src/api/common/signature/body.rs | 13 | ||||
-rw-r--r-- | src/api/common/signature/checksum.rs | 81 | ||||
-rw-r--r-- | src/api/common/signature/streaming.rs | 58 |
3 files changed, 84 insertions, 68 deletions
diff --git a/src/api/common/signature/body.rs b/src/api/common/signature/body.rs index 512d02b3..4279d7b5 100644 --- a/src/api/common/signature/body.rs +++ b/src/api/common/signature/body.rs @@ -17,6 +17,7 @@ pub struct ReqBody { pub(crate) stream: Mutex<BoxStream<'static, Result<Frame<Bytes>, Error>>>, pub(crate) checksummer: Checksummer, pub(crate) expected_checksums: ExpectedChecksums, + pub(crate) trailer_algorithm: Option<ChecksumAlgorithm>, } pub type StreamingChecksumReceiver = task::JoinHandle<Result<Checksums, Error>>; @@ -74,6 +75,7 @@ impl ReqBody { stream, mut checksummer, mut expected_checksums, + trailer_algorithm, } = self; let (frame_tx, mut frame_rx) = mpsc::channel::<Frame<Bytes>>(1); @@ -91,18 +93,21 @@ impl ReqBody { } Err(frame) => { let trailers = frame.into_trailers().unwrap(); - if let Some(cv) = request_checksum_value(&trailers)? { - expected_checksums.extra = Some(cv); - } + let algo = trailer_algorithm.unwrap(); + expected_checksums.extra = Some(extract_checksum_value(&trailers, algo)?); break; } } } + if trailer_algorithm.is_some() && expected_checksums.extra.is_none() { + return Err(Error::bad_request("trailing checksum was not sent")); + } + let checksums = checksummer.finalize(); checksums.verify(&expected_checksums)?; - return Ok(checksums); + Ok(checksums) }); let stream: BoxStream<_> = stream.into_inner().unwrap(); diff --git a/src/api/common/signature/checksum.rs b/src/api/common/signature/checksum.rs index 890c0452..3c5e7c53 100644 --- a/src/api/common/signature/checksum.rs +++ b/src/api/common/signature/checksum.rs @@ -12,10 +12,10 @@ use http::{HeaderMap, HeaderName, HeaderValue}; use garage_util::data::*; -use garage_model::s3::object_table::{ChecksumAlgorithm, ChecksumValue}; - use super::*; +pub use garage_model::s3::object_table::{ChecksumAlgorithm, ChecksumValue}; + pub const CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5"); pub const X_AMZ_CHECKSUM_ALGORITHM: HeaderName = @@ -198,17 +198,23 @@ impl Checksums { // ---- +pub fn parse_checksum_algorithm(algo: &str) -> Result<ChecksumAlgorithm, Error> { + match algo { + "CRC32" => Ok(ChecksumAlgorithm::Crc32), + "CRC32C" => Ok(ChecksumAlgorithm::Crc32c), + "SHA1" => Ok(ChecksumAlgorithm::Sha1), + "SHA256" => Ok(ChecksumAlgorithm::Sha256), + _ => Err(Error::bad_request("invalid checksum algorithm")), + } +} + /// Extract the value of the x-amz-checksum-algorithm header pub fn request_checksum_algorithm( headers: &HeaderMap<HeaderValue>, ) -> Result<Option<ChecksumAlgorithm>, Error> { match headers.get(X_AMZ_CHECKSUM_ALGORITHM) { None => Ok(None), - Some(x) if x == "CRC32" => Ok(Some(ChecksumAlgorithm::Crc32)), - Some(x) if x == "CRC32C" => Ok(Some(ChecksumAlgorithm::Crc32c)), - Some(x) if x == "SHA1" => Ok(Some(ChecksumAlgorithm::Sha1)), - Some(x) if x == "SHA256" => Ok(Some(ChecksumAlgorithm::Sha256)), - _ => Err(Error::bad_request("invalid checksum algorithm")), + Some(x) => parse_checksum_algorithm(x.to_str()?).map(Some), } } @@ -231,37 +237,17 @@ pub fn request_checksum_value( ) -> Result<Option<ChecksumValue>, Error> { let mut ret = vec![]; - if let Some(crc32_str) = headers.get(X_AMZ_CHECKSUM_CRC32) { - let crc32 = BASE64_STANDARD - .decode(&crc32_str) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_bad_request("invalid x-amz-checksum-crc32 header")?; - ret.push(ChecksumValue::Crc32(crc32)) + if headers.contains_key(X_AMZ_CHECKSUM_CRC32) { + ret.push(extract_checksum_value(headers, ChecksumAlgorithm::Crc32)?); } - if let Some(crc32c_str) = headers.get(X_AMZ_CHECKSUM_CRC32C) { - let crc32c = BASE64_STANDARD - .decode(&crc32c_str) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_bad_request("invalid x-amz-checksum-crc32c header")?; - ret.push(ChecksumValue::Crc32c(crc32c)) + if headers.contains_key(X_AMZ_CHECKSUM_CRC32C) { + ret.push(extract_checksum_value(headers, ChecksumAlgorithm::Crc32c)?); } - if let Some(sha1_str) = headers.get(X_AMZ_CHECKSUM_SHA1) { - let sha1 = BASE64_STANDARD - .decode(&sha1_str) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_bad_request("invalid x-amz-checksum-sha1 header")?; - ret.push(ChecksumValue::Sha1(sha1)) + if headers.contains_key(X_AMZ_CHECKSUM_SHA1) { + ret.push(extract_checksum_value(headers, ChecksumAlgorithm::Sha1)?); } - if let Some(sha256_str) = headers.get(X_AMZ_CHECKSUM_SHA256) { - let sha256 = BASE64_STANDARD - .decode(&sha256_str) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_bad_request("invalid x-amz-checksum-sha256 header")?; - ret.push(ChecksumValue::Sha256(sha256)) + if headers.contains_key(X_AMZ_CHECKSUM_SHA256) { + ret.push(extract_checksum_value(headers, ChecksumAlgorithm::Sha256)?); } if ret.len() > 1 { @@ -274,44 +260,43 @@ pub fn request_checksum_value( /// Checks for the presence of x-amz-checksum-algorithm /// if so extract the corresponding x-amz-checksum-* value -pub fn request_checksum_algorithm_value( +pub fn extract_checksum_value( headers: &HeaderMap<HeaderValue>, -) -> Result<Option<ChecksumValue>, Error> { - match headers.get(X_AMZ_CHECKSUM_ALGORITHM) { - Some(x) if x == "CRC32" => { + algo: ChecksumAlgorithm, +) -> Result<ChecksumValue, Error> { + match algo { + ChecksumAlgorithm::Crc32 => { let crc32 = headers .get(X_AMZ_CHECKSUM_CRC32) .and_then(|x| BASE64_STANDARD.decode(&x).ok()) .and_then(|x| x.try_into().ok()) .ok_or_bad_request("invalid x-amz-checksum-crc32 header")?; - Ok(Some(ChecksumValue::Crc32(crc32))) + Ok(ChecksumValue::Crc32(crc32)) } - Some(x) if x == "CRC32C" => { + ChecksumAlgorithm::Crc32c => { let crc32c = headers .get(X_AMZ_CHECKSUM_CRC32C) .and_then(|x| BASE64_STANDARD.decode(&x).ok()) .and_then(|x| x.try_into().ok()) .ok_or_bad_request("invalid x-amz-checksum-crc32c header")?; - Ok(Some(ChecksumValue::Crc32c(crc32c))) + Ok(ChecksumValue::Crc32c(crc32c)) } - Some(x) if x == "SHA1" => { + ChecksumAlgorithm::Sha1 => { let sha1 = headers .get(X_AMZ_CHECKSUM_SHA1) .and_then(|x| BASE64_STANDARD.decode(&x).ok()) .and_then(|x| x.try_into().ok()) .ok_or_bad_request("invalid x-amz-checksum-sha1 header")?; - Ok(Some(ChecksumValue::Sha1(sha1))) + Ok(ChecksumValue::Sha1(sha1)) } - Some(x) if x == "SHA256" => { + ChecksumAlgorithm::Sha256 => { let sha256 = headers .get(X_AMZ_CHECKSUM_SHA256) .and_then(|x| BASE64_STANDARD.decode(&x).ok()) .and_then(|x| x.try_into().ok()) .ok_or_bad_request("invalid x-amz-checksum-sha256 header")?; - Ok(Some(ChecksumValue::Sha256(sha256))) + Ok(ChecksumValue::Sha256(sha256)) } - Some(_) => Err(Error::bad_request("invalid x-amz-checksum-algorithm")), - None => Ok(None), } } diff --git a/src/api/common/signature/streaming.rs b/src/api/common/signature/streaming.rs index 75f3bf80..70b6e004 100644 --- a/src/api/common/signature/streaming.rs +++ b/src/api/common/signature/streaming.rs @@ -5,7 +5,7 @@ use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; use futures::prelude::*; use futures::task; use hmac::Mac; -use http::header::{HeaderValue, CONTENT_ENCODING}; +use http::header::{HeaderMap, HeaderValue, CONTENT_ENCODING}; use hyper::body::{Bytes, Frame, Incoming as IncomingBody}; use hyper::Request; @@ -64,10 +64,16 @@ pub fn parse_streaming_body( } // If trailer header is announced, add the calculation of the requested checksum - if trailer { - let algo = request_trailer_checksum_algorithm(req.headers())?; + let trailer_algorithm = if trailer { + let algo = Some( + request_trailer_checksum_algorithm(req.headers())? + .ok_or_bad_request("Missing x-amz-trailer header")?, + ); checksummer = checksummer.add(algo); - } + algo + } else { + None + }; // For signed variants, determine signing parameters let sign_params = if signed { @@ -123,6 +129,7 @@ pub fn parse_streaming_body( stream: Mutex::new(signed_payload_stream.boxed()), checksummer, expected_checksums, + trailer_algorithm, } })) } @@ -132,6 +139,7 @@ pub fn parse_streaming_body( stream: Mutex::new(stream.boxed()), checksummer, expected_checksums, + trailer_algorithm: None, } })), } @@ -185,6 +193,8 @@ fn compute_streaming_trailer_signature( } mod payload { + use http::{HeaderName, HeaderValue}; + use garage_util::data::Hash; use nom::bytes::streaming::{tag, take_while}; @@ -252,19 +262,21 @@ mod payload { #[derive(Debug, Clone)] pub struct TrailerChunk { - pub header_name: Vec<u8>, - pub header_value: Vec<u8>, + pub header_name: HeaderName, + pub header_value: HeaderValue, pub signature: Option<Hash>, } impl TrailerChunk { fn parse_content(input: &[u8]) -> nom::IResult<&[u8], Self, Error<&[u8]>> { - let (input, header_name) = try_parse!(take_while( - |c: u8| c.is_ascii_alphanumeric() || c == b'-' + let (input, header_name) = try_parse!(map_res( + take_while(|c: u8| c.is_ascii_alphanumeric() || c == b'-'), + HeaderName::from_bytes )(input)); let (input, _) = try_parse!(tag(b":")(input)); - let (input, header_value) = try_parse!(take_while( - |c: u8| c.is_ascii_alphanumeric() || b"+/=".contains(&c) + let (input, header_value) = try_parse!(map_res( + take_while(|c: u8| c.is_ascii_alphanumeric() || b"+/=".contains(&c)), + HeaderValue::from_bytes )(input)); // Possible '\n' after the header value, depends on clients @@ -276,8 +288,8 @@ mod payload { Ok(( input, TrailerChunk { - header_name: header_name.to_vec(), - header_value: header_value.to_vec(), + header_name, + header_value, signature: None, }, )) @@ -371,6 +383,7 @@ where buf: bytes::BytesMut, signing: Option<SignParams>, has_trailer: bool, + done: bool, } impl<S> StreamingPayloadStream<S> @@ -383,6 +396,7 @@ where buf: bytes::BytesMut::new(), signing, has_trailer, + done: false, } } @@ -448,6 +462,10 @@ where let mut this = self.project(); + if *this.done { + return Poll::Ready(None); + } + loop { let (input, payload) = match Self::parse_next(this.buf, this.signing.is_some(), *this.has_trailer) { @@ -499,17 +517,23 @@ where if data.is_empty() { // if there was a trailer, it would have been returned by the parser assert!(!*this.has_trailer); + *this.done = true; return Poll::Ready(None); } return Poll::Ready(Some(Ok(Frame::data(data)))); } StreamingPayloadChunk::Trailer(trailer) => { + trace!( + "In StreamingPayloadStream::poll_next: got trailer {:?}", + trailer + ); + if let Some(signing) = this.signing.as_mut() { let data = [ - &trailer.header_name[..], + trailer.header_name.as_ref(), &b":"[..], - &trailer.header_value[..], + trailer.header_value.as_ref(), &b"\n"[..], ] .concat(); @@ -529,10 +553,12 @@ where } *this.buf = input.into(); + *this.done = true; - // TODO: handle trailer + let mut trailers_map = HeaderMap::new(); + trailers_map.insert(trailer.header_name, trailer.header_value); - return Poll::Ready(None); + return Poll::Ready(Some(Ok(Frame::trailers(trailers_map)))); } } } |