diff options
Diffstat (limited to 'src/api/helpers.rs')
-rw-r--r-- | src/api/helpers.rs | 191 |
1 files changed, 187 insertions, 4 deletions
diff --git a/src/api/helpers.rs b/src/api/helpers.rs index c2709bb3..642dbc42 100644 --- a/src/api/helpers.rs +++ b/src/api/helpers.rs @@ -1,5 +1,21 @@ -use crate::Error; +use hyper::{Body, Request, Response}; use idna::domain_to_unicode; +use serde::{Deserialize, Serialize}; + +use crate::common_error::{CommonError as Error, *}; + +/// What kind of authorization is required to perform a given action +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Authorization { + /// No authorization is required + None, + /// Having Read permission on bucket + Read, + /// Having Write permission on bucket + Write, + /// Having Owner permission on bucket + Owner, +} /// Host to bucket /// @@ -31,7 +47,7 @@ pub fn authority_to_host(authority: &str) -> Result<String, Error> { let mut iter = authority.chars().enumerate(); let (_, first_char) = iter .next() - .ok_or_else(|| Error::BadRequest("Authority is empty".to_string()))?; + .ok_or_else(|| Error::bad_request("Authority is empty".to_string()))?; let split = match first_char { '[' => { @@ -39,7 +55,7 @@ pub fn authority_to_host(authority: &str) -> Result<String, Error> { match iter.next() { Some((_, ']')) => iter.next(), _ => { - return Err(Error::BadRequest(format!( + return Err(Error::bad_request(format!( "Authority {} has an illegal format", authority ))) @@ -52,7 +68,7 @@ pub fn authority_to_host(authority: &str) -> Result<String, Error> { let authority = match split { Some((i, ':')) => Ok(&authority[..i]), None => Ok(authority), - Some((_, _)) => Err(Error::BadRequest(format!( + Some((_, _)) => Err(Error::bad_request(format!( "Authority {} has an illegal format", authority ))), @@ -60,11 +76,135 @@ pub fn authority_to_host(authority: &str) -> Result<String, Error> { authority.map(|h| domain_to_unicode(h).0) } +/// Extract the bucket name and the key name from an HTTP path and possibly a bucket provided in +/// the host header of the request +/// +/// S3 internally manages only buckets and keys. This function splits +/// an HTTP path to get the corresponding bucket name and key. +pub fn parse_bucket_key<'a>( + path: &'a str, + host_bucket: Option<&'a str>, +) -> Result<(&'a str, Option<&'a str>), Error> { + let path = path.trim_start_matches('/'); + + if let Some(bucket) = host_bucket { + if !path.is_empty() { + return Ok((bucket, Some(path))); + } else { + return Ok((bucket, None)); + } + } + + let (bucket, key) = match path.find('/') { + Some(i) => { + let key = &path[i + 1..]; + if !key.is_empty() { + (&path[..i], Some(key)) + } else { + (&path[..i], None) + } + } + None => (path, None), + }; + if bucket.is_empty() { + return Err(Error::bad_request("No bucket specified")); + } + Ok((bucket, key)) +} + +const UTF8_BEFORE_LAST_CHAR: char = '\u{10FFFE}'; + +/// Compute the key after the prefix +pub fn key_after_prefix(pfx: &str) -> Option<String> { + let mut next = pfx.to_string(); + while !next.is_empty() { + let tail = next.pop().unwrap(); + if tail >= char::MAX { + continue; + } + + // Circumvent a limitation of RangeFrom that overflow earlier than needed + // See: https://doc.rust-lang.org/core/ops/struct.RangeFrom.html + let new_tail = if tail == UTF8_BEFORE_LAST_CHAR { + char::MAX + } else { + (tail..).nth(1).unwrap() + }; + + next.push(new_tail); + return Some(next); + } + + None +} + +pub async fn parse_json_body<T: for<'de> Deserialize<'de>>(req: Request<Body>) -> Result<T, Error> { + let body = hyper::body::to_bytes(req.into_body()).await?; + let resp: T = serde_json::from_slice(&body).ok_or_bad_request("Invalid JSON")?; + Ok(resp) +} + +pub fn json_ok_response<T: Serialize>(res: &T) -> Result<Response<Body>, Error> { + let resp_json = serde_json::to_string_pretty(res).map_err(garage_util::error::Error::from)?; + Ok(Response::builder() + .status(hyper::StatusCode::OK) + .header(http::header::CONTENT_TYPE, "application/json") + .body(Body::from(resp_json))?) +} + #[cfg(test)] mod tests { use super::*; #[test] + fn parse_bucket_containing_a_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("/my_bucket/a/super/file.jpg", None)?; + assert_eq!(bucket, "my_bucket"); + assert_eq!(key.expect("key must be set"), "a/super/file.jpg"); + Ok(()) + } + + #[test] + fn parse_bucket_containing_no_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("/my_bucket/", None)?; + assert_eq!(bucket, "my_bucket"); + assert!(key.is_none()); + let (bucket, key) = parse_bucket_key("/my_bucket", None)?; + assert_eq!(bucket, "my_bucket"); + assert!(key.is_none()); + Ok(()) + } + + #[test] + fn parse_bucket_containing_no_bucket() { + let parsed = parse_bucket_key("", None); + assert!(parsed.is_err()); + let parsed = parse_bucket_key("/", None); + assert!(parsed.is_err()); + let parsed = parse_bucket_key("////", None); + assert!(parsed.is_err()); + } + + #[test] + fn parse_bucket_with_vhost_and_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("/a/super/file.jpg", Some("my-bucket"))?; + assert_eq!(bucket, "my-bucket"); + assert_eq!(key.expect("key must be set"), "a/super/file.jpg"); + Ok(()) + } + + #[test] + fn parse_bucket_with_vhost_no_key() -> Result<(), Error> { + let (bucket, key) = parse_bucket_key("", Some("my-bucket"))?; + assert_eq!(bucket, "my-bucket"); + assert!(key.is_none()); + let (bucket, key) = parse_bucket_key("/", Some("my-bucket"))?; + assert_eq!(bucket, "my-bucket"); + assert!(key.is_none()); + Ok(()) + } + + #[test] fn authority_to_host_with_port() -> Result<(), Error> { let domain = authority_to_host("[::1]:3902")?; assert_eq!(domain, "[::1]"); @@ -111,4 +251,47 @@ mod tests { assert_eq!(host_to_bucket("not-garage.tld", "garage.tld"), None); assert_eq!(host_to_bucket("not-garage.tld", ".garage.tld"), None); } + + #[test] + fn test_key_after_prefix() { + use std::iter::FromIterator; + + assert_eq!(UTF8_BEFORE_LAST_CHAR as u32, (char::MAX as u32) - 1); + assert_eq!(key_after_prefix("a/b/").unwrap().as_str(), "a/b0"); + assert_eq!(key_after_prefix("€").unwrap().as_str(), "₭"); + assert_eq!( + key_after_prefix("").unwrap().as_str(), + String::from(char::from_u32(0x10FFFE).unwrap()) + ); + + // When the last character is the biggest UTF8 char + let a = String::from_iter(['a', char::MAX].iter()); + assert_eq!(key_after_prefix(a.as_str()).unwrap().as_str(), "b"); + + // When all characters are the biggest UTF8 char + let b = String::from_iter([char::MAX; 3].iter()); + assert!(key_after_prefix(b.as_str()).is_none()); + + // Check utf8 surrogates + let c = String::from('\u{D7FF}'); + assert_eq!( + key_after_prefix(c.as_str()).unwrap().as_str(), + String::from('\u{E000}') + ); + + // Check the character before the biggest one + let d = String::from('\u{10FFFE}'); + assert_eq!( + key_after_prefix(d.as_str()).unwrap().as_str(), + String::from(char::MAX) + ); + } +} + +#[derive(Serialize)] +pub(crate) struct CustomApiErrorBody { + pub(crate) code: String, + pub(crate) message: String, + pub(crate) region: String, + pub(crate) path: String, } |