diff --git a/Cargo.lock b/Cargo.lock index 86f96da..c7cd38b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1170,6 +1170,7 @@ dependencies = [ "tokio-rustls", "toml", "tower", + "webpki-roots", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index aa67dd4..c5d5e1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ arc-swap = "1" ring = "0.17" rustls-pemfile = "2.2.0" qrcode = { version = "0.14", default-features = false, features = ["svg"] } +webpki-roots = "1" [dev-dependencies] criterion = { version = "0.8", features = ["html_reports"] } diff --git a/src/forward.rs b/src/forward.rs index 78efcb9..ea2f1e2 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -18,6 +18,11 @@ pub enum Upstream { url: String, client: reqwest::Client, }, + Dot { + addr: SocketAddr, + tls_name: Option, + connector: tokio_rustls::TlsConnector, + }, } impl PartialEq for Upstream { @@ -25,6 +30,7 @@ impl PartialEq for Upstream { match (self, other) { (Self::Udp(a), Self::Udp(b)) => a == b, (Self::Doh { url: a, .. }, Self::Doh { url: b, .. }) => a == b, + (Self::Dot { addr: a, .. }, Self::Dot { addr: b, .. }) => a == b, _ => false, } } @@ -35,6 +41,10 @@ impl fmt::Display for Upstream { match self { Upstream::Udp(addr) => write!(f, "{}", addr), Upstream::Doh { url, .. } => f.write_str(url), + Upstream::Dot { addr, tls_name, .. } => match tls_name { + Some(name) => write!(f, "tls://{}#{}", addr, name), + None => write!(f, "tls://{}", addr), + }, } } } @@ -62,10 +72,36 @@ pub fn parse_upstream(s: &str, default_port: u16) -> Result { client, }); } + // tls://IP:PORT#hostname or tls://IP#hostname (default port 853) + if let Some(rest) = s.strip_prefix("tls://") { + let (addr_part, tls_name) = match rest.find('#') { + Some(i) => (&rest[..i], Some(rest[i + 1..].to_string())), + None => (rest, None), + }; + let addr = parse_upstream_addr(addr_part, 853)?; + let connector = build_dot_connector()?; + return Ok(Upstream::Dot { + addr, + tls_name, + connector, + }); + } let addr = parse_upstream_addr(s, default_port)?; Ok(Upstream::Udp(addr)) } +fn build_dot_connector() -> Result { + let _ = rustls::crypto::ring::default_provider().install_default(); + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + Ok(tokio_rustls::TlsConnector::from(std::sync::Arc::new( + config, + ))) +} + #[derive(Clone)] pub struct UpstreamPool { primary: Vec, @@ -174,6 +210,11 @@ pub async fn forward_query( match upstream { Upstream::Udp(addr) => forward_udp(query, *addr, timeout_duration).await, Upstream::Doh { url, client } => forward_doh(query, url, client, timeout_duration).await, + Upstream::Dot { + addr, + tls_name, + connector, + } => forward_dot(query, *addr, tls_name, connector, timeout_duration).await, } } @@ -236,6 +277,45 @@ pub(crate) async fn forward_tcp( DnsPacket::from_buffer(&mut recv_buffer) } +async fn forward_dot( + query: &DnsPacket, + addr: SocketAddr, + tls_name: &Option, + connector: &tokio_rustls::TlsConnector, + timeout_duration: Duration, +) -> Result { + use rustls::pki_types::ServerName; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + + let server_name = match tls_name { + Some(name) => ServerName::try_from(name.clone())?, + None => ServerName::try_from(addr.ip().to_string())?, + }; + + let tcp = timeout(timeout_duration, TcpStream::connect(addr)).await??; + let mut tls = timeout(timeout_duration, connector.connect(server_name, tcp)).await??; + + let mut send_buffer = BytePacketBuffer::new(); + query.write(&mut send_buffer)?; + let wire = send_buffer.filled(); + + let mut outbuf = Vec::with_capacity(2 + wire.len()); + outbuf.extend_from_slice(&(wire.len() as u16).to_be_bytes()); + outbuf.extend_from_slice(wire); + timeout(timeout_duration, tls.write_all(&outbuf)).await??; + + let mut len_buf = [0u8; 2]; + timeout(timeout_duration, tls.read_exact(&mut len_buf)).await??; + let resp_len = u16::from_be_bytes(len_buf) as usize; + + let mut data = vec![0u8; resp_len]; + timeout(timeout_duration, tls.read_exact(&mut data)).await??; + + let mut recv_buffer = BytePacketBuffer::from_bytes(&data); + DnsPacket::from_buffer(&mut recv_buffer) +} + async fn forward_doh( query: &DnsPacket, url: &str,