use std::borrow::Cow;
use std::convert::TryInto;
use std::pin::Pin;
use aes_gcm::{
aead::stream::{DecryptorLE31, EncryptorLE31, StreamLE31},
aead::{Aead, AeadCore, KeyInit, OsRng},
aes::cipher::crypto_common::rand_core::RngCore,
aes::cipher::typenum::Unsigned,
Aes256Gcm, Key, Nonce,
};
use base64::prelude::*;
use bytes::Bytes;
use futures::stream::Stream;
use futures::task;
use tokio::io::BufReader;
use http::header::{HeaderName, HeaderValue};
use hyper::{body::Body, Request};
use garage_net::bytes_buf::BytesBuf;
use garage_net::stream::{stream_asyncread, ByteStream};
use garage_rpc::rpc_helper::OrderTag;
use garage_util::data::Hash;
use garage_util::error::Error as GarageError;
use garage_util::migrate::Migrate;
use garage_model::garage::Garage;
use garage_model::s3::object_table::{ObjectVersionEncryption, ObjectVersionHeaders};
use crate::common_error::*;
use crate::s3::error::Error;
const X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM: HeaderName =
HeaderName::from_static("x-amz-server-side-encryption-customer-algorithm");
const X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY: HeaderName =
HeaderName::from_static("x-amz-server-side-encryption-customer-key");
const X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5: HeaderName =
HeaderName::from_static("x-amz-server-side-encryption-customer-key-md5");
const X_AMZ_COPY_SOURCE_SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM: HeaderName =
HeaderName::from_static("x-amz-copy-source-server-side-encryption-customer-algorithm");
const X_AMZ_COPY_SOURCE_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY: HeaderName =
HeaderName::from_static("x-amz-copy-source-server-side-encryption-customer-key");
const X_AMZ_COPY_SOURCE_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5: HeaderName =
HeaderName::from_static("x-amz-copy-source-server-side-encryption-customer-key-md5");
const CUSTOMER_ALGORITHM_AES256: HeaderValue = HeaderValue::from_static("AES256");
type StreamNonceSize = aes_gcm::aead::stream::NonceSize<Aes256Gcm, StreamLE31<Aes256Gcm>>;
const STREAM_ENC_PLAIN_CHUNK_SIZE: usize = 0x1000; // 4096 bytes
const STREAM_ENC_CYPER_CHUNK_SIZE: usize = STREAM_ENC_PLAIN_CHUNK_SIZE + 16;
#[derive(Clone, Copy)]
pub enum EncryptionParams {
Plaintext,
SseC {
client_key: Key<Aes256Gcm>,
compression_level: Option<i32>,
},
}
impl EncryptionParams {
pub fn is_encrypted(&self) -> bool {
matches!(self, Self::SseC { .. })
}
pub fn new_from_req(
garage: &Garage,
req: &Request<impl Body>,
) -> Result<EncryptionParams, Error> {
let key = parse_request_headers(
req,
&X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM,
&X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY,
&X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5,
)?;
match key {
Some(client_key) => Ok(EncryptionParams::SseC {
client_key,
compression_level: garage.config.compression_level,
}),
None => Ok(EncryptionParams::Plaintext),
}
}
pub fn check_decrypt_for_get<'a>(
garage: &Garage,
req: &Request<impl Body>,
obj_enc: &'a ObjectVersionEncryption,
) -> Result<(Self, Cow<'a, ObjectVersionHeaders>), Error> {
let key = parse_request_headers(
req,
&X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM,
&X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY,
&X_AMZ_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5,
)?;
Self::check_decrypt(garage, key, obj_enc)
}
pub fn check_decrypt_for_copy_source<'a>(
garage: &Garage,
req: &Request<impl Body>,
obj_enc: &'a ObjectVersionEncryption,
) -> Result<(Self, Cow<'a, ObjectVersionHeaders>), Error> {
let key = parse_request_headers(
req,
&X_AMZ_COPY_SOURCE_SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM,
&X_AMZ_COPY_SOURCE_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY,
&X_AMZ_COPY_SOURCE_SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5,
)?;
Self::check_decrypt(garage, key, obj_enc)
}
fn check_decrypt<'a>(
garage: &Garage,
key: Option<Key<Aes256Gcm>>,
obj_enc: &'a ObjectVersionEncryption,
) -> Result<(Self, Cow<'a, ObjectVersionHeaders>), Error> {
match (key, &obj_enc) {
(
Some(client_key),
ObjectVersionEncryption::SseC {
headers,
compressed,
},
) => {
let enc = Self::SseC {
client_key,
compression_level: if *compressed {
Some(garage.config.compression_level.unwrap_or(1))
} else {
None
},
};
let plaintext = enc.decrypt_blob(&headers)?;
let headers = ObjectVersionHeaders::decode(&plaintext)
.ok_or_internal_error("Could not decode encrypted headers")?;
Ok((enc, Cow::Owned(headers)))
}
(None, ObjectVersionEncryption::Plaintext { headers }) => {
Ok((Self::Plaintext, Cow::Borrowed(headers)))
}
(_, ObjectVersionEncryption::SseC { .. }) => {
Err(Error::bad_request("Object is encrypted"))
}
(Some(_), _) => {
// TODO: should this be an OK scenario?
Err(Error::bad_request("Trying to decrypt a plaintext object"))
}
}
}
pub fn encrypt_headers(
&self,
h: ObjectVersionHeaders,
) -> Result<ObjectVersionEncryption, Error> {
match self {
Self::SseC {
compression_level, ..
} => {
let plaintext = h.encode().map_err(GarageError::from)?;
let ciphertext = self.encrypt_blob(&plaintext)?;
Ok(ObjectVersionEncryption::SseC {
headers: ciphertext.into_owned(),
compressed: compression_level.is_some(),
})
}
Self::Plaintext => Ok(ObjectVersionEncryption::Plaintext { headers: h }),
}
}
// ---- generic function for encrypting / decrypting blobs ----
// prepends a randomly-generated nonce to the encrypted value
pub fn encrypt_blob<'a>(&self, blob: &'a [u8]) -> Result<Cow<'a, [u8]>, Error> {
match self {
Self::SseC { client_key, .. } => {
let cipher = Aes256Gcm::new(&client_key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, blob)
.ok_or_internal_error("Encryption failed")?;
Ok([nonce.to_vec(), ciphertext].concat().into())
}
Self::Plaintext => Ok(blob.into()),
}
}
pub fn decrypt_blob<'a>(&self, blob: &'a [u8]) -> Result<Cow<'a, [u8]>, Error> {
match self {
Self::SseC { client_key, .. } => {
let cipher = Aes256Gcm::new(&client_key);
let nonce_size = <Aes256Gcm as AeadCore>::NonceSize::to_usize();
let nonce = Nonce::from_slice(
blob.get(..nonce_size)
.ok_or_internal_error("invalid encrypted data")?,
);
let plaintext = cipher
.decrypt(nonce, &blob[nonce_size..])
.ok_or_bad_request(
"Invalid encryption key, could not decrypt object metadata.",
)?;
Ok(plaintext.into())
}
Self::Plaintext => Ok(blob.into()),
}
}
// ---- function for encrypting / decrypting byte streams ----
/// Get a data block from the storage node, and decrypt+decompress it
/// if necessary. If object is plaintext, just get it without any processing.
pub async fn get_block(
&self,
garage: &Garage,
hash: &Hash,
order: Option<OrderTag>,
) -> Result<ByteStream, Error> {
let raw_block = garage
.block_manager
.rpc_get_block_streaming(hash, order)
.await?;
Ok(self.decrypt_block_stream(raw_block))
}
pub fn decrypt_block_stream(&self, stream: ByteStream) -> ByteStream {
match self {
Self::Plaintext => stream,
Self::SseC {
client_key,
compression_level,
} => {
let plaintext = DecryptStream::new(stream, *client_key);
if compression_level.is_some() {
let reader = stream_asyncread(Box::pin(plaintext));
let reader = BufReader::new(reader);
let reader = async_compression::tokio::bufread::ZstdDecoder::new(reader);
Box::pin(tokio_util::io::ReaderStream::new(reader))
} else {
Box::pin(plaintext)
}
}
}
}
/// Encrypt a data block if encryption is set, for use before
/// putting the data blocks into storage
pub fn encrypt_block(&self, block: Bytes) -> Result<Bytes, Error> {
match self {
Self::Plaintext => Ok(block),
Self::SseC {
client_key,
compression_level,
} => {
let block = if let Some(level) = compression_level {
Cow::Owned(
garage_block::zstd_encode(block.as_ref(), *level)
.ok_or_internal_error("failed to compress data block")?,
)
} else {
Cow::Borrowed(block.as_ref())
};
let mut ret = Vec::with_capacity(block.len() + 32 + block.len() / 64);
let mut nonce: Nonce<StreamNonceSize> = Default::default();
OsRng.fill_bytes(&mut nonce);
ret.extend_from_slice(nonce.as_slice());
let mut cipher = EncryptorLE31::<Aes256Gcm>::new(&client_key, &nonce);
for chunk in block.chunks(STREAM_ENC_PLAIN_CHUNK_SIZE) {
// TODO: use encrypt_last for last chunk
let chunk_enc = cipher
.encrypt_next(chunk)
.ok_or_internal_error("failed to encrypt chunk")?;
if chunk.len() == STREAM_ENC_PLAIN_CHUNK_SIZE {
assert_eq!(chunk_enc.len(), STREAM_ENC_CYPER_CHUNK_SIZE);
}
ret.extend_from_slice(&chunk_enc);
}
Ok(ret.into())
}
}
}
}
fn parse_request_headers(
req: &Request<impl Body>,
alg_header: &HeaderName,
key_header: &HeaderName,
md5_header: &HeaderName,
) -> Result<Option<Key<Aes256Gcm>>, Error> {
match req.headers().get(alg_header) {
Some(alg) if *alg == CUSTOMER_ALGORITHM_AES256 => {
use md5::{Digest, Md5};
let key_b64 = req
.headers()
.get(key_header)
.ok_or_bad_request(format!("Missing {} header", key_header))?;
let key_bytes: [u8; 32] = BASE64_STANDARD
.decode(&key_b64)
.ok_or_bad_request(format!("Invalid {} header", key_header))?
.try_into()
.ok()
.ok_or_bad_request(format!("Invalid {} header", key_header))?;
let md5_b64 = req
.headers()
.get(md5_header)
.ok_or_bad_request(format!("Missing {} header", md5_header))?;
let md5_bytes = BASE64_STANDARD
.decode(&md5_b64)
.ok_or_bad_request(format!("Invalid {} header", md5_header))?;
let mut hasher = Md5::new();
hasher.update(&key_bytes[..]);
if hasher.finalize().as_slice() != md5_bytes.as_slice() {
return Err(Error::bad_request(
"Encryption key MD5 checksum does not match",
));
}
Ok(Some(key_bytes.into()))
}
Some(alg) => Err(Error::InvalidEncryptionAlgorithm(
alg.to_str().unwrap_or("??").to_string(),
)),
None => Ok(None),
}
}
// ---- encrypt & decrypt streams ----
#[pin_project::pin_project]
struct DecryptStream {
#[pin]
stream: ByteStream,
buf: BytesBuf,
key: Key<Aes256Gcm>,
cipher: Option<DecryptorLE31<Aes256Gcm>>,
}
impl DecryptStream {
fn new(stream: ByteStream, key: Key<Aes256Gcm>) -> Self {
Self {
stream,
buf: BytesBuf::new(),
key,
cipher: None,
}
}
}
impl Stream for DecryptStream {
type Item = Result<Bytes, std::io::Error>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<Self::Item>> {
use std::task::Poll;
let mut this = self.project();
while this.cipher.is_none() {
let nonce_size = StreamNonceSize::to_usize();
if this.buf.len() >= nonce_size {
let nonce = this.buf.take_exact(nonce_size).unwrap();
let nonce = Nonce::from_slice(nonce.as_ref());
*this.cipher = Some(DecryptorLE31::new(&this.key, nonce));
break;
}
match futures::ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
this.buf.extend(bytes);
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(e)));
}
None => {
return Poll::Ready(Some(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Decrypt: unexpected EOF, could not read nonce",
))));
}
}
}
while this.buf.len() < STREAM_ENC_CYPER_CHUNK_SIZE {
match futures::ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
this.buf.extend(bytes);
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(e)));
}
None => break,
}
}
if this.buf.is_empty() {
return Poll::Ready(None);
}
let chunk = this.buf.take_max(STREAM_ENC_CYPER_CHUNK_SIZE);
// TODO: use decrypt_last for last chunk
let res = this.cipher.as_mut().unwrap().decrypt_next(chunk.as_ref());
match res {
Ok(bytes) => Poll::Ready(Some(Ok(bytes.into()))),
Err(_) => Poll::Ready(Some(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Decryption failed",
)))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::prelude::*;
use garage_net::stream::read_stream_to_end;
fn stream() -> ByteStream {
Box::pin(
futures::stream::iter(16usize..1024)
.map(|i| Ok(Bytes::from(vec![(i % 256) as u8; (i * 37) % 1024]))),
)
}
async fn test_block_enc(compression_level: Option<i32>) {
let enc = EncryptionParams::SseC {
client_key: Aes256Gcm::generate_key(&mut OsRng),
compression_level,
};
let block_plain = read_stream_to_end(stream()).await.unwrap().into_bytes();
let block_enc = enc.encrypt_block(block_plain.clone()).unwrap();
let block_dec =
enc.decrypt_block_stream(Box::pin(futures::stream::once(async { Ok(block_enc) })));
let block_dec = read_stream_to_end(block_dec).await.unwrap().into_bytes();
assert_eq!(block_plain, block_dec);
assert!(block_dec.len() > 128000);
}
#[tokio::test]
async fn test_encrypt_block() {
test_block_enc(None).await
}
#[tokio::test]
async fn test_encrypt_block_compressed() {
test_block_enc(Some(1)).await
}
}