diff --git a/src/doh.rs b/src/doh.rs index 7325688..917e039 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -49,16 +49,25 @@ pub async fn doh_post(State(state): State, req: Request) } fn is_doh_host(host: Option<&str>, tld: &str) -> bool { - match host { - Some(h) if h == tld => true, - Some(h) => { - h.len() == 2 * tld.len() + 1 - && h.starts_with(tld) - && h.as_bytes().get(tld.len()) == Some(&b'.') - && h.ends_with(tld) - } - None => false, - } + let h = match host { + Some(h) => h, + None => return false, + }; + is_doh_name(h, tld) + || h.rsplit_once(':').is_some_and(|(base, port)| { + port.bytes().all(|b| b.is_ascii_digit()) && is_doh_name(base, tld) + }) +} + +fn is_doh_name(h: &str, tld: &str) -> bool { + h == tld + || (h.len() == 2 * tld.len() + 1 + && h.starts_with(tld) + && h.as_bytes().get(tld.len()) == Some(&b'.') + && h.ends_with(tld)) + || h == "127.0.0.1" + || h == "::1" + || h == "localhost" } async fn resolve_doh( @@ -148,6 +157,10 @@ mod tests { fn is_doh_host_matches_tld() { assert!(is_doh_host(Some("numa"), "numa")); assert!(is_doh_host(Some("numa.numa"), "numa")); + assert!(is_doh_host(Some("127.0.0.1"), "numa")); + assert!(is_doh_host(Some("127.0.0.1:443"), "numa")); + assert!(is_doh_host(Some("::1"), "numa")); + assert!(is_doh_host(Some("localhost"), "numa")); assert!(!is_doh_host(Some("foo.numa"), "numa")); assert!(!is_doh_host(None, "numa")); } diff --git a/src/tls.rs b/src/tls.rs index e9e2f59..2443f4f 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -186,6 +186,20 @@ fn generate_service_cert( } } + // Loopback IP SANs so browsers can reach DoH at https://127.0.0.1/dns-query + sans.push(SanType::IpAddress(std::net::IpAddr::V4( + std::net::Ipv4Addr::LOCALHOST, + ))); + sans.push(SanType::IpAddress(std::net::IpAddr::V6( + std::net::Ipv6Addr::LOCALHOST, + ))); + + // Bare TLD (e.g. "numa") for DoH via https://numa/dns-query + match tld.to_string().try_into() { + Ok(ia5) => sans.push(SanType::DnsName(ia5)), + Err(e) => warn!("invalid SAN {}: {}", tld, e), + } + if sans.is_empty() { return Err("no valid service names for TLS cert".into()); }