diff options
Diffstat (limited to 'src/api/s3/get.rs')
-rw-r--r-- | src/api/s3/get.rs | 168 |
1 files changed, 114 insertions, 54 deletions
diff --git a/src/api/s3/get.rs b/src/api/s3/get.rs index 5e682726..53f0a345 100644 --- a/src/api/s3/get.rs +++ b/src/api/s3/get.rs @@ -1,17 +1,20 @@ //! Function related to GET and HEAD requests +use std::convert::TryInto; use std::sync::Arc; use std::time::{Duration, UNIX_EPOCH}; use futures::future; use futures::stream::{self, StreamExt}; use http::header::{ - ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, ETAG, IF_MODIFIED_SINCE, - IF_NONE_MATCH, LAST_MODIFIED, RANGE, + ACCEPT_RANGES, CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LANGUAGE, + CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, ETAG, EXPIRES, IF_MODIFIED_SINCE, IF_NONE_MATCH, + LAST_MODIFIED, RANGE, }; -use hyper::{Body, Request, Response, StatusCode}; +use hyper::{body::Body, Request, Response, StatusCode}; use tokio::sync::mpsc; -use garage_rpc::rpc_helper::{netapp::stream::ByteStream, OrderTag}; +use garage_block::manager::BlockStream; +use garage_rpc::rpc_helper::OrderTag; use garage_table::EmptyKey; use garage_util::data::*; use garage_util::error::OkOrMessage; @@ -20,10 +23,22 @@ use garage_model::garage::Garage; use garage_model::s3::object_table::*; use garage_model::s3::version_table::*; +use crate::helpers::*; +use crate::s3::api_server::ResBody; use crate::s3::error::*; const X_AMZ_MP_PARTS_COUNT: &str = "x-amz-mp-parts-count"; +#[derive(Default)] +pub struct GetObjectOverrides { + pub(crate) response_cache_control: Option<String>, + pub(crate) response_content_disposition: Option<String>, + pub(crate) response_content_encoding: Option<String>, + pub(crate) response_content_language: Option<String>, + pub(crate) response_content_type: Option<String>, + pub(crate) response_expires: Option<String>, +} + fn object_headers( version: &ObjectVersion, version_meta: &ObjectVersionMeta, @@ -49,11 +64,37 @@ fn object_headers( resp } +/// Override headers according to specific query parameters, see +/// section "Overriding response header values through the request" in +/// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html +fn getobject_override_headers( + overrides: GetObjectOverrides, + resp: &mut http::response::Builder, +) -> Result<(), Error> { + // TODO: this only applies for signed requests, so when we support + // anonymous access in the future we will have to do a permission check here + let overrides = [ + (CACHE_CONTROL, overrides.response_cache_control), + (CONTENT_DISPOSITION, overrides.response_content_disposition), + (CONTENT_ENCODING, overrides.response_content_encoding), + (CONTENT_LANGUAGE, overrides.response_content_language), + (CONTENT_TYPE, overrides.response_content_type), + (EXPIRES, overrides.response_expires), + ]; + for (hdr, val_opt) in overrides { + if let Some(val) = val_opt { + let val = val.try_into().ok_or_bad_request("invalid header value")?; + resp.headers_mut().unwrap().insert(hdr, val); + } + } + Ok(()) +} + fn try_answer_cached( version: &ObjectVersion, version_meta: &ObjectVersionMeta, - req: &Request<Body>, -) -> Option<Response<Body>> { + req: &Request<impl Body>, +) -> Option<Response<ResBody>> { // <trinity> It is possible, and is even usually the case, [that both If-None-Match and // If-Modified-Since] are present in a request. In this situation If-None-Match takes // precedence and If-Modified-Since is ignored (as per 6.Precedence from rfc7232). The rational @@ -80,7 +121,7 @@ fn try_answer_cached( Some( Response::builder() .status(StatusCode::NOT_MODIFIED) - .body(Body::empty()) + .body(empty_body()) .unwrap(), ) } else { @@ -91,11 +132,11 @@ fn try_answer_cached( /// Handle HEAD request pub async fn handle_head( garage: Arc<Garage>, - req: &Request<Body>, + req: &Request<impl Body>, bucket_id: Uuid, key: &str, part_number: Option<u64>, -) -> Result<Response<Body>, Error> { +) -> Result<Response<ResBody>, Error> { let object = garage .object_table .get(&bucket_id, &key.to_string()) @@ -138,7 +179,7 @@ pub async fn handle_head( ) .header(X_AMZ_MP_PARTS_COUNT, "1") .status(StatusCode::PARTIAL_CONTENT) - .body(Body::empty())?) + .body(empty_body())?) } ObjectVersionData::FirstBlock(_, _) => { let version = garage @@ -163,7 +204,7 @@ pub async fn handle_head( ) .header(X_AMZ_MP_PARTS_COUNT, format!("{}", version.n_parts()?)) .status(StatusCode::PARTIAL_CONTENT) - .body(Body::empty())?) + .body(empty_body())?) } _ => unreachable!(), } @@ -171,18 +212,19 @@ pub async fn handle_head( Ok(object_headers(object_version, version_meta) .header(CONTENT_LENGTH, format!("{}", version_meta.size)) .status(StatusCode::OK) - .body(Body::empty())?) + .body(empty_body())?) } } /// Handle GET request pub async fn handle_get( garage: Arc<Garage>, - req: &Request<Body>, + req: &Request<impl Body>, bucket_id: Uuid, key: &str, part_number: Option<u64>, -) -> Result<Response<Body>, Error> { + overrides: GetObjectOverrides, +) -> Result<Response<ResBody>, Error> { let object = garage .object_table .get(&bucket_id, &key.to_string()) @@ -233,18 +275,18 @@ pub async fn handle_get( (None, None) => (), } - let resp_builder = object_headers(last_v, last_v_meta) + let mut resp_builder = object_headers(last_v, last_v_meta) .header(CONTENT_LENGTH, format!("{}", last_v_meta.size)) .status(StatusCode::OK); + getobject_override_headers(overrides, &mut resp_builder)?; match &last_v_data { ObjectVersionData::DeleteMarker => unreachable!(), ObjectVersionData::Inline(_, bytes) => { - let body: Body = Body::from(bytes.to_vec()); - Ok(resp_builder.body(body)?) + Ok(resp_builder.body(bytes_body(bytes.to_vec().into()))?) } ObjectVersionData::FirstBlock(_, first_block_hash) => { - let (tx, rx) = mpsc::channel(2); + let (tx, rx) = mpsc::channel::<BlockStream>(2); let order_stream = OrderTag::stream(); let first_block_hash = *first_block_hash; @@ -282,20 +324,12 @@ pub async fn handle_get( { Ok(()) => (), Err(e) => { - let err = std::io::Error::new( - std::io::ErrorKind::Other, - format!("Error while getting object data: {}", e), - ); - let _ = tx - .send(Box::pin(stream::once(future::ready(Err(err))))) - .await; + let _ = tx.send(error_stream_item(e)).await; } } }); - let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx).flatten(); - - let body = hyper::body::Body::wrap_stream(body_stream); + let body = response_body_from_block_stream(rx); Ok(resp_builder.body(body)?) } } @@ -308,7 +342,10 @@ async fn handle_get_range( version_meta: &ObjectVersionMeta, begin: u64, end: u64, -) -> Result<Response<Body>, Error> { +) -> Result<Response<ResBody>, Error> { + // Here we do not use getobject_override_headers because we don't + // want to add any overridden headers (those should not be added + // when returning PARTIAL_CONTENT) let resp_builder = object_headers(version, version_meta) .header(CONTENT_LENGTH, format!("{}", end - begin)) .header( @@ -321,7 +358,7 @@ async fn handle_get_range( ObjectVersionData::DeleteMarker => unreachable!(), ObjectVersionData::Inline(_meta, bytes) => { if end as usize <= bytes.len() { - let body: Body = Body::from(bytes[begin as usize..end as usize].to_vec()); + let body = bytes_body(bytes[begin as usize..end as usize].to_vec().into()); Ok(resp_builder.body(body)?) } else { Err(Error::internal_error( @@ -348,7 +385,8 @@ async fn handle_get_part( version_data: &ObjectVersionData, version_meta: &ObjectVersionMeta, part_number: u64, -) -> Result<Response<Body>, Error> { +) -> Result<Response<ResBody>, Error> { + // Same as for get_range, no getobject_override_headers let resp_builder = object_headers(object_version, version_meta).status(StatusCode::PARTIAL_CONTENT); @@ -364,7 +402,7 @@ async fn handle_get_part( format!("bytes {}-{}/{}", 0, bytes.len() - 1, bytes.len()), ) .header(X_AMZ_MP_PARTS_COUNT, "1") - .body(Body::from(bytes.to_vec()))?) + .body(bytes_body(bytes.to_vec().into()))?) } ObjectVersionData::FirstBlock(_, _) => { let version = garage @@ -392,7 +430,7 @@ async fn handle_get_part( } fn parse_range_header( - req: &Request<Body>, + req: &Request<impl Body>, total_size: u64, ) -> Result<Option<http_range::HttpRange>, Error> { let range = match req.headers().get(RANGE) { @@ -434,7 +472,7 @@ fn body_from_blocks_range( all_blocks: &[(VersionBlockKey, VersionBlock)], begin: u64, end: u64, -) -> Body { +) -> ResBody { // We will store here the list of blocks that have an intersection with the requested // range, as well as their "true offset", which is their actual offset in the complete // file (whereas block.offset designates the offset of the block WITHIN THE PART @@ -456,17 +494,17 @@ fn body_from_blocks_range( } let order_stream = OrderTag::stream(); - let body_stream = futures::stream::iter(blocks) - .enumerate() - .map(move |(i, (block, block_offset))| { + let (tx, rx) = mpsc::channel::<BlockStream>(2); + + tokio::spawn(async move { + match async { let garage = garage.clone(); - async move { - garage + for (i, (block, block_offset)) in blocks.iter().enumerate() { + let block_stream = garage .block_manager .rpc_get_block_streaming(&block.hash, Some(order_stream.order(i as u64))) - .await - .unwrap_or_else(|e| error_stream(i, e)) - .scan(block_offset, move |chunk_offset, chunk| { + .await? + .scan(*block_offset, move |chunk_offset, chunk| { let r = match chunk { Ok(chunk_bytes) => { let chunk_len = chunk_bytes.len() as u64; @@ -502,20 +540,42 @@ fn body_from_blocks_range( }; futures::future::ready(r) }) - .filter_map(futures::future::ready) + .filter_map(futures::future::ready); + + let block_stream: BlockStream = Box::pin(block_stream); + tx.send(Box::pin(block_stream)) + .await + .ok_or_message("channel closed")?; } - }) - .buffered(2) - .flatten(); - hyper::body::Body::wrap_stream(body_stream) + Ok::<(), Error>(()) + } + .await + { + Ok(()) => (), + Err(e) => { + let _ = tx.send(error_stream_item(e)).await; + } + } + }); + + response_body_from_block_stream(rx) +} + +fn response_body_from_block_stream(rx: mpsc::Receiver<BlockStream>) -> ResBody { + let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx) + .flatten() + .map(|x| { + x.map(hyper::body::Frame::data) + .map_err(|e| Error::from(garage_util::error::Error::from(e))) + }); + ResBody::new(http_body_util::StreamBody::new(body_stream)) } -fn error_stream(i: usize, e: garage_util::error::Error) -> ByteStream { - Box::pin(futures::stream::once(async move { - Err(std::io::Error::new( - std::io::ErrorKind::Other, - format!("Could not get block {}: {}", i, e), - )) - })) +fn error_stream_item<E: std::fmt::Display>(e: E) -> BlockStream { + let err = std::io::Error::new( + std::io::ErrorKind::Other, + format!("Error while getting object data: {}", e), + ); + Box::pin(stream::once(future::ready(Err(err)))) } |