aboutsummaryrefslogblamecommitdiff
path: root/src/api/helpers.rs
blob: 5f488912937b69193d31160e2c2d39393eb536b0 (plain) (tree)
1
2
3
4
5
6
7
8
9
                             
                                               
                                                


                            
                            
                                    
 
                                             
                                                   












                                                                    
                  

                                                                                       


















                                                                                                            
                                                                    


                                                                    
                                                                                      





                                                                           
                                                                              










                                                                                     
                                                               





                                                             





























                                                                                                
                                                                      




























                                                                                         
                                                 
                                                         
                                            




                                                                            
                                                                              
                                      
                                                                                          
 
                                                           
 
 
                                                                      
                                     
                                                  
 
                                                               


                                                                                       





                                                                                    

                                                                       
                                             
 














                                                                                           


                                                          



                     














































                                                                                              












































                                                                                      

































                                                                                 
 






                                      
use std::convert::Infallible;

use futures::{Stream, StreamExt, TryStreamExt};

use http_body_util::{BodyExt, Full as FullBody};
use hyper::{
	body::{Body, Bytes},
	Request, Response,
};
use idna::domain_to_unicode;
use serde::{Deserialize, Serialize};

use garage_util::error::Error as GarageError;

use crate::common_error::{CommonError as Error, *};

/// What kind of authorization is required to perform a given action
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Authorization {
	/// No authorization is required
	None,
	/// Having Read permission on bucket
	Read,
	/// Having Write permission on bucket
	Write,
	/// Having Owner permission on bucket
	Owner,
}

/// Host to bucket
///
/// Convert a host, like "bucket.garage-site.tld" to the corresponding bucket "bucket",
/// considering that ".garage-site.tld" is the "root domain". For domains not matching
/// the provided root domain, no bucket is returned
/// This behavior has been chosen to follow AWS S3 semantic.
pub fn host_to_bucket<'a>(host: &'a str, root: &str) -> Option<&'a str> {
	let root = root.trim_start_matches('.');
	let label_root = root.chars().filter(|c| c == &'.').count() + 1;
	let root = root.rsplit('.');
	let mut host = host.rsplitn(label_root + 1, '.');
	for root_part in root {
		let host_part = host.next()?;
		if root_part != host_part {
			return None;
		}
	}
	host.next()
}

/// Extract host from the authority section given by the HTTP host header
///
/// The HTTP host contains both a host and a port.
/// Extracting the port is more complex than just finding the colon (:) symbol due to IPv6
/// We do not use the collect pattern as there is no way in std rust to collect over a stack allocated value
/// check here: <https://docs.rs/collect_slice/1.2.0/collect_slice/>
pub fn authority_to_host(authority: &str) -> Result<String, Error> {
	let mut iter = authority.chars().enumerate();
	let (_, first_char) = iter
		.next()
		.ok_or_else(|| Error::bad_request("Authority is empty".to_string()))?;

	let split = match first_char {
		'[' => {
			let mut iter = iter.skip_while(|(_, c)| c != &']');
			match iter.next() {
				Some((_, ']')) => iter.next(),
				_ => {
					return Err(Error::bad_request(format!(
						"Authority {} has an illegal format",
						authority
					)))
				}
			}
		}
		_ => iter.find(|(_, c)| *c == ':'),
	};

	let authority = match split {
		Some((i, ':')) => Ok(&authority[..i]),
		None => Ok(authority),
		Some((_, _)) => Err(Error::bad_request(format!(
			"Authority {} has an illegal format",
			authority
		))),
	};
	authority.map(|h| domain_to_unicode(h).0)
}

/// Extract the bucket name and the key name from an HTTP path and possibly a bucket provided in
/// the host header of the request
///
/// S3 internally manages only buckets and keys. This function splits
/// an HTTP path to get the corresponding bucket name and key.
pub fn parse_bucket_key<'a>(
	path: &'a str,
	host_bucket: Option<&'a str>,
) -> Result<(&'a str, Option<&'a str>), Error> {
	let path = path.trim_start_matches('/');

	if let Some(bucket) = host_bucket {
		if !path.is_empty() {
			return Ok((bucket, Some(path)));
		} else {
			return Ok((bucket, None));
		}
	}

	let (bucket, key) = match path.find('/') {
		Some(i) => {
			let key = &path[i + 1..];
			if !key.is_empty() {
				(&path[..i], Some(key))
			} else {
				(&path[..i], None)
			}
		}
		None => (path, None),
	};
	if bucket.is_empty() {
		return Err(Error::bad_request("No bucket specified"));
	}
	Ok((bucket, key))
}

