aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2020-04-12 15:51:19 +0200
committerAlex Auvolat <alex@adnab.me>2020-04-12 15:51:19 +0200
commitd1e8f78b2cd28f4514ad6f7d54aae6aaa4ef3f15 (patch)
tree74ac969472fad3baa8f5a3cdac6bfc6b3846d2e3 /src
parent5967c5a5af430855fbd73f380041d63bd82f5ce1 (diff)
downloadgarage-d1e8f78b2cd28f4514ad6f7d54aae6aaa4ef3f15.tar.gz
garage-d1e8f78b2cd28f4514ad6f7d54aae6aaa4ef3f15.zip
Trying to do TLS
Diffstat (limited to 'src')
-rw-r--r--src/block_ref_table.rs4
-rw-r--r--src/error.rs6
-rw-r--r--src/main.rs3
-rw-r--r--src/membership.rs4
-rw-r--r--src/object_table.rs2
-rw-r--r--src/rpc_client.rs53
-rw-r--r--src/rpc_server.rs98
-rw-r--r--src/server.rs9
-rw-r--r--src/tls_util.rs46
-rw-r--r--src/version_table.rs6
10 files changed, 189 insertions, 42 deletions
diff --git a/src/block_ref_table.rs b/src/block_ref_table.rs
index 3e5fb0a1..0511ea25 100644
--- a/src/block_ref_table.rs
+++ b/src/block_ref_table.rs
@@ -2,10 +2,10 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
-use crate::data::*;
-use crate::table::*;
use crate::background::*;
use crate::block::*;
+use crate::data::*;
+use crate::table::*;
#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)]
pub struct BlockRef {
diff --git a/src/error.rs b/src/error.rs
index 661621c9..c9653f5a 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -16,6 +16,12 @@ pub enum Error {
#[error(display = "Invalid HTTP header value: {}", _0)]
HTTPHeader(#[error(source)] http::header::ToStrError),
+ #[error(display = "TLS error: {}", _0)]
+ TLS(#[error(source)] rustls::TLSError),
+
+ #[error(display = "PKI error: {}", _0)]
+ PKI(#[error(source)] webpki::Error),
+
#[error(display = "Sled error: {}", _0)]
Sled(#[error(source)] sled::Error),
diff --git a/src/main.rs b/src/main.rs
index 533afcc7..619f3422 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -16,6 +16,7 @@ mod http_util;
mod rpc_client;
mod rpc_server;
mod server;
+mod tls_util;
use std::collections::HashSet;
use std::net::SocketAddr;
@@ -76,7 +77,7 @@ pub struct ConfigureOpt {
async fn main() {
let opt = Opt::from_args();
- let rpc_cli = RpcClient::new();
+ let rpc_cli = RpcClient::new(&None).expect("Could not create RPC client");
let resp = match opt.cmd {
Command::Server(server_opt) => server::run_server(server_opt.config_file).await,
diff --git a/src/membership.rs b/src/membership.rs
index 22c13f64..89550b67 100644
--- a/src/membership.rs
+++ b/src/membership.rs
@@ -226,10 +226,12 @@ impl System {
ring.rebuild_ring();
let (update_ring, ring) = watch::channel(Arc::new(ring));
+ let rpc_client = RpcClient::new(&config.rpc_tls).expect("Could not create RPC client");
+
System {
config,
id,
- rpc_client: RpcClient::new(),
+ rpc_client,
status,
ring,
update_lock: Mutex::new((update_status, update_ring)),
diff --git a/src/object_table.rs b/src/object_table.rs
index 8ce49565..a3a03372 100644
--- a/src/object_table.rs
+++ b/src/object_table.rs
@@ -2,9 +2,9 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
+use crate::background::BackgroundRunner;
use crate::data::*;
use crate::table::*;
-use crate::background::BackgroundRunner;
#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)]
pub struct Object {
diff --git a/src/rpc_client.rs b/src/rpc_client.rs
index 8d8b724b..247f114e 100644
--- a/src/rpc_client.rs
+++ b/src/rpc_client.rs
@@ -6,13 +6,16 @@ use bytes::IntoBuf;
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt;
use futures_util::future::FutureExt;
-use hyper::client::Client;
+use hyper::client::{Client, HttpConnector};
use hyper::{Body, Method, Request, StatusCode};
+use hyper_rustls::HttpsConnector;
use crate::data::*;
use crate::error::Error;
use crate::membership::System;
use crate::proto::Message;
+use crate::server::*;
+use crate::tls_util;
pub async fn rpc_call_many(
sys: Arc<System>,
@@ -88,14 +91,34 @@ pub async fn rpc_call(
sys.rpc_client.call(&addr, msg, timeout).await
}
-pub struct RpcClient {
- pub client: Client<hyper::client::HttpConnector, hyper::Body>,
+pub enum RpcClient {
+ HTTP(Client<HttpConnector, hyper::Body>),
+ HTTPS(Client<HttpsConnector<HttpConnector>, hyper::Body>),
}
impl RpcClient {
- pub fn new() -> Self {
- RpcClient {
- client: Client::new(),
+ pub fn new(tls_config: &Option<TlsConfig>) -> Result<Self, Error> {
+ if let Some(cf) = tls_config {
+ let ca_certs = tls_util::load_certs(&cf.ca_cert)?;
+ let node_certs = tls_util::load_certs(&cf.node_cert)?;
+ let node_key = tls_util::load_private_key(&cf.node_key)?;
+
+ let mut config = rustls::ClientConfig::new();
+
+ for crt in ca_certs.iter() {
+ config.root_store.add(crt)?;
+ }
+
+ config.set_single_client_cert([&ca_certs[..], &node_certs[..]].concat(), node_key)?;
+
+ let mut http_connector = HttpConnector::new();
+ http_connector.enforce_http(false);
+ let connector =
+ HttpsConnector::<HttpConnector>::from((http_connector, Arc::new(config)));
+
+ Ok(RpcClient::HTTPS(Client::builder().build(connector)))
+ } else {
+ Ok(RpcClient::HTTP(Client::new()))
}
}
@@ -105,14 +128,26 @@ impl RpcClient {
msg: &Message,
timeout: Duration,
) -> Result<Message, Error> {
- let uri = format!("http://{}/rpc", to_addr);
+ let uri = match self {
+ RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr),
+ RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr),
+ };
+
let req = Request::builder()
.method(Method::POST)
.uri(uri)
.body(Body::from(rmp_to_vec_all_named(msg)?))?;
- let resp_fut = self.client.request(req).fuse();
- let resp = tokio::time::timeout(timeout, resp_fut).await??;
+ let resp_fut = match self {
+ RpcClient::HTTP(client) => client.request(req).fuse(),
+ RpcClient::HTTPS(client) => client.request(req).fuse(),
+ };
+ let resp = tokio::time::timeout(timeout, resp_fut)
+ .await?
+ .map_err(|e| {
+ eprintln!("RPC client error: {}", e);
+ e
+ })?;
if resp.status() == StatusCode::OK {
let body = hyper::body::to_bytes(resp.into_body()).await?;
diff --git a/src/rpc_server.rs b/src/rpc_server.rs
index f54b5099..f42d54ac 100644
--- a/src/rpc_server.rs
+++ b/src/rpc_server.rs
@@ -4,15 +4,20 @@ use std::sync::Arc;
use bytes::IntoBuf;
use futures::future::Future;
use futures_util::future::*;
+use futures_util::stream::*;
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use serde::Serialize;
+use tokio::net::{TcpListener, TcpStream};
+use tokio_rustls::server::TlsStream;
+use tokio_rustls::TlsAcceptor;
use crate::data::rmp_to_vec_all_named;
use crate::error::Error;
use crate::proto::Message;
use crate::server::Garage;
+use crate::tls_util;
fn debug_serialize<T: Serialize>(x: T) -> String {
match serde_json::to_string(&x) {
@@ -71,9 +76,7 @@ async fn handler(
// and the request handler simply sits there waiting for the task to finish.
// (if it's cancelled, that's not an issue)
// (TODO FIXME except if garage happens to shut down at that point)
- let write_fut = async move {
- garage.block_manager.write_block(&m.hash, &m.data).await
- };
+ let write_fut = async move { garage.block_manager.write_block(&m.hash, &m.data).await };
tokio::spawn(write_fut).await?
}
Message::GetBlock(h) => garage.block_manager.read_block(&h).await,
@@ -105,25 +108,82 @@ pub async fn run_rpc_server(
) -> Result<(), Error> {
let bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], garage.system.config.rpc_port).into();
- let service = make_service_fn(|conn: &AddrStream| {
- let client_addr = conn.remote_addr();
- let garage = garage.clone();
- async move {
- Ok::<_, Error>(service_fn(move |req: Request<Body>| {
- let garage = garage.clone();
- handler(garage, req, client_addr).map_err(|e| {
- eprintln!("RPC handler error: {}", e);
- e
- })
- }))
+ if let Some(tls_config) = &garage.system.config.rpc_tls {
+ let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?;
+ let node_certs = tls_util::load_certs(&tls_config.node_cert)?;
+ let node_key = tls_util::load_private_key(&tls_config.node_key)?;
+
+ let mut ca_store = rustls::RootCertStore::empty();
+ for crt in ca_certs.iter() {
+ ca_store.add(crt)?;
}
- });
- let server = Server::bind(&bind_addr).serve(service);
+ let mut config =
+ rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store));
+ config.set_single_cert([&ca_certs[..], &node_certs[..]].concat(), node_key)?;
+ let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config)));
+
+ let mut listener = TcpListener::bind(&bind_addr).await?;
+ let incoming = listener.incoming().filter_map(|socket| async {
+ match socket {
+ Ok(stream) => match tls_acceptor.clone().accept(stream).await {
+ Ok(x) => Some(Ok::<_, hyper::Error>(x)),
+ Err(e) => {
+ eprintln!("RPC server TLS error: {}", e);
+ None
+ }
+ },
+ Err(_) => None,
+ }
+ });
+ let incoming = hyper::server::accept::from_stream(incoming);
+
+ let service = make_service_fn(|conn: &TlsStream<TcpStream>| {
+ let client_addr = conn
+ .get_ref()
+ .0
+ .peer_addr()
+ .unwrap_or(([0, 0, 0, 0], 0).into());
+ let garage = garage.clone();
+ async move {
+ Ok::<_, Error>(service_fn(move |req: Request<Body>| {
+ let garage = garage.clone();
+ handler(garage, req, client_addr).map_err(|e| {
+ eprintln!("RPC handler error: {}", e);
+ e
+ })
+ }))
+ }
+ });
+
+ let server = Server::builder(incoming).serve(service);
+
+ let graceful = server.with_graceful_shutdown(shutdown_signal);
+ println!("RPC server listening on http://{}", bind_addr);
+
+ graceful.await?;
+ } else {
+ let service = make_service_fn(|conn: &AddrStream| {
+ let client_addr = conn.remote_addr();
+ let garage = garage.clone();
+ async move {
+ Ok::<_, Error>(service_fn(move |req: Request<Body>| {
+ let garage = garage.clone();
+ handler(garage, req, client_addr).map_err(|e| {
+ eprintln!("RPC handler error: {}", e);
+ e
+ })
+ }))
+ }
+ });
+
+ let server = Server::bind(&bind_addr).serve(service);
- let graceful = server.with_graceful_shutdown(shutdown_signal);
- println!("RPC server listening on http://{}", bind_addr);
+ let graceful = server.with_graceful_shutdown(shutdown_signal);
+ println!("RPC server listening on http://{}", bind_addr);
+
+ graceful.await?;
+ }
- graceful.await?;
Ok(())
}
diff --git a/src/server.rs b/src/server.rs
index 29a2dbcb..0123eb90 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -36,14 +36,14 @@ pub struct Config {
#[serde(default = "default_replication_factor")]
pub data_replication_factor: usize,
- pub tls: TlsConfig,
+ pub rpc_tls: Option<TlsConfig>,
}
#[derive(Deserialize, Debug)]
pub struct TlsConfig {
- pub ca_cert: Option<String>,
- pub node_cert: Option<String>,
- pub node_key: Option<String>,
+ pub ca_cert: String,
+ pub node_cert: String,
+ pub node_key: String,
}
pub struct Garage {
@@ -115,7 +115,6 @@ impl Garage {
meta_rep_param.clone(),
));
-
let mut garage = Self {
db,
system: system.clone(),
diff --git a/src/tls_util.rs b/src/tls_util.rs
new file mode 100644
index 00000000..a9e16c53
--- /dev/null
+++ b/src/tls_util.rs
@@ -0,0 +1,46 @@
+use std::{fs, io};
+
+use rustls::internal::pemfile;
+
+use crate::error::Error;
+
+pub fn load_certs(filename: &str) -> Result<Vec<rustls::Certificate>, Error> {
+ let certfile = fs::File::open(&filename)?;
+ let mut reader = io::BufReader::new(certfile);
+
+ let certs = pemfile::certs(&mut reader).map_err(|_| {
+ Error::Message(format!(
+ "Could not deecode certificates from file: {}",
+ filename
+ ))
+ })?;
+
+ if certs.is_empty() {
+ return Err(Error::Message(format!(
+ "Invalid certificate file: {}",
+ filename
+ )));
+ }
+ Ok(certs)
+}
+
+pub fn load_private_key(filename: &str) -> Result<rustls::PrivateKey, Error> {
+ let keyfile = fs::File::open(&filename)?;
+ let mut reader = io::BufReader::new(keyfile);
+
+ let keys = pemfile::rsa_private_keys(&mut reader).map_err(|_| {
+ Error::Message(format!(
+ "Could not decode private key from file: {}",
+ filename
+ ))
+ })?;
+
+ if keys.len() != 1 {
+ return Err(Error::Message(format!(
+ "Invalid private key file: {} ({} private keys)",
+ filename,
+ keys.len()
+ )));
+ }
+ Ok(keys[0].clone())
+}
diff --git a/src/version_table.rs b/src/version_table.rs
index cb70c645..106527b1 100644
--- a/src/version_table.rs
+++ b/src/version_table.rs
@@ -2,9 +2,9 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
+use crate::background::BackgroundRunner;
use crate::data::*;
use crate::table::*;
-use crate::background::BackgroundRunner;
#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)]
pub struct Version {
@@ -78,9 +78,7 @@ impl TableFormat for VersionTable {
deleted: true,
})
.collect::<Vec<_>>();
- block_ref_table
- .insert_many(&deleted_block_refs[..])
- .await?;
+ block_ref_table.insert_many(&deleted_block_refs[..]).await?;
}
}
Ok(())