aboutsummaryrefslogtreecommitdiff
path: root/src/api/common/cors.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/api/common/cors.rs')
-rw-r--r--src/api/common/cors.rs170
1 files changed, 170 insertions, 0 deletions
diff --git a/src/api/common/cors.rs b/src/api/common/cors.rs
new file mode 100644
index 00000000..14369b56
--- /dev/null
+++ b/src/api/common/cors.rs
@@ -0,0 +1,170 @@
+use std::sync::Arc;
+
+use http::header::{
+ ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
+ ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD,
+};
+use hyper::{body::Body, body::Incoming as IncomingBody, Request, Response, StatusCode};
+
+use garage_model::bucket_table::{BucketParams, CorsRule as GarageCorsRule};
+use garage_model::garage::Garage;
+
+use crate::common_error::{
+ helper_error_as_internal, CommonError, OkOrBadRequest, OkOrInternalError,
+};
+use crate::helpers::*;
+
+pub fn find_matching_cors_rule<'a>(
+ bucket_params: &'a BucketParams,
+ req: &Request<impl Body>,
+) -> Result<Option<&'a GarageCorsRule>, CommonError> {
+ if let Some(cors_config) = bucket_params.cors_config.get() {
+ if let Some(origin) = req.headers().get("Origin") {
+ let origin = origin.to_str()?;
+ let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) {
+ Some(h) => h.to_str()?.split(',').map(|h| h.trim()).collect::<Vec<_>>(),
+ None => vec![],
+ };
+ return Ok(cors_config.iter().find(|rule| {
+ cors_rule_matches(rule, origin, req.method().as_ref(), request_headers.iter())
+ }));
+ }
+ }
+ Ok(None)
+}
+
+pub fn cors_rule_matches<'a, HI, S>(
+ rule: &GarageCorsRule,
+ origin: &'a str,
+ method: &'a str,
+ mut request_headers: HI,
+) -> bool
+where
+ HI: Iterator<Item = S>,
+ S: AsRef<str>,
+{
+ rule.allow_origins.iter().any(|x| x == "*" || x == origin)
+ && rule.allow_methods.iter().any(|x| x == "*" || x == method)
+ && request_headers.all(|h| {
+ rule.allow_headers
+ .iter()
+ .any(|x| x == "*" || x == h.as_ref())
+ })
+}
+
+pub fn add_cors_headers(
+ resp: &mut Response<impl Body>,
+ rule: &GarageCorsRule,
+) -> Result<(), http::header::InvalidHeaderValue> {
+ let h = resp.headers_mut();
+ h.insert(
+ ACCESS_CONTROL_ALLOW_ORIGIN,
+ rule.allow_origins.join(", ").parse()?,
+ );
+ h.insert(
+ ACCESS_CONTROL_ALLOW_METHODS,
+ rule.allow_methods.join(", ").parse()?,
+ );
+ h.insert(
+ ACCESS_CONTROL_ALLOW_HEADERS,
+ rule.allow_headers.join(", ").parse()?,
+ );
+ h.insert(
+ ACCESS_CONTROL_EXPOSE_HEADERS,
+ rule.expose_headers.join(", ").parse()?,
+ );
+ Ok(())
+}
+
+pub async fn handle_options_api(
+ garage: Arc<Garage>,
+ req: &Request<IncomingBody>,
+ bucket_name: Option<String>,
+) -> Result<Response<EmptyBody>, CommonError> {
+ // FIXME: CORS rules of buckets with local aliases are
+ // not taken into account.
+
+ // If the bucket name is a global bucket name,
+ // we try to apply the CORS rules of that bucket.
+ // If a user has a local bucket name that has
+ // the same name, its CORS rules won't be applied
+ // and will be shadowed by the rules of the globally
+ // existing bucket (but this is inevitable because
+ // OPTIONS calls are not auhtenticated).
+ if let Some(bn) = bucket_name {
+ let helper = garage.bucket_helper();
+ let bucket_id = helper
+ .resolve_global_bucket_name(&bn)
+ .await
+ .map_err(helper_error_as_internal)?;
+ if let Some(id) = bucket_id {
+ let bucket = garage
+ .bucket_helper()
+ .get_existing_bucket(id)
+ .await
+ .map_err(helper_error_as_internal)?;
+ let bucket_params = bucket.state.into_option().unwrap();
+ handle_options_for_bucket(req, &bucket_params)
+ } else {
+ // If there is a bucket name in the request, but that name
+ // does not correspond to a global alias for a bucket,
+ // then it's either a non-existing bucket or a local bucket.
+ // We have no way of knowing, because the request is not
+ // authenticated and thus we can't resolve local aliases.
+ // We take the permissive approach of allowing everything,
+ // because we don't want to prevent web apps that use
+ // local bucket names from making API calls.
+ Ok(Response::builder()
+ .header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
+ .header(ACCESS_CONTROL_ALLOW_METHODS, "*")
+ .status(StatusCode::OK)
+ .body(EmptyBody::new())?)
+ }
+ } else {
+ // If there is no bucket name in the request,
+ // we are doing a ListBuckets call, which we want to allow
+ // for all origins.
+ Ok(Response::builder()
+ .header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
+ .header(ACCESS_CONTROL_ALLOW_METHODS, "GET")
+ .status(StatusCode::OK)
+ .body(EmptyBody::new())?)
+ }
+}
+
+pub fn handle_options_for_bucket(
+ req: &Request<IncomingBody>,
+ bucket_params: &BucketParams,
+) -> Result<Response<EmptyBody>, CommonError> {
+ let origin = req
+ .headers()
+ .get("Origin")
+ .ok_or_bad_request("Missing Origin header")?
+ .to_str()?;
+ let request_method = req
+ .headers()
+ .get(ACCESS_CONTROL_REQUEST_METHOD)
+ .ok_or_bad_request("Missing Access-Control-Request-Method header")?
+ .to_str()?;
+ let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) {
+ Some(h) => h.to_str()?.split(',').map(|h| h.trim()).collect::<Vec<_>>(),
+ None => vec![],
+ };
+
+ if let Some(cors_config) = bucket_params.cors_config.get() {
+ let matching_rule = cors_config
+ .iter()
+ .find(|rule| cors_rule_matches(rule, origin, request_method, request_headers.iter()));
+ if let Some(rule) = matching_rule {
+ let mut resp = Response::builder()
+ .status(StatusCode::OK)
+ .body(EmptyBody::new())?;
+ add_cors_headers(&mut resp, rule).ok_or_internal_error("Invalid CORS configuration")?;
+ return Ok(resp);
+ }
+ }
+
+ Err(CommonError::Forbidden(
+ "This CORS request is not allowed.".into(),
+ ))
+}