aboutsummaryrefslogtreecommitdiff
path: root/src/tls_util.rs
blob: c80dcf8517f585b0e6e1eaf1f9816d9384c20aeb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
use core::future::Future;
use core::task::{Context, Poll};
use std::convert::TryFrom;
use std::io;
use std::pin::Pin;
use std::sync::Arc;

use futures_util::future::*;
use hyper::client::connect::Connection;
use hyper::client::HttpConnector;
use hyper::service::Service;
use hyper::Uri;
use hyper_rustls::MaybeHttpsStream;
use rustls::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;

#[derive(Clone)]
pub struct HttpsConnectorFixedDnsname<T> {
	http: T,
	tls_config: Arc<rustls::ClientConfig>,
	fixed_dnsname: &'static str,
}

type BoxError = Box<dyn std::error::Error + Send + Sync>;

impl HttpsConnectorFixedDnsname<HttpConnector> {
	pub fn new(
		mut tls_config: rustls::ClientConfig,
		fixed_dnsname: &'static str,
		mut http: HttpConnector,
	) -> Self {
		http.enforce_http(false);
		tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
		Self {
			http,
			tls_config: Arc::new(tls_config),
			fixed_dnsname,
		}
	}
}

impl<T> Service<Uri> for HttpsConnectorFixedDnsname<T>
where
	T: Service<Uri>,
	T::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static,
	T::Future: Send + 'static,
	T::Error: Into<BoxError>,
{
	type Response = MaybeHttpsStream<T::Response>;
	type Error = BoxError;
	#[allow(clippy::type_complexity)]
	type Future =
		Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>;

	fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
		match self.http.poll_ready(cx) {
			Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
			Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
			Poll::Pending => Poll::Pending,
		}
	}

	fn call(&mut self, dst: Uri) -> Self::Future {
		let is_https = dst.scheme_str() == Some("https");
		assert!(is_https);

		let cfg = self.tls_config.clone();
		let connecting_future = self.http.call(dst);
		let dnsname = ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname");
		let f = async move {
			let tcp = connecting_future.await.map_err(Into::into)?;
			let connector = TlsConnector::from(cfg);
			let tls = connector
				.connect(dnsname, tcp)
				.await
				.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
			Ok(MaybeHttpsStream::Https(tls))
		};
		f.boxed()
	}
}