const UTF8_BEFORE_LAST_CHAR: char = '\u{10FFFE}';

/// Compute the key after the prefix
pub fn key_after_prefix(pfx: &str) -> Option<String> {
	let mut next = pfx.to_string();
	while !next.is_empty() {
		let tail = next.pop().unwrap();
		if tail >= char::MAX {
			continue;
		}

		// Circumvent a limitation of RangeFrom that overflow earlier than needed
		// See: https://doc.rust-lang.org/core/ops/struct.RangeFrom.html
		let new_tail = if tail == UTF8_BEFORE_LAST_CHAR {
			char::MAX
		} else {
			(tail..).nth(1).unwrap()
		};

		next.push(new_tail);
		return Some(next);
	}

	None
}

// =============== body helpers =================

pub type EmptyBody = http_body_util::Empty<bytes::Bytes>;
pub type ErrorBody = FullBody<bytes::Bytes>;
pub type BoxBody<E> = http_body_util::combinators::BoxBody<bytes::Bytes, E>;

pub fn string_body<E>(s: String) -> BoxBody<E> {
	bytes_body(bytes::Bytes::from(s.into_bytes()))
}
pub fn bytes_body<E>(b: bytes::Bytes) -> BoxBody<E> {
	BoxBody::new(FullBody::new(b).map_err(|_: Infallible| unreachable!()))
}
pub fn empty_body<E>() -> BoxBody<E> {
	BoxBody::new(http_body_util::Empty::new().map_err(|_: Infallible| unreachable!()))
}
pub fn error_body(s: String) -> ErrorBody {
	ErrorBody::from(bytes::Bytes::from(s.into_bytes()))
}

pub async fn parse_json_body<T, B, E>(req: Request<B>) -> Result<T, E>
where
	T: for<'de> Deserialize<'de>,
	B: Body,
	E: From<<B as Body>::Error> + From<Error>,
{
	let body = req.into_body().collect().await?.to_bytes();
	let resp: T = serde_json::from_slice(&body).ok_or_bad_request("Invalid JSON")?;
	Ok(resp)
}

pub fn json_ok_response<E, T: Serialize>(res: &T) -> Result<Response<BoxBody<E>>, E>
where
	E: From<Error>,
{
	let resp_json = serde_json::to_string_pretty(res)
		.map_err(GarageError::from)
		.map_err(Error::from)?;
	Ok(Response::builder()
		.status(hyper::StatusCode::OK)
		.header(http::header::CONTENT_TYPE, "application/json")
		.body(string_body(resp_json))
		.unwrap())
}

pub fn body_stream<B, E>(body: B) -> impl Stream<Item = Result<Bytes, E>>
where
	B: Body<Data = Bytes>,
	<B as Body>::Error: Into<E>,
	E: From<Error>,
{
	let stream = http_body_util::BodyStream::new(body);
	let stream = TryStreamExt::map_err(stream, Into::into);
	stream.map(|x| {
		x.and_then(|f| {
			f.into_data()
				.map_err(|_| E::from(Error::bad_request("non-data frame")))
		})
	})
}

pub fn is_default<T: Default + PartialEq>(v: &T) -> bool {
	*v == T::default()
}

#[cfg(test)]
mod tests {
	use super::*;

	#[test]
	fn parse_bucket_containing_a_key() -> Result<(), Error> {
		let (bucket, key) = parse_bucket_key("/my_bucket/a/super/file.jpg", None)?;
		assert_eq!(bucket, "my_bucket");
		assert_eq!(key.expect("key must be set"), "a/super/file.jpg");
		Ok(())
	}

	#[test]
	fn parse_bucket_containing_no_key() -> Result<(), Error> {
		let (bucket, key) = parse_bucket_key("/my_bucket/", None)?;
		assert_eq!(bucket, "my_bucket");
		assert!(key.is_none());
		let (bucket, key) = parse_bucket_key("/my_bucket", None)?;
		assert_eq!(bucket, "my_bucket");
		assert!(key.is_none());
		Ok(())
	}

