935 lines
32 KiB
Rust
935 lines
32 KiB
Rust
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::path::PathBuf;
|
|
use std::sync::{Mutex, RwLock};
|
|
use std::time::{Duration, Instant, SystemTime};
|
|
|
|
use arc_swap::ArcSwap;
|
|
use log::{debug, error, info, warn};
|
|
use rustls::ServerConfig;
|
|
use tokio::net::UdpSocket;
|
|
use tokio::sync::broadcast;
|
|
|
|
type InflightMap = HashMap<(String, QueryType), broadcast::Sender<Option<DnsPacket>>>;
|
|
|
|
use crate::blocklist::BlocklistStore;
|
|
use crate::buffer::BytePacketBuffer;
|
|
use crate::cache::{DnsCache, DnssecStatus};
|
|
use crate::config::{UpstreamMode, ZoneMap};
|
|
use crate::forward::{forward_query, Upstream};
|
|
use crate::header::ResultCode;
|
|
use crate::lan::PeerStore;
|
|
use crate::override_store::OverrideStore;
|
|
use crate::packet::DnsPacket;
|
|
use crate::query_log::{QueryLog, QueryLogEntry};
|
|
use crate::question::QueryType;
|
|
use crate::record::DnsRecord;
|
|
use crate::service_store::ServiceStore;
|
|
use crate::srtt::SrttCache;
|
|
use crate::stats::{QueryPath, ServerStats};
|
|
use crate::system_dns::ForwardingRule;
|
|
|
|
pub struct ServerCtx {
|
|
pub socket: UdpSocket,
|
|
pub zone_map: ZoneMap,
|
|
/// std::sync::RwLock (not tokio) — locks must never be held across .await points.
|
|
pub cache: RwLock<DnsCache>,
|
|
pub stats: Mutex<ServerStats>,
|
|
pub overrides: RwLock<OverrideStore>,
|
|
pub blocklist: RwLock<BlocklistStore>,
|
|
pub query_log: Mutex<QueryLog>,
|
|
pub services: Mutex<ServiceStore>,
|
|
pub lan_peers: Mutex<PeerStore>,
|
|
pub forwarding_rules: Vec<ForwardingRule>,
|
|
pub upstream: Mutex<Upstream>,
|
|
pub upstream_auto: bool,
|
|
pub upstream_port: u16,
|
|
pub lan_ip: Mutex<std::net::Ipv4Addr>,
|
|
pub timeout: Duration,
|
|
pub proxy_tld: String,
|
|
pub proxy_tld_suffix: String, // pre-computed ".{tld}" to avoid per-query allocation
|
|
pub lan_enabled: bool,
|
|
pub config_path: String,
|
|
pub config_found: bool,
|
|
pub config_dir: PathBuf,
|
|
pub data_dir: PathBuf,
|
|
pub tls_config: Option<ArcSwap<ServerConfig>>,
|
|
pub upstream_mode: UpstreamMode,
|
|
pub root_hints: Vec<SocketAddr>,
|
|
pub srtt: RwLock<SrttCache>,
|
|
pub inflight: Mutex<InflightMap>,
|
|
pub dnssec_enabled: bool,
|
|
pub dnssec_strict: bool,
|
|
}
|
|
|
|
/// Transport-agnostic DNS resolution. Runs the full pipeline (overrides, blocklist,
|
|
/// cache, upstream, DNSSEC) and returns the serialized response in a buffer.
|
|
/// Callers use `.filled()` to get the response bytes without heap allocation.
|
|
pub async fn resolve_query(
|
|
mut buffer: BytePacketBuffer,
|
|
src_addr: SocketAddr,
|
|
ctx: &ServerCtx,
|
|
) -> crate::Result<BytePacketBuffer> {
|
|
let start = Instant::now();
|
|
|
|
let query = match DnsPacket::from_buffer(&mut buffer) {
|
|
Ok(packet) => packet,
|
|
Err(e) => {
|
|
warn!("{} | PARSE ERROR | {}", src_addr, e);
|
|
return Err(e);
|
|
}
|
|
};
|
|
|
|
let (qname, qtype) = match query.questions.first() {
|
|
Some(q) => (q.name.clone(), q.qtype),
|
|
None => return Err("empty question section".into()),
|
|
};
|
|
|
|
// Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream
|
|
// Each lock is scoped to avoid holding MutexGuard across await points.
|
|
let (response, path, dnssec) = {
|
|
let override_record = ctx.overrides.read().unwrap().lookup(&qname);
|
|
if let Some(record) = override_record {
|
|
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
|
resp.answers.push(record);
|
|
(resp, QueryPath::Overridden, DnssecStatus::Indeterminate)
|
|
} else if qname == "localhost" || qname.ends_with(".localhost") {
|
|
// RFC 6761: .localhost always resolves to loopback
|
|
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
|
resp.answers.push(sinkhole_record(
|
|
&qname,
|
|
qtype,
|
|
std::net::Ipv4Addr::LOCALHOST,
|
|
std::net::Ipv6Addr::LOCALHOST,
|
|
300,
|
|
));
|
|
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
|
|
} else if is_special_use_domain(&qname) {
|
|
// RFC 6761/8880: private PTR, DDR, NAT64 — answer locally
|
|
let resp = special_use_response(&query, &qname, qtype);
|
|
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
|
|
} else if !ctx.proxy_tld_suffix.is_empty()
|
|
&& (qname.ends_with(&ctx.proxy_tld_suffix) || qname == ctx.proxy_tld)
|
|
{
|
|
// Resolve .numa: remote clients get LAN IP (can't reach 127.0.0.1), local get loopback
|
|
let service_name = qname.strip_suffix(&ctx.proxy_tld_suffix).unwrap_or(&qname);
|
|
let is_remote = !src_addr.ip().is_loopback();
|
|
let resolve_ip = {
|
|
let local = ctx.services.lock().unwrap();
|
|
if local.lookup(service_name).is_some() {
|
|
if is_remote {
|
|
*ctx.lan_ip.lock().unwrap()
|
|
} else {
|
|
std::net::Ipv4Addr::LOCALHOST
|
|
}
|
|
} else {
|
|
let mut peers = ctx.lan_peers.lock().unwrap();
|
|
peers
|
|
.lookup(service_name)
|
|
.and_then(|(ip, _)| match ip {
|
|
std::net::IpAddr::V4(v4) => Some(v4),
|
|
_ => None,
|
|
})
|
|
.unwrap_or(std::net::Ipv4Addr::LOCALHOST)
|
|
}
|
|
};
|
|
let v6 = if resolve_ip == std::net::Ipv4Addr::LOCALHOST {
|
|
std::net::Ipv6Addr::LOCALHOST
|
|
} else {
|
|
resolve_ip.to_ipv6_mapped()
|
|
};
|
|
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
|
resp.answers
|
|
.push(sinkhole_record(&qname, qtype, resolve_ip, v6, 300));
|
|
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
|
|
} else if ctx.blocklist.read().unwrap().is_blocked(&qname) {
|
|
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
|
resp.answers.push(sinkhole_record(
|
|
&qname,
|
|
qtype,
|
|
std::net::Ipv4Addr::UNSPECIFIED,
|
|
std::net::Ipv6Addr::UNSPECIFIED,
|
|
60,
|
|
));
|
|
(resp, QueryPath::Blocked, DnssecStatus::Indeterminate)
|
|
} else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) {
|
|
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
|
resp.answers = records.clone();
|
|
(resp, QueryPath::Local, DnssecStatus::Indeterminate)
|
|
} else {
|
|
let cached = ctx.cache.read().unwrap().lookup_with_status(&qname, qtype);
|
|
if let Some((cached, cached_dnssec)) = cached {
|
|
let mut resp = cached;
|
|
resp.header.id = query.header.id;
|
|
if cached_dnssec == DnssecStatus::Secure {
|
|
resp.header.authed_data = true;
|
|
}
|
|
(resp, QueryPath::Cached, cached_dnssec)
|
|
} else if let Some(fwd_addr) =
|
|
crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules)
|
|
{
|
|
// Conditional forwarding takes priority over recursive mode
|
|
// (e.g. Tailscale .ts.net, VPC private zones)
|
|
let upstream = Upstream::Udp(fwd_addr);
|
|
match forward_query(&query, &upstream, ctx.timeout).await {
|
|
Ok(resp) => {
|
|
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
|
|
(resp, QueryPath::Forwarded, DnssecStatus::Indeterminate)
|
|
}
|
|
Err(e) => {
|
|
error!(
|
|
"{} | {:?} {} | FORWARD ERROR | {}",
|
|
src_addr, qtype, qname, e
|
|
);
|
|
(
|
|
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
|
QueryPath::UpstreamError,
|
|
DnssecStatus::Indeterminate,
|
|
)
|
|
}
|
|
}
|
|
} else if ctx.upstream_mode == UpstreamMode::Recursive {
|
|
let key = (qname.clone(), qtype);
|
|
let (resp, path, err) = resolve_coalesced(&ctx.inflight, key, &query, || {
|
|
crate::recursive::resolve_recursive(
|
|
&qname,
|
|
qtype,
|
|
&ctx.cache,
|
|
&query,
|
|
&ctx.root_hints,
|
|
&ctx.srtt,
|
|
)
|
|
})
|
|
.await;
|
|
if path == QueryPath::Coalesced {
|
|
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
|
|
} else if path == QueryPath::UpstreamError {
|
|
error!(
|
|
"{} | {:?} {} | RECURSIVE ERROR | {}",
|
|
src_addr,
|
|
qtype,
|
|
qname,
|
|
err.as_deref().unwrap_or("leader failed")
|
|
);
|
|
}
|
|
(resp, path, DnssecStatus::Indeterminate)
|
|
} else {
|
|
let upstream =
|
|
match crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules) {
|
|
Some(addr) => Upstream::Udp(addr),
|
|
None => ctx.upstream.lock().unwrap().clone(),
|
|
};
|
|
match forward_query(&query, &upstream, ctx.timeout).await {
|
|
Ok(resp) => {
|
|
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
|
|
(resp, QueryPath::Forwarded, DnssecStatus::Indeterminate)
|
|
}
|
|
Err(e) => {
|
|
error!(
|
|
"{} | {:?} {} | UPSTREAM ERROR | {}",
|
|
src_addr, qtype, qname, e
|
|
);
|
|
(
|
|
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
|
|
QueryPath::UpstreamError,
|
|
DnssecStatus::Indeterminate,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
let client_do = query.edns.as_ref().is_some_and(|e| e.do_bit);
|
|
let mut response = response;
|
|
|
|
// DNSSEC validation (recursive/forwarded responses only)
|
|
let mut dnssec = dnssec;
|
|
if ctx.dnssec_enabled && path == QueryPath::Recursive {
|
|
let (status, vstats) =
|
|
crate::dnssec::validate_response(&response, &ctx.cache, &ctx.root_hints, &ctx.srtt)
|
|
.await;
|
|
|
|
debug!(
|
|
"DNSSEC | {} | {:?} | {}ms | dnskey_hit={} dnskey_fetch={} ds_hit={} ds_fetch={}",
|
|
qname,
|
|
status,
|
|
vstats.elapsed_ms,
|
|
vstats.dnskey_cache_hits,
|
|
vstats.dnskey_fetches,
|
|
vstats.ds_cache_hits,
|
|
vstats.ds_fetches,
|
|
);
|
|
|
|
dnssec = status;
|
|
|
|
if status == DnssecStatus::Secure {
|
|
response.header.authed_data = true;
|
|
}
|
|
|
|
if status == DnssecStatus::Bogus && ctx.dnssec_strict {
|
|
response = DnsPacket::response_from(&query, ResultCode::SERVFAIL);
|
|
}
|
|
|
|
ctx.cache
|
|
.write()
|
|
.unwrap()
|
|
.insert_with_status(&qname, qtype, &response, status);
|
|
}
|
|
|
|
// Strip DNSSEC records if client didn't set DO bit
|
|
if !client_do {
|
|
strip_dnssec_records(&mut response);
|
|
}
|
|
|
|
// Echo EDNS back if client sent it
|
|
if query.edns.is_some() {
|
|
response.edns = Some(crate::packet::EdnsOpt {
|
|
do_bit: client_do,
|
|
..Default::default()
|
|
});
|
|
}
|
|
|
|
let elapsed = start.elapsed();
|
|
|
|
info!(
|
|
"{} | {:?} {} | {} | {} | {}ms",
|
|
src_addr,
|
|
qtype,
|
|
qname,
|
|
path.as_str(),
|
|
response.header.rescode.as_str(),
|
|
elapsed.as_millis(),
|
|
);
|
|
|
|
debug!(
|
|
"response: {} answers, {} authorities, {} resources",
|
|
response.answers.len(),
|
|
response.authorities.len(),
|
|
response.resources.len(),
|
|
);
|
|
|
|
// Serialize response
|
|
// TODO: TC bit is UDP-specific; DoT connections could carry up to 65535 bytes.
|
|
// Once BytePacketBuffer supports larger buffers, skip truncation for TCP/TLS.
|
|
let mut resp_buffer = BytePacketBuffer::new();
|
|
if response.write(&mut resp_buffer).is_err() {
|
|
// Response too large — set TC bit and send header + question only
|
|
debug!("response too large, setting TC bit for {}", qname);
|
|
let mut tc_response = DnsPacket::response_from(&query, response.header.rescode);
|
|
tc_response.header.truncated_message = true;
|
|
resp_buffer = BytePacketBuffer::new();
|
|
tc_response.write(&mut resp_buffer)?;
|
|
}
|
|
|
|
// Record stats and query log
|
|
{
|
|
let mut s = ctx.stats.lock().unwrap();
|
|
let total = s.record(path);
|
|
if total.is_multiple_of(1000) {
|
|
s.log_summary();
|
|
}
|
|
}
|
|
|
|
ctx.query_log.lock().unwrap().push(QueryLogEntry {
|
|
timestamp: SystemTime::now(),
|
|
src_addr,
|
|
domain: qname,
|
|
query_type: qtype,
|
|
path,
|
|
rescode: response.header.rescode,
|
|
latency_us: elapsed.as_micros() as u64,
|
|
dnssec,
|
|
});
|
|
|
|
Ok(resp_buffer)
|
|
}
|
|
|
|
/// Handle a DNS query received over UDP. Thin wrapper around resolve_query.
|
|
pub async fn handle_query(
|
|
buffer: BytePacketBuffer,
|
|
src_addr: SocketAddr,
|
|
ctx: &ServerCtx,
|
|
) -> crate::Result<()> {
|
|
match resolve_query(buffer, src_addr, ctx).await {
|
|
Ok(resp_buffer) => {
|
|
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
|
|
}
|
|
Err(e) => {
|
|
warn!("{} | RESOLVE ERROR | {}", src_addr, e);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn is_dnssec_record(r: &DnsRecord) -> bool {
|
|
matches!(
|
|
r.query_type(),
|
|
QueryType::RRSIG | QueryType::DNSKEY | QueryType::DS | QueryType::NSEC | QueryType::NSEC3
|
|
)
|
|
}
|
|
|
|
fn strip_dnssec_records(pkt: &mut DnsPacket) {
|
|
pkt.answers.retain(|r| !is_dnssec_record(r));
|
|
pkt.authorities.retain(|r| !is_dnssec_record(r));
|
|
pkt.resources.retain(|r| !is_dnssec_record(r));
|
|
}
|
|
|
|
fn is_special_use_domain(qname: &str) -> bool {
|
|
if qname.ends_with(".in-addr.arpa") {
|
|
// RFC 6303: private + loopback + link-local reverse DNS
|
|
if qname.ends_with(".10.in-addr.arpa")
|
|
|| qname.ends_with(".168.192.in-addr.arpa")
|
|
|| qname.ends_with(".127.in-addr.arpa")
|
|
|| qname.ends_with(".254.169.in-addr.arpa")
|
|
|| qname.ends_with(".0.in-addr.arpa")
|
|
|| qname.contains("_dns-sd._udp")
|
|
{
|
|
return true;
|
|
}
|
|
// 172.16-31.x.x (RFC 1918) — extract second octet from reverse name
|
|
if qname.ends_with(".172.in-addr.arpa") {
|
|
if let Some(octet_str) = qname
|
|
.strip_suffix(".172.in-addr.arpa")
|
|
.and_then(|s| s.rsplit('.').next())
|
|
{
|
|
if let Ok(octet) = octet_str.parse::<u8>() {
|
|
return (16..=31).contains(&octet);
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
// DDR (RFC 9462)
|
|
if qname == "_dns.resolver.arpa" || qname.ends_with("._dns.resolver.arpa") {
|
|
return true;
|
|
}
|
|
// NAT64 (RFC 8880)
|
|
if qname == "ipv4only.arpa" {
|
|
return true;
|
|
}
|
|
// RFC 6762: .local is reserved for mDNS — never forward to upstream
|
|
qname == "local" || qname.ends_with(".local")
|
|
}
|
|
|
|
fn sinkhole_record(
|
|
domain: &str,
|
|
qtype: QueryType,
|
|
v4: std::net::Ipv4Addr,
|
|
v6: std::net::Ipv6Addr,
|
|
ttl: u32,
|
|
) -> DnsRecord {
|
|
match qtype {
|
|
QueryType::AAAA => DnsRecord::AAAA {
|
|
domain: domain.to_string(),
|
|
addr: v6,
|
|
ttl,
|
|
},
|
|
_ => DnsRecord::A {
|
|
domain: domain.to_string(),
|
|
addr: v4,
|
|
ttl,
|
|
},
|
|
}
|
|
}
|
|
|
|
enum Disposition {
|
|
Leader(broadcast::Sender<Option<DnsPacket>>),
|
|
Follower(broadcast::Receiver<Option<DnsPacket>>),
|
|
}
|
|
|
|
fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) -> Disposition {
|
|
let mut map = inflight.lock().unwrap();
|
|
if let Some(tx) = map.get(&key) {
|
|
Disposition::Follower(tx.subscribe())
|
|
} else {
|
|
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
|
map.insert(key, tx.clone());
|
|
Disposition::Leader(tx)
|
|
}
|
|
}
|
|
|
|
/// Run a resolve function with in-flight coalescing. Multiple concurrent calls
|
|
/// for the same key share a single resolution — the first caller (leader)
|
|
/// executes `resolve_fn`, and followers wait for the broadcast result.
|
|
async fn resolve_coalesced<F, Fut>(
|
|
inflight: &Mutex<InflightMap>,
|
|
key: (String, QueryType),
|
|
query: &DnsPacket,
|
|
resolve_fn: F,
|
|
) -> (DnsPacket, QueryPath, Option<String>)
|
|
where
|
|
F: FnOnce() -> Fut,
|
|
Fut: std::future::Future<Output = crate::Result<DnsPacket>>,
|
|
{
|
|
let disposition = acquire_inflight(inflight, key.clone());
|
|
|
|
match disposition {
|
|
Disposition::Follower(mut rx) => match rx.recv().await {
|
|
Ok(Some(mut resp)) => {
|
|
resp.header.id = query.header.id;
|
|
(resp, QueryPath::Coalesced, None)
|
|
}
|
|
_ => (
|
|
DnsPacket::response_from(query, ResultCode::SERVFAIL),
|
|
QueryPath::UpstreamError,
|
|
None,
|
|
),
|
|
},
|
|
Disposition::Leader(tx) => {
|
|
let guard = InflightGuard { inflight, key };
|
|
let result = resolve_fn().await;
|
|
drop(guard);
|
|
|
|
match result {
|
|
Ok(resp) => {
|
|
let _ = tx.send(Some(resp.clone()));
|
|
(resp, QueryPath::Recursive, None)
|
|
}
|
|
Err(e) => {
|
|
let _ = tx.send(None);
|
|
let err_msg = e.to_string();
|
|
(
|
|
DnsPacket::response_from(query, ResultCode::SERVFAIL),
|
|
QueryPath::UpstreamError,
|
|
Some(err_msg),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct InflightGuard<'a> {
|
|
inflight: &'a Mutex<InflightMap>,
|
|
key: (String, QueryType),
|
|
}
|
|
|
|
impl Drop for InflightGuard<'_> {
|
|
fn drop(&mut self) {
|
|
self.inflight.lock().unwrap().remove(&self.key);
|
|
}
|
|
}
|
|
|
|
fn special_use_response(query: &DnsPacket, qname: &str, qtype: QueryType) -> DnsPacket {
|
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
|
if qname == "ipv4only.arpa" {
|
|
// RFC 8880: well-known NAT64 addresses
|
|
let mut resp = DnsPacket::response_from(query, ResultCode::NOERROR);
|
|
let domain = qname.to_string();
|
|
match qtype {
|
|
QueryType::A => {
|
|
resp.answers.push(DnsRecord::A {
|
|
domain: domain.clone(),
|
|
addr: Ipv4Addr::new(192, 0, 0, 170),
|
|
ttl: 300,
|
|
});
|
|
resp.answers.push(DnsRecord::A {
|
|
domain,
|
|
addr: Ipv4Addr::new(192, 0, 0, 171),
|
|
ttl: 300,
|
|
});
|
|
}
|
|
QueryType::AAAA => {
|
|
resp.answers.push(DnsRecord::AAAA {
|
|
domain,
|
|
addr: Ipv6Addr::new(0x0064, 0xff9b, 0, 0, 0, 0, 0xc000, 0x00aa),
|
|
ttl: 300,
|
|
});
|
|
}
|
|
_ => {}
|
|
}
|
|
resp
|
|
} else {
|
|
DnsPacket::response_from(query, ResultCode::NXDOMAIN)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::collections::HashMap;
|
|
use std::net::Ipv4Addr;
|
|
use std::sync::{Arc, Mutex};
|
|
use tokio::sync::broadcast;
|
|
|
|
// ---- InflightGuard unit tests ----
|
|
|
|
#[test]
|
|
fn inflight_guard_removes_key_on_drop() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key = ("example.com".to_string(), QueryType::A);
|
|
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
|
map.lock().unwrap().insert(key.clone(), tx);
|
|
|
|
assert_eq!(map.lock().unwrap().len(), 1);
|
|
{
|
|
let _guard = InflightGuard {
|
|
inflight: &map,
|
|
key: key.clone(),
|
|
};
|
|
} // guard dropped here
|
|
assert!(map.lock().unwrap().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn inflight_guard_only_removes_own_key() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key_a = ("a.com".to_string(), QueryType::A);
|
|
let key_b = ("b.com".to_string(), QueryType::A);
|
|
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
|
let (tx_b, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
|
map.lock().unwrap().insert(key_a.clone(), tx_a);
|
|
map.lock().unwrap().insert(key_b.clone(), tx_b);
|
|
|
|
{
|
|
let _guard = InflightGuard {
|
|
inflight: &map,
|
|
key: key_a,
|
|
};
|
|
}
|
|
let m = map.lock().unwrap();
|
|
assert_eq!(m.len(), 1);
|
|
assert!(m.contains_key(&key_b));
|
|
}
|
|
|
|
#[test]
|
|
fn inflight_guard_same_domain_different_qtype_independent() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key_a = ("example.com".to_string(), QueryType::A);
|
|
let key_aaaa = ("example.com".to_string(), QueryType::AAAA);
|
|
let (tx_a, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
|
let (tx_aaaa, _) = broadcast::channel::<Option<DnsPacket>>(1);
|
|
map.lock().unwrap().insert(key_a.clone(), tx_a);
|
|
map.lock().unwrap().insert(key_aaaa.clone(), tx_aaaa);
|
|
|
|
{
|
|
let _guard = InflightGuard {
|
|
inflight: &map,
|
|
key: key_a,
|
|
};
|
|
}
|
|
let m = map.lock().unwrap();
|
|
assert_eq!(m.len(), 1);
|
|
assert!(m.contains_key(&key_aaaa));
|
|
}
|
|
|
|
// ---- Coalescing disposition tests (via acquire_inflight) ----
|
|
|
|
#[test]
|
|
fn first_caller_becomes_leader() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key = ("test.com".to_string(), QueryType::A);
|
|
|
|
let d = acquire_inflight(&map, key.clone());
|
|
assert!(matches!(d, Disposition::Leader(_)));
|
|
assert_eq!(map.lock().unwrap().len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn second_caller_becomes_follower() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key = ("test.com".to_string(), QueryType::A);
|
|
|
|
let _leader = acquire_inflight(&map, key.clone());
|
|
let follower = acquire_inflight(&map, key);
|
|
assert!(matches!(follower, Disposition::Follower(_)));
|
|
// Map still has exactly 1 entry — follower subscribes, doesn't insert
|
|
assert_eq!(map.lock().unwrap().len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn leader_broadcast_reaches_follower() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key = ("test.com".to_string(), QueryType::A);
|
|
|
|
let leader = acquire_inflight(&map, key.clone());
|
|
let follower = acquire_inflight(&map, key);
|
|
|
|
let tx = match leader {
|
|
Disposition::Leader(tx) => tx,
|
|
_ => panic!("expected leader"),
|
|
};
|
|
let mut rx = match follower {
|
|
Disposition::Follower(rx) => rx,
|
|
_ => panic!("expected follower"),
|
|
};
|
|
|
|
let mut resp = DnsPacket::new();
|
|
resp.header.id = 42;
|
|
resp.answers.push(DnsRecord::A {
|
|
domain: "test.com".into(),
|
|
addr: Ipv4Addr::new(1, 2, 3, 4),
|
|
ttl: 300,
|
|
});
|
|
let _ = tx.send(Some(resp));
|
|
|
|
let received = rx.recv().await.unwrap().unwrap();
|
|
assert_eq!(received.header.id, 42);
|
|
assert_eq!(received.answers.len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn leader_none_signals_failure_to_follower() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key = ("test.com".to_string(), QueryType::A);
|
|
|
|
let leader = acquire_inflight(&map, key.clone());
|
|
let follower = acquire_inflight(&map, key);
|
|
|
|
let tx = match leader {
|
|
Disposition::Leader(tx) => tx,
|
|
_ => panic!("expected leader"),
|
|
};
|
|
let mut rx = match follower {
|
|
Disposition::Follower(rx) => rx,
|
|
_ => panic!("expected follower"),
|
|
};
|
|
|
|
let _ = tx.send(None);
|
|
assert!(rx.recv().await.unwrap().is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn multiple_followers_all_receive_via_acquire() {
|
|
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let key = ("multi.com".to_string(), QueryType::A);
|
|
|
|
let leader = acquire_inflight(&map, key.clone());
|
|
let f1 = acquire_inflight(&map, key.clone());
|
|
let f2 = acquire_inflight(&map, key.clone());
|
|
let f3 = acquire_inflight(&map, key);
|
|
|
|
let tx = match leader {
|
|
Disposition::Leader(tx) => tx,
|
|
_ => panic!("expected leader"),
|
|
};
|
|
|
|
let mut resp = DnsPacket::new();
|
|
resp.answers.push(DnsRecord::A {
|
|
domain: "multi.com".into(),
|
|
addr: Ipv4Addr::new(10, 0, 0, 1),
|
|
ttl: 60,
|
|
});
|
|
let _ = tx.send(Some(resp));
|
|
|
|
for f in [f1, f2, f3] {
|
|
let mut rx = match f {
|
|
Disposition::Follower(rx) => rx,
|
|
_ => panic!("expected follower"),
|
|
};
|
|
let r = rx.recv().await.unwrap().unwrap();
|
|
assert_eq!(r.answers.len(), 1);
|
|
}
|
|
}
|
|
|
|
// ---- Integration: resolve_coalesced with mock futures ----
|
|
|
|
fn mock_response(domain: &str) -> DnsPacket {
|
|
let mut resp = DnsPacket::new();
|
|
resp.header.response = true;
|
|
resp.header.rescode = ResultCode::NOERROR;
|
|
resp.answers.push(DnsRecord::A {
|
|
domain: domain.to_string(),
|
|
addr: Ipv4Addr::new(10, 0, 0, 1),
|
|
ttl: 300,
|
|
});
|
|
resp
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn concurrent_queries_coalesce_to_single_resolution() {
|
|
let inflight = Arc::new(Mutex::new(HashMap::new()));
|
|
let resolve_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
|
|
|
|
let mut handles = Vec::new();
|
|
for i in 0..5u16 {
|
|
let count = resolve_count.clone();
|
|
let inf = inflight.clone();
|
|
let key = ("coalesce.test".to_string(), QueryType::A);
|
|
let query = DnsPacket::query(100 + i, "coalesce.test", QueryType::A);
|
|
handles.push(tokio::spawn(async move {
|
|
resolve_coalesced(&inf, key, &query, || async {
|
|
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
|
tokio::time::sleep(Duration::from_millis(200)).await;
|
|
Ok(mock_response("coalesce.test"))
|
|
})
|
|
.await
|
|
}));
|
|
}
|
|
|
|
let mut paths = Vec::new();
|
|
for h in handles {
|
|
let (_, path, _) = h.await.unwrap();
|
|
paths.push(path);
|
|
}
|
|
|
|
let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed);
|
|
assert_eq!(actual, 1, "expected 1 resolution, got {}", actual);
|
|
|
|
let recursive = paths.iter().filter(|p| **p == QueryPath::Recursive).count();
|
|
let coalesced = paths.iter().filter(|p| **p == QueryPath::Coalesced).count();
|
|
assert_eq!(recursive, 1, "expected 1 RECURSIVE, got {}", recursive);
|
|
assert_eq!(coalesced, 4, "expected 4 COALESCED, got {}", coalesced);
|
|
|
|
assert!(inflight.lock().unwrap().is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn different_qtypes_not_coalesced() {
|
|
let inflight = Arc::new(Mutex::new(HashMap::new()));
|
|
let resolve_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
|
|
|
|
let inf1 = inflight.clone();
|
|
let inf2 = inflight.clone();
|
|
let count1 = resolve_count.clone();
|
|
let count2 = resolve_count.clone();
|
|
|
|
let query_a = DnsPacket::query(200, "same.domain", QueryType::A);
|
|
let query_aaaa = DnsPacket::query(201, "same.domain", QueryType::AAAA);
|
|
|
|
let h1 = tokio::spawn(async move {
|
|
resolve_coalesced(
|
|
&inf1,
|
|
("same.domain".to_string(), QueryType::A),
|
|
&query_a,
|
|
|| async {
|
|
count1.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
|
Ok(mock_response("same.domain"))
|
|
},
|
|
)
|
|
.await
|
|
});
|
|
let h2 = tokio::spawn(async move {
|
|
resolve_coalesced(
|
|
&inf2,
|
|
("same.domain".to_string(), QueryType::AAAA),
|
|
&query_aaaa,
|
|
|| async {
|
|
count2.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
|
Ok(mock_response("same.domain"))
|
|
},
|
|
)
|
|
.await
|
|
});
|
|
|
|
let (_, path1, _) = h1.await.unwrap();
|
|
let (_, path2, _) = h2.await.unwrap();
|
|
|
|
let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed);
|
|
assert_eq!(actual, 2, "A and AAAA should each resolve, got {}", actual);
|
|
assert_eq!(path1, QueryPath::Recursive);
|
|
assert_eq!(path2, QueryPath::Recursive);
|
|
|
|
assert!(inflight.lock().unwrap().is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn inflight_map_cleaned_after_error() {
|
|
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let query = DnsPacket::query(300, "will-fail.test", QueryType::A);
|
|
|
|
let (_, path, _) = resolve_coalesced(
|
|
&inflight,
|
|
("will-fail.test".to_string(), QueryType::A),
|
|
&query,
|
|
|| async { Err::<DnsPacket, _>("upstream timeout".into()) },
|
|
)
|
|
.await;
|
|
|
|
assert_eq!(path, QueryPath::UpstreamError);
|
|
assert!(inflight.lock().unwrap().is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn follower_gets_servfail_when_leader_fails() {
|
|
let inflight = Arc::new(Mutex::new(HashMap::new()));
|
|
|
|
let mut handles = Vec::new();
|
|
for i in 0..3u16 {
|
|
let inf = inflight.clone();
|
|
let query = DnsPacket::query(400 + i, "fail.test", QueryType::A);
|
|
handles.push(tokio::spawn(async move {
|
|
resolve_coalesced(
|
|
&inf,
|
|
("fail.test".to_string(), QueryType::A),
|
|
&query,
|
|
|| async {
|
|
tokio::time::sleep(Duration::from_millis(200)).await;
|
|
Err::<DnsPacket, _>("upstream error".into())
|
|
},
|
|
)
|
|
.await
|
|
}));
|
|
}
|
|
|
|
let mut paths = Vec::new();
|
|
for h in handles {
|
|
let (resp, path, _) = h.await.unwrap();
|
|
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
|
|
assert_eq!(
|
|
resp.questions.len(),
|
|
1,
|
|
"SERVFAIL must echo question section"
|
|
);
|
|
assert_eq!(resp.questions[0].name, "fail.test");
|
|
paths.push(path);
|
|
}
|
|
|
|
let errors = paths
|
|
.iter()
|
|
.filter(|p| **p == QueryPath::UpstreamError)
|
|
.count();
|
|
assert_eq!(errors, 3, "all 3 should be UpstreamError, got {}", errors);
|
|
|
|
assert!(inflight.lock().unwrap().is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn servfail_leader_includes_question_section() {
|
|
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let query = DnsPacket::query(500, "question.test", QueryType::A);
|
|
|
|
let (resp, _, _) = resolve_coalesced(
|
|
&inflight,
|
|
("question.test".to_string(), QueryType::A),
|
|
&query,
|
|
|| async { Err::<DnsPacket, _>("fail".into()) },
|
|
)
|
|
.await;
|
|
|
|
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
|
|
assert_eq!(
|
|
resp.questions.len(),
|
|
1,
|
|
"SERVFAIL must echo question section"
|
|
);
|
|
assert_eq!(resp.questions[0].name, "question.test");
|
|
assert_eq!(resp.questions[0].qtype, QueryType::A);
|
|
assert_eq!(resp.header.id, 500);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn leader_error_preserves_message() {
|
|
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
|
|
let query = DnsPacket::query(700, "err-msg.test", QueryType::A);
|
|
|
|
let (_, path, err) = resolve_coalesced(
|
|
&inflight,
|
|
("err-msg.test".to_string(), QueryType::A),
|
|
&query,
|
|
|| async { Err::<DnsPacket, _>("connection refused by upstream".into()) },
|
|
)
|
|
.await;
|
|
|
|
assert_eq!(path, QueryPath::UpstreamError);
|
|
assert_eq!(
|
|
err.as_deref(),
|
|
Some("connection refused by upstream"),
|
|
"error message must be preserved for logging"
|
|
);
|
|
}
|
|
}
|