aboutsummaryrefslogtreecommitdiff
path: root/src/api
diff options
context:
space:
mode:
Diffstat (limited to 'src/api')
-rw-r--r--src/api/Cargo.toml26
-rw-r--r--src/api/admin/api_server.rs56
-rw-r--r--src/api/admin/bucket.rs4
-rw-r--r--src/api/admin/cluster.rs2
-rw-r--r--src/api/admin/router.rs3
-rw-r--r--src/api/generic_server.rs16
-rw-r--r--src/api/k2v/api_server.rs3
-rw-r--r--src/api/k2v/batch.rs96
-rw-r--r--src/api/k2v/error.rs2
-rw-r--r--src/api/k2v/index.rs9
-rw-r--r--src/api/k2v/item.rs19
-rw-r--r--src/api/k2v/router.rs8
-rw-r--r--src/api/s3/bucket.rs2
-rw-r--r--src/api/s3/error.rs2
-rw-r--r--src/api/s3/list.rs11
-rw-r--r--src/api/s3/post_object.rs5
-rw-r--r--src/api/s3/put.rs96
-rw-r--r--src/api/signature/error.rs2
18 files changed, 268 insertions, 94 deletions
diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml
index dba0bbef..9babec02 100644
--- a/src/api/Cargo.toml
+++ b/src/api/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "garage_api"
-version = "0.8.1"
+version = "0.8.2"
authors = ["Alex Auvolat <alex@adnab.me>"]
edition = "2018"
license = "AGPL-3.0"
@@ -14,35 +14,35 @@ path = "lib.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
-garage_model = { version = "0.8.1", path = "../model" }
-garage_table = { version = "0.8.1", path = "../table" }
-garage_block = { version = "0.8.1", path = "../block" }
-garage_util = { version = "0.8.1", path = "../util" }
-garage_rpc = { version = "0.8.1", path = "../rpc" }
+garage_model = { version = "0.8.2", path = "../model" }
+garage_table = { version = "0.8.2", path = "../table" }
+garage_block = { version = "0.8.2", path = "../block" }
+garage_util = { version = "0.8.2", path = "../util" }
+garage_rpc = { version = "0.8.2", path = "../rpc" }
async-trait = "0.1.7"
-base64 = "0.13"
+base64 = "0.21"
bytes = "1.0"
chrono = "0.4"
crypto-common = "0.1"
err-derive = "0.3"
hex = "0.4"
hmac = "0.12"
-idna = "0.2"
-tracing = "0.1.30"
+idna = "0.3"
+tracing = "0.1"
md-5 = "0.10"
nom = "7.1"
sha2 = "0.10"
futures = "0.3"
futures-util = "0.3"
-pin-project = "1.0.11"
+pin-project = "1.0.12"
tokio = { version = "1.0", default-features = false, features = ["rt", "rt-multi-thread", "io-util", "net", "time", "macros", "sync", "signal", "fs"] }
tokio-stream = "0.1"
form_urlencoded = "1.0.0"
http = "0.2"
-httpdate = "0.3"
+httpdate = "1.0"
http-range = "0.1"
hyper = { version = "0.14", features = ["server", "http1", "runtime", "tcp", "stream"] }
multer = "2.0"
@@ -51,8 +51,8 @@ roxmltree = "0.14"
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11"
serde_json = "1.0"
-quick-xml = { version = "0.21", features = [ "serialize" ] }
-url = "2.1"
+quick-xml = { version = "0.26", features = [ "serialize" ] }
+url = "2.3"
opentelemetry = "0.17"
opentelemetry-prometheus = { version = "0.10", optional = true }
diff --git a/src/api/admin/api_server.rs b/src/api/admin/api_server.rs
index 2d325fb1..58dd38d8 100644
--- a/src/api/admin/api_server.rs
+++ b/src/api/admin/api_server.rs
@@ -1,3 +1,4 @@
+use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
@@ -77,6 +78,60 @@ impl AdminApiServer {
.body(Body::empty())?)
}
+ async fn handle_check_website_enabled(
+ &self,
+ req: Request<Body>,
+ ) -> Result<Response<Body>, Error> {
+ let query_params: HashMap<String, String> = req
+ .uri()
+ .query()
+ .map(|v| {
+ url::form_urlencoded::parse(v.as_bytes())
+ .into_owned()
+ .collect()
+ })
+ .unwrap_or_else(HashMap::new);
+
+ let has_domain_key = query_params.contains_key("domain");
+
+ if !has_domain_key {
+ return Err(Error::bad_request("No domain query string found"));
+ }
+
+ let domain = query_params
+ .get("domain")
+ .ok_or_internal_error("Could not parse domain query string")?;
+
+ let bucket_id = self
+ .garage
+ .bucket_helper()
+ .resolve_global_bucket_name(&domain)
+ .await?
+ .ok_or(HelperError::NoSuchBucket(domain.to_string()))?;
+
+ let bucket = self
+ .garage
+ .bucket_helper()
+ .get_existing_bucket(bucket_id)
+ .await?;
+
+ let bucket_state = bucket.state.as_option().unwrap();
+ let bucket_website_config = bucket_state.website_config.get();
+
+ match bucket_website_config {
+ Some(_v) => {
+ Ok(Response::builder()
+ .status(StatusCode::OK)
+ .body(Body::from(format!(
+ "Bucket '{domain}' is authorized for website hosting"
+ )))?)
+ }
+ None => Err(Error::bad_request(format!(
+ "Bucket '{domain}' is not authorized for website hosting"
+ ))),
+ }
+ }
+
fn handle_health(&self) -> Result<Response<Body>, Error> {
let health = self.garage.system.health();
@@ -174,6 +229,7 @@ impl ApiHandler for AdminApiServer {
match endpoint {
Endpoint::Options => self.handle_options(&req),
+ Endpoint::CheckWebsiteEnabled => self.handle_check_website_enabled(req).await,
Endpoint::Health => self.handle_health(),
Endpoint::Metrics => self.handle_metrics(),
Endpoint::GetClusterStatus => handle_get_cluster_status(&self.garage).await,
diff --git a/src/api/admin/bucket.rs b/src/api/admin/bucket.rs
index 65034852..e60f07ca 100644
--- a/src/api/admin/bucket.rs
+++ b/src/api/admin/bucket.rs
@@ -167,7 +167,7 @@ async fn bucket_info_results(
let quotas = state.quotas.get();
let res =
GetBucketInfoResult {
- id: hex::encode(&bucket.id),
+ id: hex::encode(bucket.id),
global_aliases: state
.aliases
.items()
@@ -575,6 +575,6 @@ pub async fn handle_local_unalias_bucket(
// ---- HELPER ----
fn parse_bucket_id(id: &str) -> Result<Uuid, Error> {
- let id_hex = hex::decode(&id).ok_or_bad_request("Invalid bucket id")?;
+ let id_hex = hex::decode(id).ok_or_bad_request("Invalid bucket id")?;
Ok(Uuid::try_from(&id_hex).ok_or_bad_request("Invalid bucket id")?)
}
diff --git a/src/api/admin/cluster.rs b/src/api/admin/cluster.rs
index 540c6009..b2508d2e 100644
--- a/src/api/admin/cluster.rs
+++ b/src/api/admin/cluster.rs
@@ -20,6 +20,7 @@ pub async fn handle_get_cluster_status(garage: &Arc<Garage>) -> Result<Response<
node: hex::encode(garage.system.id),
garage_version: garage_util::version::garage_version(),
garage_features: garage_util::version::garage_features(),
+ rust_version: garage_util::version::rust_version(),
db_engine: garage.db.engine(),
known_nodes: garage
.system
@@ -106,6 +107,7 @@ struct GetClusterStatusResponse {
node: String,
garage_version: &'static str,
garage_features: Option<&'static [&'static str]>,
+ rust_version: &'static str,
db_engine: String,
known_nodes: HashMap<String, KnownNodeResp>,
layout: GetClusterLayoutResponse,
diff --git a/src/api/admin/router.rs b/src/api/admin/router.rs
index 62e6abc3..0dcb1546 100644
--- a/src/api/admin/router.rs
+++ b/src/api/admin/router.rs
@@ -17,6 +17,7 @@ router_match! {@func
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Endpoint {
Options,
+ CheckWebsiteEnabled,
Health,
Metrics,
GetClusterStatus,
@@ -91,6 +92,7 @@ impl Endpoint {
let res = router_match!(@gen_path_parser (req.method(), path, query) [
OPTIONS _ => Options,
+ GET "/check" => CheckWebsiteEnabled,
GET "/health" => Health,
GET "/metrics" => Metrics,
GET "/v0/status" => GetClusterStatus,
@@ -136,6 +138,7 @@ impl Endpoint {
pub fn authorization_type(&self) -> Authorization {
match self {
Self::Health => Authorization::None,
+ Self::CheckWebsiteEnabled => Authorization::None,
Self::Metrics => Authorization::MetricsToken,
_ => Authorization::AdminToken,
}
diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs
index 62fe4e5a..d0354d28 100644
--- a/src/api/generic_server.rs
+++ b/src/api/generic_server.rs
@@ -19,6 +19,7 @@ use opentelemetry::{
};
use garage_util::error::Error as GarageError;
+use garage_util::forwarded_headers;
use garage_util::metrics::{gen_trace_id, RecordDuration};
pub(crate) trait ApiEndpoint: Send + Sync + 'static {
@@ -125,7 +126,20 @@ impl<A: ApiHandler> ApiServer<A> {
addr: SocketAddr,
) -> Result<Response<Body>, GarageError> {
let uri = req.uri().clone();
- info!("{} {} {}", addr, req.method(), uri);
+
+ if let Ok(forwarded_for_ip_addr) =
+ forwarded_headers::handle_forwarded_for_headers(&req.headers())
+ {
+ info!(
+ "{} (via {}) {} {}",
+ forwarded_for_ip_addr,
+ addr,
+ req.method(),
+ uri
+ );
+ } else {
+ info!("{} {} {}", addr, req.method(), uri);
+ }
debug!("{:?}", req);
let tracer = opentelemetry::global::tracer("garage");
diff --git a/src/api/k2v/api_server.rs b/src/api/k2v/api_server.rs
index 084867b5..bb85b2e7 100644
--- a/src/api/k2v/api_server.rs
+++ b/src/api/k2v/api_server.rs
@@ -164,6 +164,9 @@ impl ApiHandler for K2VApiServer {
Endpoint::InsertBatch {} => handle_insert_batch(garage, bucket_id, req).await,
Endpoint::ReadBatch {} => handle_read_batch(garage, bucket_id, req).await,
Endpoint::DeleteBatch {} => handle_delete_batch(garage, bucket_id, req).await,
+ Endpoint::PollRange { partition_key } => {
+ handle_poll_range(garage, bucket_id, &partition_key, req).await
+ }
Endpoint::Options => unreachable!(),
};
diff --git a/src/api/k2v/batch.rs b/src/api/k2v/batch.rs
index 78035362..26d678da 100644
--- a/src/api/k2v/batch.rs
+++ b/src/api/k2v/batch.rs
@@ -1,10 +1,10 @@
use std::sync::Arc;
+use base64::prelude::*;
use hyper::{Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use garage_util::data::*;
-use garage_util::error::Error as GarageError;
use garage_table::{EnumerationOrder, TableSchema};
@@ -25,15 +25,13 @@ pub async fn handle_insert_batch(
let mut items2 = vec![];
for it in items {
- let ct = it
- .ct
- .map(|s| CausalContext::parse(&s))
- .transpose()
- .ok_or_bad_request("Invalid causality token")?;
+ let ct = it.ct.map(|s| CausalContext::parse_helper(&s)).transpose()?;
let v = match it.v {
- Some(vs) => {
- DvvsValue::Value(base64::decode(vs).ok_or_bad_request("Invalid base64 value")?)
- }
+ Some(vs) => DvvsValue::Value(
+ BASE64_STANDARD
+ .decode(vs)
+ .ok_or_bad_request("Invalid base64 value")?,
+ ),
None => DvvsValue::Deleted,
};
items2.push((it.pk, it.sk, ct, v));
@@ -65,10 +63,7 @@ pub async fn handle_read_batch(
resps.push(resp?);
}
- let resp_json = serde_json::to_string_pretty(&resps).map_err(GarageError::from)?;
- Ok(Response::builder()
- .status(StatusCode::OK)
- .body(Body::from(resp_json))?)
+ Ok(json_ok_response(&resps)?)
}
async fn handle_read_batch_query(
@@ -160,10 +155,7 @@ pub async fn handle_delete_batch(
resps.push(resp?);
}
- let resp_json = serde_json::to_string_pretty(&resps).map_err(GarageError::from)?;
- Ok(Response::builder()
- .status(StatusCode::OK)
- .body(Body::from(resp_json))?)
+ Ok(json_ok_response(&resps)?)
}
async fn handle_delete_batch_query(
@@ -257,6 +249,53 @@ async fn handle_delete_batch_query(
})
}
+pub(crate) async fn handle_poll_range(
+ garage: Arc<Garage>,
+ bucket_id: Uuid,
+ partition_key: &str,
+ req: Request<Body>,
+) -> Result<Response<Body>, Error> {
+ use garage_model::k2v::sub::PollRange;
+
+ let query = parse_json_body::<PollRangeQuery>(req).await?;
+
+ let timeout_msec = query.timeout.unwrap_or(300).clamp(1, 600) * 1000;
+
+ let resp = garage
+ .k2v
+ .rpc
+ .poll_range(
+ PollRange {
+ partition: K2VItemPartition {
+ bucket_id,
+ partition_key: partition_key.to_string(),
+ },
+ start: query.start,
+ end: query.end,
+ prefix: query.prefix,
+ },
+ query.seen_marker,
+ timeout_msec,
+ )
+ .await?;
+
+ if let Some((items, seen_marker)) = resp {
+ let resp = PollRangeResponse {
+ items: items
+ .into_iter()
+ .map(|(_k, i)| ReadBatchResponseItem::from(i))
+ .collect::<Vec<_>>(),
+ seen_marker,
+ };
+
+ Ok(json_ok_response(&resp)?)
+ } else {
+ Ok(Response::builder()
+ .status(StatusCode::NOT_MODIFIED)
+ .body(Body::empty())?)
+ }
+}
+
#[derive(Deserialize)]
struct InsertBatchItem {
pk: String,
@@ -322,7 +361,7 @@ impl ReadBatchResponseItem {
.values()
.iter()
.map(|v| match v {
- DvvsValue::Value(x) => Some(base64::encode(x)),
+ DvvsValue::Value(x) => Some(BASE64_STANDARD.encode(x)),
DvvsValue::Deleted => None,
})
.collect::<Vec<_>>();
@@ -361,3 +400,24 @@ struct DeleteBatchResponse {
#[serde(rename = "deletedItems")]
deleted_items: usize,
}
+
+#[derive(Deserialize)]
+struct PollRangeQuery {
+ #[serde(default)]
+ prefix: Option<String>,
+ #[serde(default)]
+ start: Option<String>,
+ #[serde(default)]
+ end: Option<String>,
+ #[serde(default)]
+ timeout: Option<u64>,
+ #[serde(default, rename = "seenMarker")]
+ seen_marker: Option<String>,
+}
+
+#[derive(Serialize)]
+struct PollRangeResponse {
+ items: Vec<ReadBatchResponseItem>,
+ #[serde(rename = "seenMarker")]
+ seen_marker: String,
+}
diff --git a/src/api/k2v/error.rs b/src/api/k2v/error.rs
index 42491466..4eb017ab 100644
--- a/src/api/k2v/error.rs
+++ b/src/api/k2v/error.rs
@@ -19,7 +19,7 @@ pub enum Error {
// Category: cannot process
/// Authorization Header Malformed
- #[error(display = "Authorization header malformed, expected scope: {}", _0)]
+ #[error(display = "Authorization header malformed, unexpected scope: {}", _0)]
AuthorizationHeaderMalformed(String),
/// The object requested don't exists
diff --git a/src/api/k2v/index.rs b/src/api/k2v/index.rs
index 210950bf..6c1d4a91 100644
--- a/src/api/k2v/index.rs
+++ b/src/api/k2v/index.rs
@@ -1,10 +1,9 @@
use std::sync::Arc;
-use hyper::{Body, Response, StatusCode};
+use hyper::{Body, Response};
use serde::Serialize;
use garage_util::data::*;
-use garage_util::error::Error as GarageError;
use garage_rpc::ring::Ring;
use garage_table::util::*;
@@ -12,6 +11,7 @@ use garage_table::util::*;
use garage_model::garage::Garage;
use garage_model::k2v::item_table::{BYTES, CONFLICTS, ENTRIES, VALUES};
+use crate::helpers::*;
use crate::k2v::error::*;
use crate::k2v::range::read_range;
@@ -68,10 +68,7 @@ pub async fn handle_read_index(
next_start,
};
- let resp_json = serde_json::to_string_pretty(&resp).map_err(GarageError::from)?;
- Ok(Response::builder()
- .status(StatusCode::OK)
- .body(Body::from(resp_json))?)
+ Ok(json_ok_response(&resp)?)
}
#[derive(Serialize)]
diff --git a/src/api/k2v/item.rs b/src/api/k2v/item.rs
index f85138c7..e13a0f30 100644
--- a/src/api/k2v/item.rs
+++ b/src/api/k2v/item.rs
@@ -1,5 +1,6 @@
use std::sync::Arc;
+use base64::prelude::*;
use http::header;
use hyper::{Body, Request, Response, StatusCode};
@@ -81,7 +82,7 @@ impl ReturnFormat {
.iter()
.map(|v| match v {
DvvsValue::Deleted => serde_json::Value::Null,
- DvvsValue::Value(v) => serde_json::Value::String(base64::encode(v)),
+ DvvsValue::Value(v) => serde_json::Value::String(BASE64_STANDARD.encode(v)),
})
.collect::<Vec<_>>();
let json_body =
@@ -133,9 +134,8 @@ pub async fn handle_insert_item(
.get(X_GARAGE_CAUSALITY_TOKEN)
.map(|s| s.to_str())
.transpose()?
- .map(CausalContext::parse)
- .transpose()
- .ok_or_bad_request("Invalid causality token")?;
+ .map(CausalContext::parse_helper)
+ .transpose()?;
let body = hyper::body::to_bytes(req.into_body()).await?;
let value = DvvsValue::Value(body.to_vec());
@@ -169,9 +169,8 @@ pub async fn handle_delete_item(
.get(X_GARAGE_CAUSALITY_TOKEN)
.map(|s| s.to_str())
.transpose()?
- .map(CausalContext::parse)
- .transpose()
- .ok_or_bad_request("Invalid causality token")?;
+ .map(CausalContext::parse_helper)
+ .transpose()?;
let value = DvvsValue::Deleted;
@@ -208,15 +207,17 @@ pub async fn handle_poll_item(
let causal_context =
CausalContext::parse(&causality_token).ok_or_bad_request("Invalid causality token")?;
+ let timeout_msec = timeout_secs.unwrap_or(300).clamp(1, 600) * 1000;
+
let item = garage
.k2v
.rpc
- .poll(
+ .poll_item(
bucket_id,
partition_key,
sort_key,
causal_context,
- timeout_secs.unwrap_or(300) * 1000,
+ timeout_msec,
)
.await?;
diff --git a/src/api/k2v/router.rs b/src/api/k2v/router.rs
index e7a3dd69..1cc58be5 100644
--- a/src/api/k2v/router.rs
+++ b/src/api/k2v/router.rs
@@ -32,6 +32,9 @@ pub enum Endpoint {
causality_token: String,
timeout: Option<u64>,
},
+ PollRange {
+ partition_key: String,
+ },
ReadBatch {
},
ReadIndex {
@@ -113,6 +116,7 @@ impl Endpoint {
@gen_parser
(query.keyword.take().unwrap_or_default(), partition_key, query, None),
key: [
+ POLL_RANGE => PollRange,
],
no_key: [
EMPTY => ReadBatch,
@@ -142,6 +146,7 @@ impl Endpoint {
@gen_parser
(query.keyword.take().unwrap_or_default(), partition_key, query, None),
key: [
+ POLL_RANGE => PollRange,
],
no_key: [
EMPTY => InsertBatch,
@@ -234,7 +239,8 @@ impl Endpoint {
generateQueryParameters! {
keywords: [
"delete" => DELETE,
- "search" => SEARCH
+ "search" => SEARCH,
+ "poll_range" => POLL_RANGE
],
fields: [
"prefix" => prefix,
diff --git a/src/api/s3/bucket.rs b/src/api/s3/bucket.rs
index 8471385f..733981e1 100644
--- a/src/api/s3/bucket.rs
+++ b/src/api/s3/bucket.rs
@@ -305,7 +305,7 @@ fn parse_create_bucket_xml(xml_bytes: &[u8]) -> Option<Option<String>> {
let mut ret = None;
for item in cbc.children() {
if item.has_tag_name("LocationConstraint") {
- if ret != None {
+ if ret.is_some() {
return None;
}
ret = Some(item.text()?.to_string());
diff --git a/src/api/s3/error.rs b/src/api/s3/error.rs
index 67009d63..c50cff9f 100644
--- a/src/api/s3/error.rs
+++ b/src/api/s3/error.rs
@@ -21,7 +21,7 @@ pub enum Error {
// Category: cannot process
/// Authorization Header Malformed
- #[error(display = "Authorization header malformed, expected scope: {}", _0)]
+ #[error(display = "Authorization header malformed, unexpected scope: {}", _0)]
AuthorizationHeaderMalformed(String),
/// The object requested don't exists
diff --git a/src/api/s3/list.rs b/src/api/s3/list.rs
index e5f486c8..5cb0d65a 100644
--- a/src/api/s3/list.rs
+++ b/src/api/s3/list.rs
@@ -3,6 +3,7 @@ use std::collections::{BTreeMap, BTreeSet};
use std::iter::{Iterator, Peekable};
use std::sync::Arc;
+use base64::prelude::*;
use hyper::{Body, Response};
use garage_util::data::*;
@@ -129,11 +130,11 @@ pub async fn handle_list(
next_continuation_token: match (query.is_v2, &pagination) {
(true, Some(RangeBegin::AfterKey { key })) => Some(s3_xml::Value(format!(
"]{}",
- base64::encode(key.as_bytes())
+ BASE64_STANDARD.encode(key.as_bytes())
))),
(true, Some(RangeBegin::IncludingKey { key, .. })) => Some(s3_xml::Value(format!(
"[{}",
- base64::encode(key.as_bytes())
+ BASE64_STANDARD.encode(key.as_bytes())
))),
_ => None,
},
@@ -583,14 +584,16 @@ impl ListObjectsQuery {
(Some(token), _) => match &token[..1] {
"[" => Ok(RangeBegin::IncludingKey {
key: String::from_utf8(
- base64::decode(token[1..].as_bytes())
+ BASE64_STANDARD
+ .decode(token[1..].as_bytes())
.ok_or_bad_request("Invalid continuation token")?,
)?,
fallback_key: None,
}),
"]" => Ok(RangeBegin::AfterKey {
key: String::from_utf8(
- base64::decode(token[1..].as_bytes())
+ BASE64_STANDARD
+ .decode(token[1..].as_bytes())
.ok_or_bad_request("Invalid continuation token")?,
)?,
}),
diff --git a/src/api/s3/post_object.rs b/src/api/s3/post_object.rs
index d063faa4..f2098ab0 100644
--- a/src/api/s3/post_object.rs
+++ b/src/api/s3/post_object.rs
@@ -4,6 +4,7 @@ use std::ops::RangeInclusive;
use std::sync::Arc;
use std::task::{Context, Poll};
+use base64::prelude::*;
use bytes::Bytes;
use chrono::{DateTime, Duration, Utc};
use futures::{Stream, StreamExt};
@@ -138,7 +139,9 @@ pub async fn handle_post_object(
.get_existing_bucket(bucket_id)
.await?;
- let decoded_policy = base64::decode(&policy).ok_or_bad_request("Invalid policy")?;
+ let decoded_policy = BASE64_STANDARD
+ .decode(policy)
+ .ok_or_bad_request("Invalid policy")?;
let decoded_policy: Policy =
serde_json::from_slice(&decoded_policy).ok_or_bad_request("Invalid policy")?;
diff --git a/src/api/s3/put.rs b/src/api/s3/put.rs
index 97b8e4e3..350ab884 100644
--- a/src/api/s3/put.rs
+++ b/src/api/s3/put.rs
@@ -1,6 +1,7 @@
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::sync::Arc;
+use base64::prelude::*;
use futures::prelude::*;
use hyper::body::{Body, Bytes};
use hyper::header::{HeaderMap, HeaderValue};
@@ -119,6 +120,17 @@ pub(crate) async fn save_stream<S: Stream<Item = Result<Bytes, Error>> + Unpin>(
return Ok((version_uuid, data_md5sum_hex));
}
+ // The following consists in many steps that can each fail.
+ // Keep track that some cleanup will be needed if things fail
+ // before everything is finished (cleanup is done using the Drop trait).
+ let mut interrupted_cleanup = InterruptedCleanup(Some((
+ garage.clone(),
+ bucket.id,
+ key.into(),
+ version_uuid,
+ version_timestamp,
+ )));
+
// Write version identifier in object table so that we have a trace
// that we are uploading something
let mut object_version = ObjectVersion {
@@ -139,44 +151,27 @@ pub(crate) async fn save_stream<S: Stream<Item = Result<Bytes, Error>> + Unpin>(
// Transfer data and verify checksum
let first_block_hash = async_blake2sum(first_block.clone()).await;
- let tx_result = (|| async {
- let (total_size, data_md5sum, data_sha256sum) = read_and_put_blocks(
- &garage,
- &version,
- 1,
- first_block,
- first_block_hash,
- &mut chunker,
- )
- .await?;
-
- ensure_checksum_matches(
- data_md5sum.as_slice(),
- data_sha256sum,
- content_md5.as_deref(),
- content_sha256,
- )?;
-
- check_quotas(&garage, bucket, key, total_size).await?;
+ let (total_size, data_md5sum, data_sha256sum) = read_and_put_blocks(
+ &garage,
+ &version,
+ 1,
+ first_block,
+ first_block_hash,
+ &mut chunker,
+ )
+ .await?;
- Ok((total_size, data_md5sum))
- })()
- .await;
+ ensure_checksum_matches(
+ data_md5sum.as_slice(),
+ data_sha256sum,
+ content_md5.as_deref(),
+ content_sha256,
+ )?;
- // If something went wrong, clean up
- let (total_size, md5sum_arr) = match tx_result {
- Ok(rv) => rv,
- Err(e) => {
- // Mark object as aborted, this will free the blocks further down
- object_version.state = ObjectVersionState::Aborted;
- let object = Object::new(bucket.id, key.into(), vec![object_version.clone()]);
- garage.object_table.insert(&object).await?;
- return Err(e);
- }
- };
+ check_quotas(&garage, bucket, key, total_size).await?;
// Save final object state, marked as Complete
- let md5sum_hex = hex::encode(md5sum_arr);
+ let md5sum_hex = hex::encode(data_md5sum);
object_version.state = ObjectVersionState::Complete(ObjectVersionData::FirstBlock(
ObjectVersionMeta {
headers,
@@ -188,6 +183,10 @@ pub(crate) async fn save_stream<S: Stream<Item = Result<Bytes, Error>> + Unpin>(
let object = Object::new(bucket.id, key.into(), vec![object_version]);
garage.object_table.insert(&object).await?;
+ // We were not interrupted, everything went fine.
+ // We won't have to clean up on drop.
+ interrupted_cleanup.cancel();
+
Ok((version_uuid, md5sum_hex))
}
@@ -209,7 +208,7 @@ fn ensure_checksum_matches(
}
}
if let Some(expected_md5) = content_md5 {
- if expected_md5.trim_matches('"') != base64::encode(data_md5sum) {
+ if expected_md5.trim_matches('"') != BASE64_STANDARD.encode(data_md5sum) {
return Err(Error::bad_request("Unable to validate content-md5"));
} else {
trace!("Successfully validated content-md5");
@@ -426,6 +425,33 @@ pub fn put_response(version_uuid: Uuid, md5sum_hex: String) -> Response<Body> {
.unwrap()
}
+struct InterruptedCleanup(Option<(Arc<Garage>, Uuid, String, Uuid, u64)>);
+
+impl InterruptedCleanup {
+ fn cancel(&mut self) {
+ drop(self.0.take());
+ }
+}
+impl Drop for InterruptedCleanup {
+ fn drop(&mut self) {
+ if let Some((garage, bucket_id, key, version_uuid, version_ts)) = self.0.take() {
+ tokio::spawn(async move {
+ let object_version = ObjectVersion {
+ uuid: version_uuid,
+ timestamp: version_ts,
+ state: ObjectVersionState::Aborted,
+ };
+ let object = Object::new(bucket_id, key, vec![object_version]);
+ if let Err(e) = garage.object_table.insert(&object).await {
+ warn!("Cannot cleanup after aborted PutObject: {}", e);
+ }
+ });
+ }
+ }
+}
+
+// ----
+
pub async fn handle_create_multipart_upload(
garage: Arc<Garage>,
req: &Request<Body>,
diff --git a/src/api/signature/error.rs b/src/api/signature/error.rs
index f5a067bd..f0d7c816 100644
--- a/src/api/signature/error.rs
+++ b/src/api/signature/error.rs
@@ -11,7 +11,7 @@ pub enum Error {
Common(CommonError),
/// Authorization Header Malformed
- #[error(display = "Authorization header malformed, expected scope: {}", _0)]
+ #[error(display = "Authorization header malformed, unexpected scope: {}", _0)]
AuthorizationHeaderMalformed(String),
// Category: bad request