diff --git a/Cargo.toml b/Cargo.toml index d7f6f9f..6ab0972 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ tower = { version = "0.5", features = ["util"] } http = "1" hickory-resolver = { version = "0.25", features = ["https-ring", "webpki-roots"] } hickory-proto = "0.25" +x509-parser = "0.18" [[bench]] name = "hot_path" diff --git a/src/doh.rs b/src/doh.rs index 7325688..f90b919 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -49,18 +49,43 @@ pub async fn doh_post(State(state): State, req: Request) } fn is_doh_host(host: Option<&str>, tld: &str) -> bool { - match host { - Some(h) if h == tld => true, - Some(h) => { - h.len() == 2 * tld.len() + 1 - && h.starts_with(tld) - && h.as_bytes().get(tld.len()) == Some(&b'.') - && h.ends_with(tld) + let h = match host { + Some(h) => h, + None => return false, + }; + let base = strip_port(h).unwrap_or(h); + is_loopback_host(base) || is_tld_match(base, tld) +} + +fn strip_port(h: &str) -> Option<&str> { + if h.starts_with('[') { + // [::1]:443 → [::1] + let (base, port) = h.rsplit_once("]:")?; + port.bytes() + .all(|b| b.is_ascii_digit()) + .then(|| &h[..base.len() + 1]) + } else { + let (base, port) = h.rsplit_once(':')?; + // Bare IPv6 like "::1" has multiple colons — not a port suffix + if base.contains(':') { + return None; } - None => false, + port.bytes().all(|b| b.is_ascii_digit()).then_some(base) } } +fn is_loopback_host(h: &str) -> bool { + matches!(h, "127.0.0.1" | "::1" | "[::1]" | "localhost") +} + +fn is_tld_match(h: &str, tld: &str) -> bool { + h == tld + || (h.len() == 2 * tld.len() + 1 + && h.starts_with(tld) + && h.as_bytes().get(tld.len()) == Some(&b'.') + && h.ends_with(tld)) +} + async fn resolve_doh( dns_bytes: &[u8], src: SocketAddr, @@ -148,6 +173,13 @@ mod tests { fn is_doh_host_matches_tld() { assert!(is_doh_host(Some("numa"), "numa")); assert!(is_doh_host(Some("numa.numa"), "numa")); + assert!(is_doh_host(Some("127.0.0.1"), "numa")); + assert!(is_doh_host(Some("127.0.0.1:443"), "numa")); + assert!(is_doh_host(Some("::1"), "numa")); + assert!(is_doh_host(Some("[::1]"), "numa")); + assert!(is_doh_host(Some("[::1]:443"), "numa")); + assert!(is_doh_host(Some("localhost"), "numa")); + assert!(is_doh_host(Some("localhost:443"), "numa")); assert!(!is_doh_host(Some("foo.numa"), "numa")); assert!(!is_doh_host(None, "numa")); } diff --git a/src/tls.rs b/src/tls.rs index e9e2f59..22a00a4 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -186,8 +186,19 @@ fn generate_service_cert( } } - if sans.is_empty() { - return Err("no valid service names for TLS cert".into()); + // Loopback IP SANs so browsers can reach DoH at https://127.0.0.1/dns-query + sans.push(SanType::IpAddress(std::net::IpAddr::V4( + std::net::Ipv4Addr::LOCALHOST, + ))); + sans.push(SanType::IpAddress(std::net::IpAddr::V6( + std::net::Ipv6Addr::LOCALHOST, + ))); + + for name in ["localhost", tld] { + match name.to_string().try_into() { + Ok(ia5) => sans.push(SanType::DnsName(ia5)), + Err(e) => warn!("invalid SAN {}: {}", name, e), + } } params.subject_alt_names = sans; @@ -240,4 +251,72 @@ mod tests { let err: crate::Error = "rcgen failure".into(); assert!(try_data_dir_advisory(&err, &PathBuf::from("/x")).is_none()); } + + #[test] + fn service_cert_contains_expected_sans() { + use x509_parser::prelude::GeneralName; + + let dir = std::env::temp_dir().join(format!("numa-test-san-{}", std::process::id())); + let _ = std::fs::remove_dir_all(&dir); + let (ca_der, issuer) = ensure_ca(&dir).unwrap(); + + let names = vec!["grafana".into(), "router".into()]; + let (chain, _) = generate_service_cert(&ca_der, &issuer, "numa", &names).unwrap(); + assert_eq!(chain.len(), 2, "chain should be [leaf, CA]"); + + let (_, cert) = x509_parser::parse_x509_certificate(chain[0].as_ref()).unwrap(); + let san = cert + .tbs_certificate + .subject_alternative_name() + .unwrap() + .unwrap(); + + let dns: Vec<&str> = san + .value + .general_names + .iter() + .filter_map(|gn| match gn { + GeneralName::DNSName(s) => Some(*s), + _ => None, + }) + .collect(); + + let ips: Vec = san + .value + .general_names + .iter() + .filter_map(|gn| match gn { + GeneralName::IPAddress(b) => match b.len() { + 4 => Some(std::net::IpAddr::V4(std::net::Ipv4Addr::new( + b[0], b[1], b[2], b[3], + ))), + 16 => { + let a: [u8; 16] = (*b).try_into().unwrap(); + Some(std::net::IpAddr::V6(std::net::Ipv6Addr::from(a))) + } + _ => None, + }, + _ => None, + }) + .collect(); + + // DNS SANs + assert!(dns.contains(&"*.numa"), "missing wildcard SAN"); + assert!(dns.contains(&"grafana.numa"), "missing service SAN"); + assert!(dns.contains(&"router.numa"), "missing service SAN"); + assert!(dns.contains(&"localhost"), "missing localhost SAN"); + assert!(dns.contains(&"numa"), "missing bare TLD SAN"); + + // IP SANs + assert!( + ips.contains(&std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)), + "missing 127.0.0.1 SAN" + ); + assert!( + ips.contains(&std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)), + "missing ::1 SAN" + ); + + let _ = std::fs::remove_dir_all(&dir); + } }