diff options
author | Alex Auvolat <alex@adnab.me> | 2023-04-25 12:34:26 +0200 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2023-04-25 12:34:26 +0200 |
commit | fa78d806e3ae40031e80eebb86e4eb1756d7baea (patch) | |
tree | 144662fb430c484093f6f9a585a2441c2ff26494 /src/api | |
parent | 654999e254e6c1f46bb5d668bc1230f226575716 (diff) | |
parent | a16eb7e4b8344d2f58c09a249b7b1bd17d339a35 (diff) | |
download | garage-fa78d806e3ae40031e80eebb86e4eb1756d7baea.tar.gz garage-fa78d806e3ae40031e80eebb86e4eb1756d7baea.zip |
Merge branch 'main' into next
Diffstat (limited to 'src/api')
-rw-r--r-- | src/api/Cargo.toml | 26 | ||||
-rw-r--r-- | src/api/admin/api_server.rs | 56 | ||||
-rw-r--r-- | src/api/admin/bucket.rs | 4 | ||||
-rw-r--r-- | src/api/admin/cluster.rs | 2 | ||||
-rw-r--r-- | src/api/admin/router.rs | 3 | ||||
-rw-r--r-- | src/api/generic_server.rs | 16 | ||||
-rw-r--r-- | src/api/k2v/api_server.rs | 3 | ||||
-rw-r--r-- | src/api/k2v/batch.rs | 96 | ||||
-rw-r--r-- | src/api/k2v/error.rs | 2 | ||||
-rw-r--r-- | src/api/k2v/index.rs | 9 | ||||
-rw-r--r-- | src/api/k2v/item.rs | 19 | ||||
-rw-r--r-- | src/api/k2v/router.rs | 8 | ||||
-rw-r--r-- | src/api/s3/bucket.rs | 2 | ||||
-rw-r--r-- | src/api/s3/error.rs | 2 | ||||
-rw-r--r-- | src/api/s3/list.rs | 11 | ||||
-rw-r--r-- | src/api/s3/post_object.rs | 5 | ||||
-rw-r--r-- | src/api/s3/put.rs | 96 | ||||
-rw-r--r-- | src/api/signature/error.rs | 2 |
18 files changed, 268 insertions, 94 deletions
diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index dba0bbef..9babec02 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "garage_api" -version = "0.8.1" +version = "0.8.2" authors = ["Alex Auvolat <alex@adnab.me>"] edition = "2018" license = "AGPL-3.0" @@ -14,35 +14,35 @@ path = "lib.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -garage_model = { version = "0.8.1", path = "../model" } -garage_table = { version = "0.8.1", path = "../table" } -garage_block = { version = "0.8.1", path = "../block" } -garage_util = { version = "0.8.1", path = "../util" } -garage_rpc = { version = "0.8.1", path = "../rpc" } +garage_model = { version = "0.8.2", path = "../model" } +garage_table = { version = "0.8.2", path = "../table" } +garage_block = { version = "0.8.2", path = "../block" } +garage_util = { version = "0.8.2", path = "../util" } +garage_rpc = { version = "0.8.2", path = "../rpc" } async-trait = "0.1.7" -base64 = "0.13" +base64 = "0.21" bytes = "1.0" chrono = "0.4" crypto-common = "0.1" err-derive = "0.3" hex = "0.4" hmac = "0.12" -idna = "0.2" -tracing = "0.1.30" +idna = "0.3" +tracing = "0.1" md-5 = "0.10" nom = "7.1" sha2 = "0.10" futures = "0.3" futures-util = "0.3" -pin-project = "1.0.11" +pin-project = "1.0.12" tokio = { version = "1.0", default-features = false, features = ["rt", "rt-multi-thread", "io-util", "net", "time", "macros", "sync", "signal", "fs"] } tokio-stream = "0.1" form_urlencoded = "1.0.0" http = "0.2" -httpdate = "0.3" +httpdate = "1.0" http-range = "0.1" hyper = { version = "0.14", features = ["server", "http1", "runtime", "tcp", "stream"] } multer = "2.0" @@ -51,8 +51,8 @@ roxmltree = "0.14" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" serde_json = "1.0" -quick-xml = { version = "0.21", features = [ "serialize" ] } -url = "2.1" +quick-xml = { version = "0.26", features = [ "serialize" ] } +url = "2.3" opentelemetry = "0.17" opentelemetry-prometheus = { version = "0.10", optional = true } diff --git a/src/api/admin/api_server.rs b/src/api/admin/api_server.rs index 2d325fb1..58dd38d8 100644 --- a/src/api/admin/api_server.rs +++ b/src/api/admin/api_server.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; @@ -77,6 +78,60 @@ impl AdminApiServer { .body(Body::empty())?) } + async fn handle_check_website_enabled( + &self, + req: Request<Body>, + ) -> Result<Response<Body>, Error> { + let query_params: HashMap<String, String> = req + .uri() + .query() + .map(|v| { + url::form_urlencoded::parse(v.as_bytes()) + .into_owned() + .collect() + }) + .unwrap_or_else(HashMap::new); + + let has_domain_key = query_params.contains_key("domain"); + + if !has_domain_key { + return Err(Error::bad_request("No domain query string found")); + } + + let domain = query_params + .get("domain") + .ok_or_internal_error("Could not parse domain query string")?; + + let bucket_id = self + .garage + .bucket_helper() + .resolve_global_bucket_name(&domain) + .await? + .ok_or(HelperError::NoSuchBucket(domain.to_string()))?; + + let bucket = self + .garage + .bucket_helper() + .get_existing_bucket(bucket_id) + .await?; + + let bucket_state = bucket.state.as_option().unwrap(); + let bucket_website_config = bucket_state.website_config.get(); + + match bucket_website_config { + Some(_v) => { + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(format!( + "Bucket '{domain}' is authorized for website hosting" + )))?) + } + None => Err(Error::bad_request(format!( + "Bucket '{domain}' is not authorized for website hosting" + ))), + } + } + fn handle_health(&self) -> Result<Response<Body>, Error> { let health = self.garage.system.health(); @@ -174,6 +229,7 @@ impl ApiHandler for AdminApiServer { match endpoint { Endpoint::Options => self.handle_options(&req), + Endpoint::CheckWebsiteEnabled => self.handle_check_website_enabled(req).await, Endpoint::Health => self.handle_health(), Endpoint::Metrics => self.handle_metrics(), Endpoint::GetClusterStatus => handle_get_cluster_status(&self.garage).await, diff --git a/src/api/admin/bucket.rs b/src/api/admin/bucket.rs index 65034852..e60f07ca 100644 --- a/src/api/admin/bucket.rs +++ b/src/api/admin/bucket.rs @@ -167,7 +167,7 @@ async fn bucket_info_results( let quotas = state.quotas.get(); let res = GetBucketInfoResult { - id: hex::encode(&bucket.id), + id: hex::encode(bucket.id), global_aliases: state .aliases .items() @@ -575,6 +575,6 @@ pub async fn handle_local_unalias_bucket( // ---- HELPER ---- fn parse_bucket_id(id: &str) -> Result<Uuid, Error> { - let id_hex = hex::decode(&id).ok_or_bad_request("Invalid bucket id")?; + let id_hex = hex::decode(id).ok_or_bad_request("Invalid bucket id")?; Ok(Uuid::try_from(&id_hex).ok_or_bad_request("Invalid bucket id")?) } diff --git a/src/api/admin/cluster.rs b/src/api/admin/cluster.rs index 540c6009..b2508d2e 100644 --- a/src/api/admin/cluster.rs +++ b/src/api/admin/cluster.rs @@ -20,6 +20,7 @@ pub async fn handle_get_cluster_status(garage: &Arc<Garage>) -> Result<Response< node: hex::encode(garage.system.id), garage_version: garage_util::version::garage_version(), garage_features: garage_util::version::garage_features(), + rust_version: garage_util::version::rust_version(), db_engine: garage.db.engine(), known_nodes: garage .system @@ -106,6 +107,7 @@ struct GetClusterStatusResponse { node: String, garage_version: &'static str, garage_features: Option<&'static [&'static str]>, + rust_version: &'static str, db_engine: String, known_nodes: HashMap<String, KnownNodeResp>, layout: GetClusterLayoutResponse, diff --git a/src/api/admin/router.rs b/src/api/admin/router.rs index 62e6abc3..0dcb1546 100644 --- a/src/api/admin/router.rs +++ b/src/api/admin/router.rs @@ -17,6 +17,7 @@ router_match! {@func #[derive(Debug, Clone, PartialEq, Eq)] pub enum Endpoint { Options, + CheckWebsiteEnabled, Health, Metrics, GetClusterStatus, @@ -91,6 +92,7 @@ impl Endpoint { let res = router_match!(@gen_path_parser (req.method(), path, query) [ OPTIONS _ => Options, + GET "/check" => CheckWebsiteEnabled, GET "/health" => Health, GET "/metrics" => Metrics, GET "/v0/status" => GetClusterStatus, @@ -136,6 +138,7 @@ impl Endpoint { pub fn authorization_type(&self) -> Authorization { match self { Self::Health => Authorization::None, + Self::CheckWebsiteEnabled => Authorization::None, Self::Metrics => Authorization::MetricsToken, _ => Authorization::AdminToken, } diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs index 62fe4e5a..d0354d28 100644 --- a/src/api/generic_server.rs +++ b/src/api/generic_server.rs @@ -19,6 +19,7 @@ use opentelemetry::{ }; use garage_util::error::Error as GarageError; +use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; pub(crate) trait ApiEndpoint: Send + Sync + 'static { @@ -125,7 +126,20 @@ impl<A: ApiHandler> ApiServer<A> { addr: SocketAddr, ) -> Result<Response<Body>, GarageError> { let uri = req.uri().clone(); - info!("{} {} {}", addr, req.method(), uri); + + if let Ok(forwarded_for_ip_addr) = + forwarded_headers::handle_forwarded_for_headers(&req.headers()) + { + info!( + "{} (via {}) {} {}", + forwarded_for_ip_addr, + addr, + req.method(), + uri + ); + } else { + info!("{} {} {}", addr, req.method(), uri); + } debug!("{:?}", req); let tracer = opentelemetry::global::tracer("garage"); diff --git a/src/api/k2v/api_server.rs b/src/api/k2v/api_server.rs index 084867b5..bb85b2e7 100644 --- a/src/api/k2v/api_server.rs +++ b/src/api/k2v/api_server.rs @@ -164,6 +164,9 @@ impl ApiHandler for K2VApiServer { Endpoint::InsertBatch {} => handle_insert_batch(garage, bucket_id, req).await, Endpoint::ReadBatch {} => handle_read_batch(garage, bucket_id, req).await, Endpoint::DeleteBatch {} => handle_delete_batch(garage, bucket_id, req).await, + Endpoint::PollRange { partition_key } => { + handle_poll_range(garage, bucket_id, &partition_key, req).await + } Endpoint::Options => unreachable!(), }; diff --git a/src/api/k2v/batch.rs b/src/api/k2v/batch.rs index 78035362..26d678da 100644 --- a/src/api/k2v/batch.rs +++ b/src/api/k2v/batch.rs @@ -1,10 +1,10 @@ use std::sync::Arc; +use base64::prelude::*; use hyper::{Body, Request, Response, StatusCode}; use serde::{Deserialize, Serialize}; use garage_util::data::*; -use garage_util::error::Error as GarageError; use garage_table::{EnumerationOrder, TableSchema}; @@ -25,15 +25,13 @@ pub async fn handle_insert_batch( let mut items2 = vec![]; for it in items { - let ct = it - .ct - .map(|s| CausalContext::parse(&s)) - .transpose() - .ok_or_bad_request("Invalid causality token")?; + let ct = it.ct.map(|s| CausalContext::parse_helper(&s)).transpose()?; let v = match it.v { - Some(vs) => { - DvvsValue::Value(base64::decode(vs).ok_or_bad_request("Invalid base64 value")?) - } + Some(vs) => DvvsValue::Value( + BASE64_STANDARD + .decode(vs) + .ok_or_bad_request("Invalid base64 value")?, + ), None => DvvsValue::Deleted, }; items2.push((it.pk, it.sk, ct, v)); @@ -65,10 +63,7 @@ pub async fn handle_read_batch( resps.push(resp?); } - let resp_json = serde_json::to_string_pretty(&resps).map_err(GarageError::from)?; - Ok(Response::builder() - .status(StatusCode::OK) - .body(Body::from(resp_json))?) + Ok(json_ok_response(&resps)?) } async fn handle_read_batch_query( @@ -160,10 +155,7 @@ pub async fn handle_delete_batch( resps.push(resp?); } - let resp_json = serde_json::to_string_pretty(&resps).map_err(GarageError::from)?; - Ok(Response::builder() - .status(StatusCode::OK) - .body(Body::from(resp_json))?) + Ok(json_ok_response(&resps)?) } async fn handle_delete_batch_query( @@ -257,6 +249,53 @@ async fn handle_delete_batch_query( }) } +pub(crate) async fn handle_poll_range( + garage: Arc<Garage>, + bucket_id: Uuid, + partition_key: &str, + req: Request<Body>, +) -> Result<Response<Body>, Error> { + use garage_model::k2v::sub::PollRange; + + let query = parse_json_body::<PollRangeQuery>(req).await?; + + let timeout_msec = query.timeout.unwrap_or(300).clamp(1, 600) * 1000; + + let resp = garage + .k2v + .rpc + .poll_range( + PollRange { + partition: K2VItemPartition { + bucket_id, + partition_key: partition_key.to_string(), + }, + start: query.start, + end: query.end, + prefix: query.prefix, + }, + query.seen_marker, + timeout_msec, + ) + .await?; + + if let Some((items, seen_marker)) = resp { + let resp = PollRangeResponse { + items: items + .into_iter() + .map(|(_k, i)| ReadBatchResponseItem::from(i)) + .collect::<Vec<_>>(), + seen_marker, + }; + + Ok(json_ok_response(&resp)?) + } else { + Ok(Response::builder() + .status(StatusCode::NOT_MODIFIED) + .body(Body::empty())?) + } +} + #[derive(Deserialize)] struct InsertBatchItem { pk: String, @@ -322,7 +361,7 @@ impl ReadBatchResponseItem { .values() .iter() .map(|v| match v { - DvvsValue::Value(x) => Some(base64::encode(x)), + DvvsValue::Value(x) => Some(BASE64_STANDARD.encode(x)), DvvsValue::Deleted => None, }) .collect::<Vec<_>>(); @@ -361,3 +400,24 @@ struct DeleteBatchResponse { #[serde(rename = "deletedItems")] deleted_items: usize, } + +#[derive(Deserialize)] +struct PollRangeQuery { + #[serde(default)] + prefix: Option<String>, + #[serde(default)] + start: Option<String>, + #[serde(default)] + end: Option<String>, + #[serde(default)] + timeout: Option<u64>, + #[serde(default, rename = "seenMarker")] + seen_marker: Option<String>, +} + +#[derive(Serialize)] +struct PollRangeResponse { + items: Vec<ReadBatchResponseItem>, + #[serde(rename = "seenMarker")] + seen_marker: String, +} diff --git a/src/api/k2v/error.rs b/src/api/k2v/error.rs index 42491466..4eb017ab 100644 --- a/src/api/k2v/error.rs +++ b/src/api/k2v/error.rs @@ -19,7 +19,7 @@ pub enum Error { // Category: cannot process /// Authorization Header Malformed - #[error(display = "Authorization header malformed, expected scope: {}", _0)] + #[error(display = "Authorization header malformed, unexpected scope: {}", _0)] AuthorizationHeaderMalformed(String), /// The object requested don't exists diff --git a/src/api/k2v/index.rs b/src/api/k2v/index.rs index 210950bf..6c1d4a91 100644 --- a/src/api/k2v/index.rs +++ b/src/api/k2v/index.rs @@ -1,10 +1,9 @@ use std::sync::Arc; -use hyper::{Body, Response, StatusCode}; +use hyper::{Body, Response}; use serde::Serialize; use garage_util::data::*; -use garage_util::error::Error as GarageError; use garage_rpc::ring::Ring; use garage_table::util::*; @@ -12,6 +11,7 @@ use garage_table::util::*; use garage_model::garage::Garage; use garage_model::k2v::item_table::{BYTES, CONFLICTS, ENTRIES, VALUES}; +use crate::helpers::*; use crate::k2v::error::*; use crate::k2v::range::read_range; @@ -68,10 +68,7 @@ pub async fn handle_read_index( next_start, }; - let resp_json = serde_json::to_string_pretty(&resp).map_err(GarageError::from)?; - Ok(Response::builder() - .status(StatusCode::OK) - .body(Body::from(resp_json))?) + Ok(json_ok_response(&resp)?) } #[derive(Serialize)] diff --git a/src/api/k2v/item.rs b/src/api/k2v/item.rs index f85138c7..e13a0f30 100644 --- a/src/api/k2v/item.rs +++ b/src/api/k2v/item.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use base64::prelude::*; use http::header; use hyper::{Body, Request, Response, StatusCode}; @@ -81,7 +82,7 @@ impl ReturnFormat { .iter() .map(|v| match v { DvvsValue::Deleted => serde_json::Value::Null, - DvvsValue::Value(v) => serde_json::Value::String(base64::encode(v)), + DvvsValue::Value(v) => serde_json::Value::String(BASE64_STANDARD.encode(v)), }) .collect::<Vec<_>>(); let json_body = @@ -133,9 +134,8 @@ pub async fn handle_insert_item( .get(X_GARAGE_CAUSALITY_TOKEN) .map(|s| s.to_str()) .transpose()? - .map(CausalContext::parse) - .transpose() - .ok_or_bad_request("Invalid causality token")?; + .map(CausalContext::parse_helper) + .transpose()?; let body = hyper::body::to_bytes(req.into_body()).await?; let value = DvvsValue::Value(body.to_vec()); @@ -169,9 +169,8 @@ pub async fn handle_delete_item( .get(X_GARAGE_CAUSALITY_TOKEN) .map(|s| s.to_str()) .transpose()? - .map(CausalContext::parse) - .transpose() - .ok_or_bad_request("Invalid causality token")?; + .map(CausalContext::parse_helper) + .transpose()?; let value = DvvsValue::Deleted; @@ -208,15 +207,17 @@ pub async fn handle_poll_item( let causal_context = CausalContext::parse(&causality_token).ok_or_bad_request("Invalid causality token")?; + let timeout_msec = timeout_secs.unwrap_or(300).clamp(1, 600) * 1000; + let item = garage .k2v .rpc - .poll( + .poll_item( bucket_id, partition_key, sort_key, causal_context, - timeout_secs.unwrap_or(300) * 1000, + timeout_msec, ) .await?; diff --git a/src/api/k2v/router.rs b/src/api/k2v/router.rs index e7a3dd69..1cc58be5 100644 --- a/src/api/k2v/router.rs +++ b/src/api/k2v/router.rs @@ -32,6 +32,9 @@ pub enum Endpoint { causality_token: String, timeout: Option<u64>, }, + PollRange { + partition_key: String, + }, ReadBatch { }, ReadIndex { @@ -113,6 +116,7 @@ impl Endpoint { @gen_parser (query.keyword.take().unwrap_or_default(), partition_key, query, None), key: [ + POLL_RANGE => PollRange, ], no_key: [ EMPTY => ReadBatch, @@ -142,6 +146,7 @@ impl Endpoint { @gen_parser (query.keyword.take().unwrap_or_default(), partition_key, query, None), key: [ + POLL_RANGE => PollRange, ], no_key: [ EMPTY => InsertBatch, @@ -234,7 +239,8 @@ impl Endpoint { generateQueryParameters! { keywords: [ "delete" => DELETE, - "search" => SEARCH + "search" => SEARCH, + "poll_range" => POLL_RANGE ], fields: [ "prefix" => prefix, diff --git a/src/api/s3/bucket.rs b/src/api/s3/bucket.rs index 8471385f..733981e1 100644 --- a/src/api/s3/bucket.rs +++ b/src/api/s3/bucket.rs @@ -305,7 +305,7 @@ fn parse_create_bucket_xml(xml_bytes: &[u8]) -> Option<Option<String>> { let mut ret = None; for item in cbc.children() { if item.has_tag_name("LocationConstraint") { - if ret != None { + if ret.is_some() { return None; } ret = Some(item.text()?.to_string()); diff --git a/src/api/s3/error.rs b/src/api/s3/error.rs index 67009d63..c50cff9f 100644 --- a/src/api/s3/error.rs +++ b/src/api/s3/error.rs @@ -21,7 +21,7 @@ pub enum Error { // Category: cannot process /// Authorization Header Malformed - #[error(display = "Authorization header malformed, expected scope: {}", _0)] + #[error(display = "Authorization header malformed, unexpected scope: {}", _0)] AuthorizationHeaderMalformed(String), /// The object requested don't exists diff --git a/src/api/s3/list.rs b/src/api/s3/list.rs index e5f486c8..5cb0d65a 100644 --- a/src/api/s3/list.rs +++ b/src/api/s3/list.rs @@ -3,6 +3,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::iter::{Iterator, Peekable}; use std::sync::Arc; +use base64::prelude::*; use hyper::{Body, Response}; use garage_util::data::*; @@ -129,11 +130,11 @@ pub async fn handle_list( next_continuation_token: match (query.is_v2, &pagination) { (true, Some(RangeBegin::AfterKey { key })) => Some(s3_xml::Value(format!( "]{}", - base64::encode(key.as_bytes()) + BASE64_STANDARD.encode(key.as_bytes()) ))), (true, Some(RangeBegin::IncludingKey { key, .. })) => Some(s3_xml::Value(format!( "[{}", - base64::encode(key.as_bytes()) + BASE64_STANDARD.encode(key.as_bytes()) ))), _ => None, }, @@ -583,14 +584,16 @@ impl ListObjectsQuery { (Some(token), _) => match &token[..1] { "[" => Ok(RangeBegin::IncludingKey { key: String::from_utf8( - base64::decode(token[1..].as_bytes()) + BASE64_STANDARD + .decode(token[1..].as_bytes()) .ok_or_bad_request("Invalid continuation token")?, )?, fallback_key: None, }), "]" => Ok(RangeBegin::AfterKey { key: String::from_utf8( - base64::decode(token[1..].as_bytes()) + BASE64_STANDARD + .decode(token[1..].as_bytes()) .ok_or_bad_request("Invalid continuation token")?, )?, }), diff --git a/src/api/s3/post_object.rs b/src/api/s3/post_object.rs index d063faa4..f2098ab0 100644 --- a/src/api/s3/post_object.rs +++ b/src/api/s3/post_object.rs @@ -4,6 +4,7 @@ use std::ops::RangeInclusive; use std::sync::Arc; use std::task::{Context, Poll}; +use base64::prelude::*; use bytes::Bytes; use chrono::{DateTime, Duration, Utc}; use futures::{Stream, StreamExt}; @@ -138,7 +139,9 @@ pub async fn handle_post_object( .get_existing_bucket(bucket_id) .await?; - let decoded_policy = base64::decode(&policy).ok_or_bad_request("Invalid policy")?; + let decoded_policy = BASE64_STANDARD + .decode(policy) + .ok_or_bad_request("Invalid policy")?; let decoded_policy: Policy = serde_json::from_slice(&decoded_policy).ok_or_bad_request("Invalid policy")?; diff --git a/src/api/s3/put.rs b/src/api/s3/put.rs index 97b8e4e3..350ab884 100644 --- a/src/api/s3/put.rs +++ b/src/api/s3/put.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::sync::Arc; +use base64::prelude::*; use futures::prelude::*; use hyper::body::{Body, Bytes}; use hyper::header::{HeaderMap, HeaderValue}; @@ -119,6 +120,17 @@ pub(crate) async fn save_stream<S: Stream<Item = Result<Bytes, Error>> + Unpin>( return Ok((version_uuid, data_md5sum_hex)); } + // The following consists in many steps that can each fail. + // Keep track that some cleanup will be needed if things fail + // before everything is finished (cleanup is done using the Drop trait). + let mut interrupted_cleanup = InterruptedCleanup(Some(( + garage.clone(), + bucket.id, + key.into(), + version_uuid, + version_timestamp, + ))); + // Write version identifier in object table so that we have a trace // that we are uploading something let mut object_version = ObjectVersion { @@ -139,44 +151,27 @@ pub(crate) async fn save_stream<S: Stream<Item = Result<Bytes, Error>> + Unpin>( // Transfer data and verify checksum let first_block_hash = async_blake2sum(first_block.clone()).await; - let tx_result = (|| async { - let (total_size, data_md5sum, data_sha256sum) = read_and_put_blocks( - &garage, - &version, - 1, - first_block, - first_block_hash, - &mut chunker, - ) - .await?; - - ensure_checksum_matches( - data_md5sum.as_slice(), - data_sha256sum, - content_md5.as_deref(), - content_sha256, - )?; - - check_quotas(&garage, bucket, key, total_size).await?; + let (total_size, data_md5sum, data_sha256sum) = read_and_put_blocks( + &garage, + &version, + 1, + first_block, + first_block_hash, + &mut chunker, + ) + .await?; - Ok((total_size, data_md5sum)) - })() - .await; + ensure_checksum_matches( + data_md5sum.as_slice(), + data_sha256sum, + content_md5.as_deref(), + content_sha256, + )?; - // If something went wrong, clean up - let (total_size, md5sum_arr) = match tx_result { - Ok(rv) => rv, - Err(e) => { - // Mark object as aborted, this will free the blocks further down - object_version.state = ObjectVersionState::Aborted; - let object = Object::new(bucket.id, key.into(), vec![object_version.clone()]); - garage.object_table.insert(&object).await?; - return Err(e); - } - }; + check_quotas(&garage, bucket, key, total_size).await?; // Save final object state, marked as Complete - let md5sum_hex = hex::encode(md5sum_arr); + let md5sum_hex = hex::encode(data_md5sum); object_version.state = ObjectVersionState::Complete(ObjectVersionData::FirstBlock( ObjectVersionMeta { headers, @@ -188,6 +183,10 @@ pub(crate) async fn save_stream<S: Stream<Item = Result<Bytes, Error>> + Unpin>( let object = Object::new(bucket.id, key.into(), vec![object_version]); garage.object_table.insert(&object).await?; + // We were not interrupted, everything went fine. + // We won't have to clean up on drop. + interrupted_cleanup.cancel(); + Ok((version_uuid, md5sum_hex)) } @@ -209,7 +208,7 @@ fn ensure_checksum_matches( } } if let Some(expected_md5) = content_md5 { - if expected_md5.trim_matches('"') != base64::encode(data_md5sum) { + if expected_md5.trim_matches('"') != BASE64_STANDARD.encode(data_md5sum) { return Err(Error::bad_request("Unable to validate content-md5")); } else { trace!("Successfully validated content-md5"); @@ -426,6 +425,33 @@ pub fn put_response(version_uuid: Uuid, md5sum_hex: String) -> Response<Body> { .unwrap() } +struct InterruptedCleanup(Option<(Arc<Garage>, Uuid, String, Uuid, u64)>); + +impl InterruptedCleanup { + fn cancel(&mut self) { + drop(self.0.take()); + } +} +impl Drop for InterruptedCleanup { + fn drop(&mut self) { + if let Some((garage, bucket_id, key, version_uuid, version_ts)) = self.0.take() { + tokio::spawn(async move { + let object_version = ObjectVersion { + uuid: version_uuid, + timestamp: version_ts, + state: ObjectVersionState::Aborted, + }; + let object = Object::new(bucket_id, key, vec![object_version]); + if let Err(e) = garage.object_table.insert(&object).await { + warn!("Cannot cleanup after aborted PutObject: {}", e); + } + }); + } + } +} + +// ---- + pub async fn handle_create_multipart_upload( garage: Arc<Garage>, req: &Request<Body>, diff --git a/src/api/signature/error.rs b/src/api/signature/error.rs index f5a067bd..f0d7c816 100644 --- a/src/api/signature/error.rs +++ b/src/api/signature/error.rs @@ -11,7 +11,7 @@ pub enum Error { Common(CommonError), /// Authorization Header Malformed - #[error(display = "Authorization header malformed, expected scope: {}", _0)] + #[error(display = "Authorization header malformed, unexpected scope: {}", _0)] AuthorizationHeaderMalformed(String), // Category: bad request |