	#[test]
	fn parse_bucket_containing_no_bucket() {
		let parsed = parse_bucket_key("", None);
		assert!(parsed.is_err());
		let parsed = parse_bucket_key("/", None);
		assert!(parsed.is_err());
		let parsed = parse_bucket_key("////", None);
		assert!(parsed.is_err());
	}

	#[test]
	fn parse_bucket_with_vhost_and_key() -> Result<(), Error> {
		let (bucket, key) = parse_bucket_key("/a/super/file.jpg", Some("my-bucket"))?;
		assert_eq!(bucket, "my-bucket");
		assert_eq!(key.expect("key must be set"), "a/super/file.jpg");
		Ok(())
	}

	#[test]
	fn parse_bucket_with_vhost_no_key() -> Result<(), Error> {
		let (bucket, key) = parse_bucket_key("", Some("my-bucket"))?;
		assert_eq!(bucket, "my-bucket");
		assert!(key.is_none());
		let (bucket, key) = parse_bucket_key("/", Some("my-bucket"))?;
		assert_eq!(bucket, "my-bucket");
		assert!(key.is_none());
		Ok(())
	}

	#[test]
	fn authority_to_host_with_port() -> Result<(), Error> {
		let domain = authority_to_host("[::1]:3902")?;
		assert_eq!(domain, "[::1]");
		let domain2 = authority_to_host("garage.tld:65200")?;
		assert_eq!(domain2, "garage.tld");
		let domain3 = authority_to_host("127.0.0.1:80")?;
		assert_eq!(domain3, "127.0.0.1");
		Ok(())
	}

	#[test]
	fn authority_to_host_without_port() -> Result<(), Error> {
		let domain = authority_to_host("[::1]")?;
		assert_eq!(domain, "[::1]");
		let domain2 = authority_to_host("garage.tld")?;
		assert_eq!(domain2, "garage.tld");
		let domain3 = authority_to_host("127.0.0.1")?;
		assert_eq!(domain3, "127.0.0.1");
		assert!(authority_to_host("[").is_err());
		assert!(authority_to_host("[hello").is_err());
		Ok(())
	}

	#[test]
	fn host_to_bucket_test() {
		assert_eq!(
			host_to_bucket("john.doe.garage.tld", ".garage.tld").unwrap(),
			"john.doe"
		);

		assert_eq!(
			host_to_bucket("john.doe.garage.tld", "garage.tld").unwrap(),
			"john.doe"
		);

		assert_eq!(host_to_bucket("john.doe.com", "garage.tld"), None);

		assert_eq!(host_to_bucket("john.doe.com", ".garage.tld"), None);

		assert_eq!(host_to_bucket("garage.tld", "garage.tld"), None);

		assert_eq!(host_to_bucket("garage.tld", ".garage.tld"), None);

		assert_eq!(host_to_bucket("not-garage.tld", "garage.tld"), None);
		assert_eq!(host_to_bucket("not-garage.tld", ".garage.tld"), None);
	}

	#[test]
	fn test_key_after_prefix() {
		use std::iter::FromIterator;

		assert_eq!(UTF8_BEFORE_LAST_CHAR as u32, (char::MAX as u32) - 1);
		assert_eq!(key_after_prefix("a/b/").unwrap().as_str(), "a/b0");
		assert_eq!(key_after_prefix("€").unwrap().as_str(), "₭");
		assert_eq!(
			key_after_prefix("􏿽").unwrap().as_str(),
			String::from(char::from_u32(0x10FFFE).unwrap())
		);

		// When the last character is the biggest UTF8 char
		let a = String::from_iter(['a', char::MAX].iter());
		assert_eq!(key_after_prefix(a.as_str()).unwrap().as_str(), "b");

		// When all characters are the biggest UTF8 char
		let b = String::from_iter([char::MAX; 3].iter());
		assert!(key_after_prefix(b.as_str()).is_none());

		// Check utf8 surrogates
		let c = String::from('\u{D7FF}');
		assert_eq!(
			key_after_prefix(c.as_str()).unwrap().as_str(),
			String::from('\u{E000}')
		);

		// Check the character before the biggest one
		let d = String::from('\u{10FFFE}');
		assert_eq!(
			key_after_prefix(d.as_str()).unwrap().as_str(),
			String::from(char::MAX)
		);
	}
}

#[derive(Serialize)]
pub(crate) struct CustomApiErrorBody {
	pub(crate) code: String,
	pub(crate) message: String,
	pub(crate) region: String,
	pub(crate) path: String,
}