diff options
Diffstat (limited to 'src/api/common/helpers.rs')
-rw-r--r-- | src/api/common/helpers.rs | 371 |
1 files changed, 371 insertions, 0 deletions
diff --git a/src/api/common/helpers.rs b/src/api/common/helpers.rs new file mode 100644 index 00000000..c8586de4 --- /dev/null +++ b/src/api/common/helpers.rs @@ -0,0 +1,371 @@ +use std::convert::Infallible; +use std::sync::Arc; + +use futures::{Stream, StreamExt, TryStreamExt}; + +use http_body_util::{BodyExt, Full as FullBody}; +use hyper::{ + body::{Body, Bytes}, + Request, Response, +}; +use idna::domain_to_unicode; +use serde::{Deserialize, Serialize}; + +use garage_model::bucket_table::BucketParams; +use garage_model::garage::Garage; +use garage_model::key_table::Key; +use garage_util::data::Uuid; +use garage_util::error::Error as GarageError; + +use crate::common_error::{CommonError as Error, *}; + +/// What kind of authorization is required to perform a given action +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Authorization { + /// No authorization is required + None, + /// Having Read permission on bucket + Read, + /// Having Write permission on bucket + Write, + /// Having Owner permission on bucket + Owner, +} + +/// The values which are known for each request related to a bucket +pub struct ReqCtx { + pub garage: Arc<Garage>, + pub bucket_id: Uuid, + pub bucket_name: String, + pub bucket_params: BucketParams, + pub api_key: Key, +} + +/// Host to bucket +/// +/// Convert a host, like "bucket.garage-site.tld" to the corresponding bucket "bucket", +/// considering that ".garage-site.tld" is the "root domain". For domains not matching +/// the provided root domain, no bucket is returned +/// This behavior has been chosen to follow AWS S3 semantic. +pub fn host_to_bucket<'a>(host: &'a str, root: &str) -> Option<&'a str> { + let root = root.trim_start_matches('.'); + let label_root = root.chars().filter(|c| c == &'.').count() + 1; + let root = root.rsplit('.'); + let mut host = host.rsplitn(label_root + 1, '.'); + for root_part in root { + let host_part = host.next()?; + if root_part != host_part { + return None; + } + } + host.next() +} + +/// Extract host from the authority section given by the HTTP host header +/// +/// The HTTP host contains both a host and a port. +/// Extracting the port is more complex than just finding the colon (:) symbol due to IPv6 +/// We do not use the collect pattern as there is no way in std rust to collect over a stack allocated value +/// check here: <https://docs.rs/collect_slice/1.2.0/collect_slice/> +pub fn authority_to_host(authority: &str) -> Result<String, Error> { + let mut iter = authority.chars().enumerate(); + let (_, first_char) = iter + .next() + .ok_or_else(|| Error::bad_request("Authority is empty".to_string()))?; + + let split = match first_char { + '[' => { + let mut iter = iter.skip_while(|(_, c)| c != &']'); + match iter.next() { + Some((_, ']')) => iter.next(), + _ => { + return Err(Error::bad_request(format!( + "Authority {} has an illegal format", + authority + ))) + } + } + } + _ => iter.find(|(_, c)| *c == ':'), + }; + + let authority = match split { + Some((i, ':')) => Ok(&authority[..i]), + None => Ok(authority), + Some((_, _)) => Err(Error::bad_request(format!( + "Authority {} has an illegal format", + authority + ))), + }; + authority.map(|h| domain_to_unicode(h).0) +} + +/// Extract the bucket name and the key name from an HTTP path and possibly a bucket provided in +/// the host header of the request +/// +/// S3 internally manages only buckets and keys. This function splits +/// an HTTP path to get the corresponding bucket name and key. +pub fn parse_bucket_key<'a>( + path: &'a str, + host_bucket: Option<&'a str>, +) -> Result<(&'a str, Option<&'a str>), Error> { + let path = path.trim_start_matches('/'); + + if let Some(bucket) = host_bucket { + if !path.is_empty() { + return Ok((bucket, Some(path))); + } else { + return Ok((bucket, None)); + } + } + + let (bucket, key) = match path.find('/') { + Some(i) => { + let key = &path[i + 1..]; + if !key.is_empty() { + (&path[..i], Some(key)) + } else { + (&path[..i], None) + } + } + None => (path, None), + }; + if bucket.is_empty() { + return Err(Error::bad_request("No bucket specified")); + } + Ok((bucket, key)) +} + +const UTF8_BEFORE_LAST_CHAR: char = '\u{10FFFE}'; + +/// Compute the key after the prefix +pub fn key_after_prefix(pfx: &str) -> Option<String> { + let mut next = pfx.to_string(); + while !next.is_empty() { + let tail = next.pop().unwrap(); + if tail >= char::MAX { + continue; + } + + // Circumvent a limitation of RangeFrom that overflow earlier than needed + // See: https://doc.rust-lang.org/core/ops/struct.RangeFrom.html + let new_tail = if tail == UTF8_BEFORE_LAST_CHAR { + char::MAX + } else { + (tail..).nth(1).unwrap() + }; + + next.push(new_tail); + return Some(next); + } + + None +} + +// =============== body helpers ================= + +pub type EmptyBody = http_body_util::Empty<bytes::Bytes>; +pub type ErrorBody = FullBody<bytes::Bytes>; +pub type BoxBody<E> = http_body_util::combinators::BoxBody<bytes::Bytes, E>; + +pub fn string_body<E>(s: String) -> BoxBody<E> { + bytes_body(bytes::Bytes::from(s.into_bytes())) +} +pub fn bytes_body<E>(b: bytes::Bytes) -> BoxBody<E> { + BoxBody::new(FullBody::new(b).map_err(|_: Infallible| unreachable!())) +} +pub fn empty_body<E>() -> BoxBody<E> { + BoxBody::new(http_body_util::Empty::new().map_err(|_: Infallible| unreachable!())) +} +pub fn error_body(s: String) -> ErrorBody { + ErrorBody::from(bytes::Bytes::from(s.into_bytes())) +} + +pub async fn parse_json_body<T, B, E>(req: Request<B>) -> Result<T, E> +where + T: for<'de> Deserialize<'de>, + B: Body, + E: From<<B as Body>::Error> + From<Error>, +{ + let body = req.into_body().collect().await?.to_bytes(); + let resp: T = serde_json::from_slice(&body).ok_or_bad_request("Invalid JSON")?; + Ok(resp) +} + +pub fn json_ok_response<E, T: Serialize>(res: &T) -> Result<Response<BoxBody<E>>, E> +where + E: From<Error>, +{ + let resp_json = serde_json::to_string_pretty(res) + .map_err(GarageError::from) + .map_err(Error::from)?; + Ok(Response::builder() + .status(hyper::StatusCode::OK) + .header(http::header::CONTENT_TYPE, "application/json") + .body(string_body(resp_json)) + .unwrap()) +} + +pub fn body_stream<B, E>(body: B) -> impl Stream<Item = Result<Bytes, E>> +where + B: Body<Data = Bytes>, + <B as Body>::Error: Into<E>, + E: From<Error>, +{ + let stream = http_body_util::BodyStream::new(body); + let stream = TryStreamExt::map_err(stream, Into::into); + stream.map(|x| { + x.and_then(|f| { + f.into_data() + .map_err(|_| E::from(Error::bad_request("non-data frame"))) + }) + }) +} + +pub fn is_default<T: Default + PartialEq>(v: &T) -> bool { + *v == T::default() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_bucket_containing_a_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("/my_bucket/a/super/file.jpg", None)?; + assert_eq!(bucket, "my_bucket"); + assert_eq!(key.expect("key must be set"), "a/super/file.jpg"); + Ok(()) + } + + #[test] + fn parse_bucket_containing_no_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("/my_bucket/", None)?; + assert_eq!(bucket, "my_bucket"); + assert!(key.is_none()); + let (bucket, key) = parse_bucket_key("/my_bucket", None)?; + assert_eq!(bucket, "my_bucket"); + assert!(key.is_none()); + Ok(()) + } + + #[test] + fn parse_bucket_containing_no_bucket() { + let parsed = parse_bucket_key("", None); + assert!(parsed.is_err()); + let parsed = parse_bucket_key("/", None); + assert!(parsed.is_err()); + let parsed = parse_bucket_key("////", None); + assert!(parsed.is_err()); + } + + #[test] + fn parse_bucket_with_vhost_and_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("/a/super/file.jpg", Some("my-bucket"))?; + assert_eq!(bucket, "my-bucket"); + assert_eq!(key.expect("key must be set"), "a/super/file.jpg"); + Ok(()) + } + + #[test] + fn parse_bucket_with_vhost_no_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("", Some("my-bucket"))?; + assert_eq!(bucket, "my-bucket"); + assert!(key.is_none()); + let (bucket, key) = parse_bucket_key("/", Some("my-bucket"))?; + assert_eq!(bucket, "my-bucket"); + assert!(key.is_none()); + Ok(()) + } + + #[test] + fn authority_to_host_with_port() -> Result<(), Error> { + let domain = authority_to_host("[::1]:3902")?; + assert_eq!(domain, "[::1]"); + let domain2 = authority_to_host("garage.tld:65200")?; + assert_eq!(domain2, "garage.tld"); + let domain3 = authority_to_host("127.0.0.1:80")?; + assert_eq!(domain3, "127.0.0.1"); + Ok(()) + } + + #[test] + fn authority_to_host_without_port() -> Result<(), Error> { + let domain = authority_to_host("[::1]")?; + assert_eq!(domain, "[::1]"); + let domain2 = authority_to_host("garage.tld")?; + assert_eq!(domain2, "garage.tld"); + let domain3 = authority_to_host("127.0.0.1")?; + assert_eq!(domain3, "127.0.0.1"); + assert!(authority_to_host("[").is_err()); + assert!(authority_to_host("[hello").is_err()); + Ok(()) + } + + #[test] + fn host_to_bucket_test() { + assert_eq!( + host_to_bucket("john.doe.garage.tld", ".garage.tld").unwrap(), + "john.doe" + ); + + assert_eq!( + host_to_bucket("john.doe.garage.tld", "garage.tld").unwrap(), + "john.doe" + ); + + assert_eq!(host_to_bucket("john.doe.com", "garage.tld"), None); + + assert_eq!(host_to_bucket("john.doe.com", ".garage.tld"), None); + + assert_eq!(host_to_bucket("garage.tld", "garage.tld"), None); + + assert_eq!(host_to_bucket("garage.tld", ".garage.tld"), None); + + assert_eq!(host_to_bucket("not-garage.tld", "garage.tld"), None); + assert_eq!(host_to_bucket("not-garage.tld", ".garage.tld"), None); + } + + #[test] + fn test_key_after_prefix() { + use std::iter::FromIterator; + + assert_eq!(UTF8_BEFORE_LAST_CHAR as u32, (char::MAX as u32) - 1); + assert_eq!(key_after_prefix("a/b/").unwrap().as_str(), "a/b0"); + assert_eq!(key_after_prefix("€").unwrap().as_str(), "₭"); + assert_eq!( + key_after_prefix("").unwrap().as_str(), + String::from(char::from_u32(0x10FFFE).unwrap()) + ); + + // When the last character is the biggest UTF8 char + let a = String::from_iter(['a', char::MAX].iter()); + assert_eq!(key_after_prefix(a.as_str()).unwrap().as_str(), "b"); + + // When all characters are the biggest UTF8 char + let b = String::from_iter([char::MAX; 3].iter()); + assert!(key_after_prefix(b.as_str()).is_none()); + + // Check utf8 surrogates + let c = String::from('\u{D7FF}'); + assert_eq!( + key_after_prefix(c.as_str()).unwrap().as_str(), + String::from('\u{E000}') + ); + + // Check the character before the biggest one + let d = String::from('\u{10FFFE}'); + assert_eq!( + key_after_prefix(d.as_str()).unwrap().as_str(), + String::from(char::MAX) + ); + } +} + +#[derive(Serialize)] +pub struct CustomApiErrorBody { + pub code: String, + pub message: String, + pub region: String, + pub path: String, +} |