diff --git a/src/dot.rs b/src/dot.rs index 487c25f..0a917dd 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -275,8 +275,9 @@ mod tests { use crate::question::QueryType; use crate::record::DnsRecord; - /// Generate a self-signed cert + key in memory, return (ServerConfig, ClientConfig). - fn test_tls_configs() -> (Arc, Arc) { + /// Generate a self-signed DoT server config and return its leaf cert DER + /// so callers can build matching client configs with arbitrary ALPN. + fn test_tls_configs() -> (Arc, CertificateDer<'static>) { let _ = rustls::crypto::ring::default_provider().install_default(); // Mirror production self_signed_tls SAN shape: *.numa wildcard plus @@ -301,22 +302,31 @@ mod tests { .unwrap(); server_config.alpn_protocols = dot_alpn(); - let mut root_store = rustls::RootCertStore::empty(); - root_store.add(cert_der).unwrap(); - let mut client_config = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - client_config.alpn_protocols = dot_alpn(); - - (Arc::new(server_config), Arc::new(client_config)) + (Arc::new(server_config), cert_der) } - /// Spin up a DoT listener with a test TLS config. Returns (addr, client_config). + /// Build a TLS client config that trusts `cert_der` and advertises the + /// given ALPN protocols. Used by tests to vary ALPN per test case. + fn dot_client( + cert_der: &CertificateDer<'static>, + alpn: Vec>, + ) -> Arc { + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(cert_der.clone()).unwrap(); + let mut config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + config.alpn_protocols = alpn; + Arc::new(config) + } + + /// Spin up a DoT listener with a test TLS config. Returns the bind addr + /// and the leaf cert DER so callers can build clients with arbitrary ALPN. /// The upstream is pointed at a bound-but-unresponsive UDP socket we own, so /// any query that escapes to the upstream path times out deterministically /// (SERVFAIL) regardless of what the host has running on port 53. - async fn spawn_dot_server() -> (SocketAddr, Arc) { - let (server_tls, client_tls) = test_tls_configs(); + async fn spawn_dot_server() -> (SocketAddr, CertificateDer<'static>) { + let (server_tls, cert_der) = test_tls_configs(); let socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap(); // Bind an unresponsive upstream and leak it so it lives for the test duration. @@ -375,7 +385,7 @@ mod tests { tokio::spawn(accept_loop(listener, acceptor, ctx)); - (addr, client_tls) + (addr, cert_der) } /// Open a TLS connection to the DoT server and return the stream. @@ -419,7 +429,8 @@ mod tests { #[tokio::test] async fn dot_resolves_local_zone() { - let (addr, client_config) = spawn_dot_server().await; + let (addr, cert_der) = spawn_dot_server().await; + let client_config = dot_client(&cert_der, dot_alpn()); let mut stream = dot_connect(addr, &client_config).await; let query = DnsPacket::query(0x1234, "dot-test.example", QueryType::A); @@ -441,7 +452,8 @@ mod tests { #[tokio::test] async fn dot_multiple_queries_on_persistent_connection() { - let (addr, client_config) = spawn_dot_server().await; + let (addr, cert_der) = spawn_dot_server().await; + let client_config = dot_client(&cert_der, dot_alpn()); let mut stream = dot_connect(addr, &client_config).await; for i in 0..3u16 { @@ -455,7 +467,8 @@ mod tests { #[tokio::test] async fn dot_nxdomain_for_unknown() { - let (addr, client_config) = spawn_dot_server().await; + let (addr, cert_der) = spawn_dot_server().await; + let client_config = dot_client(&cert_der, dot_alpn()); let mut stream = dot_connect(addr, &client_config).await; let query = DnsPacket::query(0xBEEF, "nonexistent.test", QueryType::A); @@ -472,15 +485,36 @@ mod tests { #[tokio::test] async fn dot_negotiates_alpn() { - let (addr, client_config) = spawn_dot_server().await; + let (addr, cert_der) = spawn_dot_server().await; + let client_config = dot_client(&cert_der, dot_alpn()); let stream = dot_connect(addr, &client_config).await; let (_io, conn) = stream.get_ref(); assert_eq!(conn.alpn_protocol(), Some(&b"dot"[..])); } + #[tokio::test] + async fn dot_rejects_non_dot_alpn() { + // Cross-protocol confusion defense: a client that only offers "h2" + // (e.g. an HTTP/2 client mistakenly hitting :853) must not complete + // a TLS handshake with the DoT server. Verifies the rustls server + // sends `no_application_protocol` rather than silently negotiating. + let (addr, cert_der) = spawn_dot_server().await; + let client_config = dot_client(&cert_der, vec![b"h2".to_vec()]); + let connector = tokio_rustls::TlsConnector::from(client_config); + let tcp = tokio::net::TcpStream::connect(addr).await.unwrap(); + let result = connector + .connect(ServerName::try_from("numa.numa").unwrap(), tcp) + .await; + assert!( + result.is_err(), + "DoT server must reject ALPN that doesn't include \"dot\"" + ); + } + #[tokio::test] async fn dot_concurrent_connections() { - let (addr, client_config) = spawn_dot_server().await; + let (addr, cert_der) = spawn_dot_server().await; + let client_config = dot_client(&cert_der, dot_alpn()); let mut handles = Vec::new(); for i in 0..5u16 {