diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/block_ref_table.rs | 4 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/main.rs | 3 | ||||
-rw-r--r-- | src/membership.rs | 4 | ||||
-rw-r--r-- | src/object_table.rs | 2 | ||||
-rw-r--r-- | src/rpc_client.rs | 53 | ||||
-rw-r--r-- | src/rpc_server.rs | 98 | ||||
-rw-r--r-- | src/server.rs | 9 | ||||
-rw-r--r-- | src/tls_util.rs | 46 | ||||
-rw-r--r-- | src/version_table.rs | 6 |
10 files changed, 189 insertions, 42 deletions
diff --git a/src/block_ref_table.rs b/src/block_ref_table.rs index 3e5fb0a1..0511ea25 100644 --- a/src/block_ref_table.rs +++ b/src/block_ref_table.rs @@ -2,10 +2,10 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::sync::Arc; -use crate::data::*; -use crate::table::*; use crate::background::*; use crate::block::*; +use crate::data::*; +use crate::table::*; #[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] pub struct BlockRef { diff --git a/src/error.rs b/src/error.rs index 661621c9..c9653f5a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,6 +16,12 @@ pub enum Error { #[error(display = "Invalid HTTP header value: {}", _0)] HTTPHeader(#[error(source)] http::header::ToStrError), + #[error(display = "TLS error: {}", _0)] + TLS(#[error(source)] rustls::TLSError), + + #[error(display = "PKI error: {}", _0)] + PKI(#[error(source)] webpki::Error), + #[error(display = "Sled error: {}", _0)] Sled(#[error(source)] sled::Error), diff --git a/src/main.rs b/src/main.rs index 533afcc7..619f3422 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,6 +16,7 @@ mod http_util; mod rpc_client; mod rpc_server; mod server; +mod tls_util; use std::collections::HashSet; use std::net::SocketAddr; @@ -76,7 +77,7 @@ pub struct ConfigureOpt { async fn main() { let opt = Opt::from_args(); - let rpc_cli = RpcClient::new(); + let rpc_cli = RpcClient::new(&None).expect("Could not create RPC client"); let resp = match opt.cmd { Command::Server(server_opt) => server::run_server(server_opt.config_file).await, diff --git a/src/membership.rs b/src/membership.rs index 22c13f64..89550b67 100644 --- a/src/membership.rs +++ b/src/membership.rs @@ -226,10 +226,12 @@ impl System { ring.rebuild_ring(); let (update_ring, ring) = watch::channel(Arc::new(ring)); + let rpc_client = RpcClient::new(&config.rpc_tls).expect("Could not create RPC client"); + System { config, id, - rpc_client: RpcClient::new(), + rpc_client, status, ring, update_lock: Mutex::new((update_status, update_ring)), diff --git a/src/object_table.rs b/src/object_table.rs index 8ce49565..a3a03372 100644 --- a/src/object_table.rs +++ b/src/object_table.rs @@ -2,9 +2,9 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::sync::Arc; +use crate::background::BackgroundRunner; use crate::data::*; use crate::table::*; -use crate::background::BackgroundRunner; #[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] pub struct Object { diff --git a/src/rpc_client.rs b/src/rpc_client.rs index 8d8b724b..247f114e 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -6,13 +6,16 @@ use bytes::IntoBuf; use futures::stream::futures_unordered::FuturesUnordered; use futures::stream::StreamExt; use futures_util::future::FutureExt; -use hyper::client::Client; +use hyper::client::{Client, HttpConnector}; use hyper::{Body, Method, Request, StatusCode}; +use hyper_rustls::HttpsConnector; use crate::data::*; use crate::error::Error; use crate::membership::System; use crate::proto::Message; +use crate::server::*; +use crate::tls_util; pub async fn rpc_call_many( sys: Arc<System>, @@ -88,14 +91,34 @@ pub async fn rpc_call( sys.rpc_client.call(&addr, msg, timeout).await } -pub struct RpcClient { - pub client: Client<hyper::client::HttpConnector, hyper::Body>, +pub enum RpcClient { + HTTP(Client<HttpConnector, hyper::Body>), + HTTPS(Client<HttpsConnector<HttpConnector>, hyper::Body>), } impl RpcClient { - pub fn new() -> Self { - RpcClient { - client: Client::new(), + pub fn new(tls_config: &Option<TlsConfig>) -> Result<Self, Error> { + if let Some(cf) = tls_config { + let ca_certs = tls_util::load_certs(&cf.ca_cert)?; + let node_certs = tls_util::load_certs(&cf.node_cert)?; + let node_key = tls_util::load_private_key(&cf.node_key)?; + + let mut config = rustls::ClientConfig::new(); + + for crt in ca_certs.iter() { + config.root_store.add(crt)?; + } + + config.set_single_client_cert([&ca_certs[..], &node_certs[..]].concat(), node_key)?; + + let mut http_connector = HttpConnector::new(); + http_connector.enforce_http(false); + let connector = + HttpsConnector::<HttpConnector>::from((http_connector, Arc::new(config))); + + Ok(RpcClient::HTTPS(Client::builder().build(connector))) + } else { + Ok(RpcClient::HTTP(Client::new())) } } @@ -105,14 +128,26 @@ impl RpcClient { msg: &Message, timeout: Duration, ) -> Result<Message, Error> { - let uri = format!("http://{}/rpc", to_addr); + let uri = match self { + RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr), + RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr), + }; + let req = Request::builder() .method(Method::POST) .uri(uri) .body(Body::from(rmp_to_vec_all_named(msg)?))?; - let resp_fut = self.client.request(req).fuse(); - let resp = tokio::time::timeout(timeout, resp_fut).await??; + let resp_fut = match self { + RpcClient::HTTP(client) => client.request(req).fuse(), + RpcClient::HTTPS(client) => client.request(req).fuse(), + }; + let resp = tokio::time::timeout(timeout, resp_fut) + .await? + .map_err(|e| { + eprintln!("RPC client error: {}", e); + e + })?; if resp.status() == StatusCode::OK { let body = hyper::body::to_bytes(resp.into_body()).await?; diff --git a/src/rpc_server.rs b/src/rpc_server.rs index f54b5099..f42d54ac 100644 --- a/src/rpc_server.rs +++ b/src/rpc_server.rs @@ -4,15 +4,20 @@ use std::sync::Arc; use bytes::IntoBuf; use futures::future::Future; use futures_util::future::*; +use futures_util::stream::*; use hyper::server::conn::AddrStream; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; use serde::Serialize; +use tokio::net::{TcpListener, TcpStream}; +use tokio_rustls::server::TlsStream; +use tokio_rustls::TlsAcceptor; use crate::data::rmp_to_vec_all_named; use crate::error::Error; use crate::proto::Message; use crate::server::Garage; +use crate::tls_util; fn debug_serialize<T: Serialize>(x: T) -> String { match serde_json::to_string(&x) { @@ -71,9 +76,7 @@ async fn handler( // and the request handler simply sits there waiting for the task to finish. // (if it's cancelled, that's not an issue) // (TODO FIXME except if garage happens to shut down at that point) - let write_fut = async move { - garage.block_manager.write_block(&m.hash, &m.data).await - }; + let write_fut = async move { garage.block_manager.write_block(&m.hash, &m.data).await }; tokio::spawn(write_fut).await? } Message::GetBlock(h) => garage.block_manager.read_block(&h).await, @@ -105,25 +108,82 @@ pub async fn run_rpc_server( ) -> Result<(), Error> { let bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], garage.system.config.rpc_port).into(); - let service = make_service_fn(|conn: &AddrStream| { - let client_addr = conn.remote_addr(); - let garage = garage.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request<Body>| { - let garage = garage.clone(); - handler(garage, req, client_addr).map_err(|e| { - eprintln!("RPC handler error: {}", e); - e - }) - })) + if let Some(tls_config) = &garage.system.config.rpc_tls { + let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; + let node_certs = tls_util::load_certs(&tls_config.node_cert)?; + let node_key = tls_util::load_private_key(&tls_config.node_key)?; + + let mut ca_store = rustls::RootCertStore::empty(); + for crt in ca_certs.iter() { + ca_store.add(crt)?; } - }); - let server = Server::bind(&bind_addr).serve(service); + let mut config = + rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store)); + config.set_single_cert([&ca_certs[..], &node_certs[..]].concat(), node_key)?; + let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); + + let mut listener = TcpListener::bind(&bind_addr).await?; + let incoming = listener.incoming().filter_map(|socket| async { + match socket { + Ok(stream) => match tls_acceptor.clone().accept(stream).await { + Ok(x) => Some(Ok::<_, hyper::Error>(x)), + Err(e) => { + eprintln!("RPC server TLS error: {}", e); + None + } + }, + Err(_) => None, + } + }); + let incoming = hyper::server::accept::from_stream(incoming); + + let service = make_service_fn(|conn: &TlsStream<TcpStream>| { + let client_addr = conn + .get_ref() + .0 + .peer_addr() + .unwrap_or(([0, 0, 0, 0], 0).into()); + let garage = garage.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request<Body>| { + let garage = garage.clone(); + handler(garage, req, client_addr).map_err(|e| { + eprintln!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::builder(incoming).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + println!("RPC server listening on http://{}", bind_addr); + + graceful.await?; + } else { + let service = make_service_fn(|conn: &AddrStream| { + let client_addr = conn.remote_addr(); + let garage = garage.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request<Body>| { + let garage = garage.clone(); + handler(garage, req, client_addr).map_err(|e| { + eprintln!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::bind(&bind_addr).serve(service); - let graceful = server.with_graceful_shutdown(shutdown_signal); - println!("RPC server listening on http://{}", bind_addr); + let graceful = server.with_graceful_shutdown(shutdown_signal); + println!("RPC server listening on http://{}", bind_addr); + + graceful.await?; + } - graceful.await?; Ok(()) } diff --git a/src/server.rs b/src/server.rs index 29a2dbcb..0123eb90 100644 --- a/src/server.rs +++ b/src/server.rs @@ -36,14 +36,14 @@ pub struct Config { #[serde(default = "default_replication_factor")] pub data_replication_factor: usize, - pub tls: TlsConfig, + pub rpc_tls: Option<TlsConfig>, } #[derive(Deserialize, Debug)] pub struct TlsConfig { - pub ca_cert: Option<String>, - pub node_cert: Option<String>, - pub node_key: Option<String>, + pub ca_cert: String, + pub node_cert: String, + pub node_key: String, } pub struct Garage { @@ -115,7 +115,6 @@ impl Garage { meta_rep_param.clone(), )); - let mut garage = Self { db, system: system.clone(), diff --git a/src/tls_util.rs b/src/tls_util.rs new file mode 100644 index 00000000..a9e16c53 --- /dev/null +++ b/src/tls_util.rs @@ -0,0 +1,46 @@ +use std::{fs, io}; + +use rustls::internal::pemfile; + +use crate::error::Error; + +pub fn load_certs(filename: &str) -> Result<Vec<rustls::Certificate>, Error> { + let certfile = fs::File::open(&filename)?; + let mut reader = io::BufReader::new(certfile); + + let certs = pemfile::certs(&mut reader).map_err(|_| { + Error::Message(format!( + "Could not deecode certificates from file: {}", + filename + )) + })?; + + if certs.is_empty() { + return Err(Error::Message(format!( + "Invalid certificate file: {}", + filename + ))); + } + Ok(certs) +} + +pub fn load_private_key(filename: &str) -> Result<rustls::PrivateKey, Error> { + let keyfile = fs::File::open(&filename)?; + let mut reader = io::BufReader::new(keyfile); + + let keys = pemfile::rsa_private_keys(&mut reader).map_err(|_| { + Error::Message(format!( + "Could not decode private key from file: {}", + filename + )) + })?; + + if keys.len() != 1 { + return Err(Error::Message(format!( + "Invalid private key file: {} ({} private keys)", + filename, + keys.len() + ))); + } + Ok(keys[0].clone()) +} diff --git a/src/version_table.rs b/src/version_table.rs index cb70c645..106527b1 100644 --- a/src/version_table.rs +++ b/src/version_table.rs @@ -2,9 +2,9 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::sync::Arc; +use crate::background::BackgroundRunner; use crate::data::*; use crate::table::*; -use crate::background::BackgroundRunner; #[derive(PartialEq, Clone, Debug, Serialize, Deserialize)] pub struct Version { @@ -78,9 +78,7 @@ impl TableFormat for VersionTable { deleted: true, }) .collect::<Vec<_>>(); - block_ref_table - .insert_many(&deleted_block_refs[..]) - .await?; + block_ref_table.insert_many(&deleted_block_refs[..]).await?; } } Ok(()) |