use quick_xml::de::from_reader; use std::sync::Arc; use http::header::{ ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, }; use hyper::{header::HeaderName, Body, Method, Request, Response, StatusCode}; use serde::{Deserialize, Serialize}; use crate::error::*; use crate::s3_xml::{to_xml_with_header, xmlns_tag, IntValue, Value}; use crate::signature::verify_signed_content; use garage_model::bucket_table::{Bucket, CorsRule as GarageCorsRule}; use garage_model::garage::Garage; use garage_table::*; use garage_util::data::*; pub async fn handle_get_cors(bucket: &Bucket) -> Result, Error> { let param = bucket .params() .ok_or_internal_error("Bucket should not be deleted at this point")?; if let Some(cors) = param.cors_config.get() { let wc = CorsConfiguration { xmlns: (), cors_rules: cors .iter() .map(CorsRule::from_garage_cors_rule) .collect::>(), }; let xml = to_xml_with_header(&wc)?; Ok(Response::builder() .status(StatusCode::OK) .header(http::header::CONTENT_TYPE, "application/xml") .body(Body::from(xml))?) } else { Ok(Response::builder() .status(StatusCode::NO_CONTENT) .body(Body::empty())?) } } pub async fn handle_delete_cors( garage: Arc, bucket_id: Uuid, ) -> Result, Error> { let mut bucket = garage .bucket_table .get(&EmptyKey, &bucket_id) .await? .ok_or(Error::NoSuchBucket)?; let param = bucket .params_mut() .ok_or_internal_error("Bucket should not be deleted at this point")?; param.cors_config.update(None); garage.bucket_table.insert(&bucket).await?; Ok(Response::builder() .status(StatusCode::NO_CONTENT) .body(Body::empty())?) } pub async fn handle_put_cors( garage: Arc, bucket_id: Uuid, req: Request, content_sha256: Option, ) -> Result, Error> { let body = hyper::body::to_bytes(req.into_body()).await?; if let Some(content_sha256) = content_sha256 { verify_signed_content(content_sha256, &body[..])?; } let mut bucket = garage .bucket_table .get(&EmptyKey, &bucket_id) .await? .ok_or(Error::NoSuchBucket)?; let param = bucket .params_mut() .ok_or_internal_error("Bucket should not be deleted at this point")?; let conf: CorsConfiguration = from_reader(&body as &[u8])?; conf.validate()?; param .cors_config .update(Some(conf.into_garage_cors_config()?)); garage.bucket_table.insert(&bucket).await?; Ok(Response::builder() .status(StatusCode::OK) .body(Body::empty())?) } pub async fn handle_options( garage: Arc, req: &Request, bucket_name: Option, ) -> Result, Error> { let bucket = if let Some(bn) = bucket_name { let helper = garage.bucket_helper(); let bucket_id = helper .resolve_global_bucket_name(&bn) .await? .ok_or(Error::NoSuchBucket)?; garage .bucket_table .get(&EmptyKey, &bucket_id) .await? .filter(|b| !b.state.is_deleted()) .ok_or(Error::NoSuchBucket)? } else { // The only supported API call that doesn't use a bucket name is ListBuckets, // which we want to allow in all cases return Ok(Response::builder() .header(ACCESS_CONTROL_ALLOW_ORIGIN, "*") .header(ACCESS_CONTROL_ALLOW_METHODS, "GET") .status(StatusCode::OK) .body(Body::empty())?); }; let origin = req .headers() .get("Origin") .ok_or_bad_request("Missing Origin header")? .to_str()?; let request_method = req .headers() .get(ACCESS_CONTROL_REQUEST_METHOD) .ok_or_bad_request("Missing Access-Control-Request-Method header")? .to_str()?; let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) { Some(h) => h.to_str()?.split(',').map(|h| h.trim()).collect::>(), None => vec![], }; if let Some(cors_config) = bucket.params().unwrap().cors_config.get() { let matching_rule = cors_config .iter() .find(|rule| cors_rule_matches(rule, origin, request_method, request_headers.iter())); if let Some(rule) = matching_rule { let mut resp = Response::builder() .status(StatusCode::OK) .body(Body::empty())?; add_cors_headers(&mut resp, rule).ok_or_internal_error("Invalid CORS configuration")?; return Ok(resp); } } Err(Error::Forbidden("This CORS request is not allowed.".into())) } pub fn find_matching_cors_rule<'a>( bucket: &'a Bucket, req: &Request, ) -> Result, Error> { if let Some(cors_config) = bucket.params().unwrap().cors_config.get() { if let Some(origin) = req.headers().get("Origin") { let origin = origin.to_str()?; let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) { Some(h) => h.to_str()?.split(',').map(|h| h.trim()).collect::>(), None => vec![], }; return Ok(cors_config.iter().find(|rule| { cors_rule_matches( rule, origin, &req.method().to_string(), request_headers.iter(), ) })); } } Ok(None) } fn cors_rule_matches<'a, HI, S>( rule: &GarageCorsRule, origin: &'a str, method: &'a str, mut request_headers: HI, ) -> bool where HI: Iterator, S: AsRef, { rule.allow_origins.iter().any(|x| x == "*" || x == origin) && rule.allow_methods.iter().any(|x| x == "*" || x == method) && request_headers.all(|h| { rule.allow_headers .iter() .any(|x| x == "*" || x == h.as_ref()) }) } pub fn add_cors_headers( resp: &mut Response, rule: &GarageCorsRule, ) -> Result<(), http::header::InvalidHeaderValue> { let h = resp.headers_mut(); h.insert( ACCESS_CONTROL_ALLOW_ORIGIN, rule.allow_origins.join(", ").parse()?, ); h.insert( ACCESS_CONTROL_ALLOW_METHODS, rule.allow_methods.join(", ").parse()?, ); h.insert( ACCESS_CONTROL_ALLOW_HEADERS, rule.allow_headers.join(", ").parse()?, ); h.insert( ACCESS_CONTROL_EXPOSE_HEADERS, rule.expose_headers.join(", ").parse()?, ); Ok(()) } // ---- SERIALIZATION AND DESERIALIZATION TO/FROM S3 XML ---- #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] #[serde(rename = "CORSConfiguration")] pub struct CorsConfiguration { #[serde(serialize_with = "xmlns_tag", skip_deserializing)] pub xmlns: (), #[serde(rename = "CORSRule")] pub cors_rules: Vec, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] pub struct CorsRule { #[serde(rename = "ID")] pub id: Option, #[serde(rename = "MaxAgeSeconds")] pub max_age_seconds: Option, #[serde(rename = "AllowedOrigin")] pub allowed_origins: Vec, #[serde(rename = "AllowedMethod")] pub allowed_methods: Vec, #[serde(rename = "AllowedHeader", default)] pub allowed_headers: Vec, #[serde(rename = "ExposeHeader", default)] pub expose_headers: Vec, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] pub struct AllowedMethod { #[serde(rename = "AllowedMethod")] pub allowed_method: Value, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] pub struct AllowedHeader { #[serde(rename = "AllowedHeader")] pub allowed_header: Value, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] pub struct ExposeHeader { #[serde(rename = "ExposeHeader")] pub expose_header: Value, } impl CorsConfiguration { pub fn validate(&self) -> Result<(), Error> { for r in self.cors_rules.iter() { r.validate()?; } Ok(()) } pub fn into_garage_cors_config(self) -> Result, Error> { Ok(self .cors_rules .iter() .map(CorsRule::to_garage_cors_rule) .collect()) } } impl CorsRule { pub fn validate(&self) -> Result<(), Error> { for method in self.allowed_methods.iter() { method .0 .parse::() .ok_or_bad_request("Invalid CORSRule method")?; } for header in self .allowed_headers .iter() .chain(self.expose_headers.iter()) { header .0 .parse::() .ok_or_bad_request("Invalid HTTP header name")?; } Ok(()) } pub fn to_garage_cors_rule(&self) -> GarageCorsRule { let convert_vec = |vval: &[Value]| vval.iter().map(|x| x.0.to_owned()).collect::>(); GarageCorsRule { id: self.id.as_ref().map(|x| x.0.to_owned()), max_age_seconds: self.max_age_seconds.as_ref().map(|x| x.0 as u64), allow_origins: convert_vec(&self.allowed_origins), allow_methods: convert_vec(&self.allowed_methods), allow_headers: convert_vec(&self.allowed_headers), expose_headers: convert_vec(&self.expose_headers), } } pub fn from_garage_cors_rule(rule: &GarageCorsRule) -> Self { let convert_vec = |vval: &[String]| { vval.iter() .map(|x| Value(x.clone())) .collect::>() }; Self { id: rule.id.as_ref().map(|x| Value(x.clone())), max_age_seconds: rule.max_age_seconds.map(|x| IntValue(x as i64)), allowed_origins: convert_vec(&rule.allow_origins), allowed_methods: convert_vec(&rule.allow_methods), allowed_headers: convert_vec(&rule.allow_headers), expose_headers: convert_vec(&rule.expose_headers), } } } #[cfg(test)] mod tests { use super::*; use quick_xml::de::from_str; #[test] fn test_deserialize() -> Result<(), Error> { let message = r#" http://www.example.com PUT POST DELETE * * GET qsdfjklm 12345 https://perdu.com GET DELETE * * "#; let conf: CorsConfiguration = from_str(message).unwrap(); let ref_value = CorsConfiguration { xmlns: (), cors_rules: vec![ CorsRule { id: None, max_age_seconds: None, allowed_origins: vec!["http://www.example.com".into()], allowed_methods: vec!["PUT".into(), "POST".into(), "DELETE".into()], allowed_headers: vec!["*".into()], expose_headers: vec![], }, CorsRule { id: None, max_age_seconds: None, allowed_origins: vec!["*".into()], allowed_methods: vec!["GET".into()], allowed_headers: vec![], expose_headers: vec![], }, CorsRule { id: Some("qsdfjklm".into()), max_age_seconds: Some(IntValue(12345)), allowed_origins: vec!["https://perdu.com".into()], allowed_methods: vec!["GET".into(), "DELETE".into()], allowed_headers: vec!["*".into()], expose_headers: vec!["*".into()], }, ], }; assert_eq! { ref_value, conf }; let message2 = to_xml_with_header(&ref_value)?; let cleanup = |c: &str| c.replace(char::is_whitespace, ""); assert_eq!(cleanup(message), cleanup(&message2)); Ok(()) } }