refactor: extract resolve_coalesced, test real code #21

Merged
razvandimescu merged 2 commits from refactor/extract-resolve-coalesced into main 2026-03-29 16:14:25 +08:00
Showing only changes of commit 850a0c6ab4 - Show all commits

View File

@@ -178,7 +178,7 @@ pub async fn handle_query(
(resp, QueryPath::Cached, cached_dnssec) (resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive { } else if ctx.upstream_mode == UpstreamMode::Recursive {
let key = (qname.clone(), qtype); let key = (qname.clone(), qtype);
let (resp, path) = resolve_coalesced(&ctx.inflight, key, query.header.id, || { let (resp, path, err) = resolve_coalesced(&ctx.inflight, key, &query, || {
crate::recursive::resolve_recursive( crate::recursive::resolve_recursive(
&qname, &qname,
qtype, qtype,
@@ -192,7 +192,13 @@ pub async fn handle_query(
if path == QueryPath::Coalesced { if path == QueryPath::Coalesced {
debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname); debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname);
} else if path == QueryPath::UpstreamError { } else if path == QueryPath::UpstreamError {
error!("{} | {:?} {} | RECURSIVE ERROR", src_addr, qtype, qname); error!(
"{} | {:?} {} | RECURSIVE ERROR | {}",
src_addr,
qtype,
qname,
err.as_deref().unwrap_or("leader failed")
);
} }
(resp, path, DnssecStatus::Indeterminate) (resp, path, DnssecStatus::Indeterminate)
} else { } else {
@@ -399,9 +405,9 @@ fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) ->
async fn resolve_coalesced<F, Fut>( async fn resolve_coalesced<F, Fut>(
inflight: &Mutex<InflightMap>, inflight: &Mutex<InflightMap>,
key: (String, QueryType), key: (String, QueryType),
query_id: u16, query: &DnsPacket,
resolve_fn: F, resolve_fn: F,
) -> (DnsPacket, QueryPath) ) -> (DnsPacket, QueryPath, Option<String>)
where where
F: FnOnce() -> Fut, F: FnOnce() -> Fut,
Fut: std::future::Future<Output = crate::Result<DnsPacket>>, Fut: std::future::Future<Output = crate::Result<DnsPacket>>,
@@ -411,16 +417,14 @@ where
match disposition { match disposition {
Disposition::Follower(mut rx) => match rx.recv().await { Disposition::Follower(mut rx) => match rx.recv().await {
Ok(Some(mut resp)) => { Ok(Some(mut resp)) => {
resp.header.id = query_id; resp.header.id = query.header.id;
(resp, QueryPath::Coalesced) (resp, QueryPath::Coalesced, None)
}
_ => {
let mut resp = DnsPacket::new();
resp.header.id = query_id;
resp.header.response = true;
resp.header.rescode = ResultCode::SERVFAIL;
(resp, QueryPath::UpstreamError)
} }
_ => (
DnsPacket::response_from(query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
None,
),
}, },
Disposition::Leader(tx) => { Disposition::Leader(tx) => {
let guard = InflightGuard { inflight, key }; let guard = InflightGuard { inflight, key };
@@ -430,15 +434,16 @@ where
match result { match result {
Ok(resp) => { Ok(resp) => {
let _ = tx.send(Some(resp.clone())); let _ = tx.send(Some(resp.clone()));
(resp, QueryPath::Recursive) (resp, QueryPath::Recursive, None)
} }
Err(_) => { Err(e) => {
let _ = tx.send(None); let _ = tx.send(None);
let mut resp = DnsPacket::new(); let err_msg = e.to_string();
resp.header.id = query_id; (
resp.header.response = true; DnsPacket::response_from(query, ResultCode::SERVFAIL),
resp.header.rescode = ResultCode::SERVFAIL; QueryPath::UpstreamError,
(resp, QueryPath::UpstreamError) Some(err_msg),
)
} }
} }
} }
@@ -670,7 +675,14 @@ mod tests {
// ---- Integration: resolve_coalesced with mock futures ---- // ---- Integration: resolve_coalesced with mock futures ----
// ---- Integration: resolve_coalesced with mock futures ---- fn mock_query(id: u16, domain: &str, qtype: QueryType) -> DnsPacket {
let mut pkt = DnsPacket::new();
pkt.header.id = id;
pkt.header.recursion_desired = true;
pkt.questions
.push(crate::question::DnsQuestion::new(domain.to_string(), qtype));
pkt
}
fn mock_response(domain: &str) -> DnsPacket { fn mock_response(domain: &str) -> DnsPacket {
let mut resp = DnsPacket::new(); let mut resp = DnsPacket::new();
@@ -694,8 +706,9 @@ mod tests {
let count = resolve_count.clone(); let count = resolve_count.clone();
let inf = inflight.clone(); let inf = inflight.clone();
let key = ("coalesce.test".to_string(), QueryType::A); let key = ("coalesce.test".to_string(), QueryType::A);
let query = mock_query(100 + i, "coalesce.test", QueryType::A);
handles.push(tokio::spawn(async move { handles.push(tokio::spawn(async move {
resolve_coalesced(&inf, key, 100 + i, || async { resolve_coalesced(&inf, key, &query, || async {
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(200)).await; tokio::time::sleep(Duration::from_millis(200)).await;
Ok(mock_response("coalesce.test")) Ok(mock_response("coalesce.test"))
@@ -706,7 +719,7 @@ mod tests {
let mut paths = Vec::new(); let mut paths = Vec::new();
for h in handles { for h in handles {
let (_, path) = h.await.unwrap(); let (_, path, _) = h.await.unwrap();
paths.push(path); paths.push(path);
} }
@@ -731,11 +744,14 @@ mod tests {
let count1 = resolve_count.clone(); let count1 = resolve_count.clone();
let count2 = resolve_count.clone(); let count2 = resolve_count.clone();
let query_a = mock_query(200, "same.domain", QueryType::A);
let query_aaaa = mock_query(201, "same.domain", QueryType::AAAA);
let h1 = tokio::spawn(async move { let h1 = tokio::spawn(async move {
resolve_coalesced( resolve_coalesced(
&inf1, &inf1,
("same.domain".to_string(), QueryType::A), ("same.domain".to_string(), QueryType::A),
200, &query_a,
|| async { || async {
count1.fetch_add(1, std::sync::atomic::Ordering::Relaxed); count1.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;
@@ -748,7 +764,7 @@ mod tests {
resolve_coalesced( resolve_coalesced(
&inf2, &inf2,
("same.domain".to_string(), QueryType::AAAA), ("same.domain".to_string(), QueryType::AAAA),
201, &query_aaaa,
|| async { || async {
count2.fetch_add(1, std::sync::atomic::Ordering::Relaxed); count2.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;
@@ -758,8 +774,8 @@ mod tests {
.await .await
}); });
let (_, path1) = h1.await.unwrap(); let (_, path1, _) = h1.await.unwrap();
let (_, path2) = h2.await.unwrap(); let (_, path2, _) = h2.await.unwrap();
let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed); let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed);
assert_eq!(actual, 2, "A and AAAA should each resolve, got {}", actual); assert_eq!(actual, 2, "A and AAAA should each resolve, got {}", actual);
@@ -772,11 +788,12 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn inflight_map_cleaned_after_error() { async fn inflight_map_cleaned_after_error() {
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new()); let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
let query = mock_query(300, "will-fail.test", QueryType::A);
let (_, path) = resolve_coalesced( let (_, path, _) = resolve_coalesced(
&inflight, &inflight,
("will-fail.test".to_string(), QueryType::A), ("will-fail.test".to_string(), QueryType::A),
300, &query,
|| async { Err::<DnsPacket, _>("upstream timeout".into()) }, || async { Err::<DnsPacket, _>("upstream timeout".into()) },
) )
.await; .await;
@@ -792,11 +809,12 @@ mod tests {
let mut handles = Vec::new(); let mut handles = Vec::new();
for i in 0..3u16 { for i in 0..3u16 {
let inf = inflight.clone(); let inf = inflight.clone();
let query = mock_query(400 + i, "fail.test", QueryType::A);
handles.push(tokio::spawn(async move { handles.push(tokio::spawn(async move {
resolve_coalesced( resolve_coalesced(
&inf, &inf,
("fail.test".to_string(), QueryType::A), ("fail.test".to_string(), QueryType::A),
400 + i, &query,
|| async { || async {
tokio::time::sleep(Duration::from_millis(200)).await; tokio::time::sleep(Duration::from_millis(200)).await;
Err::<DnsPacket, _>("upstream error".into()) Err::<DnsPacket, _>("upstream error".into())
@@ -808,8 +826,14 @@ mod tests {
let mut paths = Vec::new(); let mut paths = Vec::new();
for h in handles { for h in handles {
let (resp, path) = h.await.unwrap(); let (resp, path, _) = h.await.unwrap();
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL); assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
assert_eq!(
resp.questions.len(),
1,
"SERVFAIL must echo question section"
);
assert_eq!(resp.questions[0].name, "fail.test");
paths.push(path); paths.push(path);
} }
@@ -821,4 +845,49 @@ mod tests {
assert!(inflight.lock().unwrap().is_empty()); assert!(inflight.lock().unwrap().is_empty());
} }
#[tokio::test]
async fn servfail_leader_includes_question_section() {
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
let query = mock_query(500, "question.test", QueryType::A);
let (resp, _, _) = resolve_coalesced(
&inflight,
("question.test".to_string(), QueryType::A),
&query,
|| async { Err::<DnsPacket, _>("fail".into()) },
)
.await;
assert_eq!(resp.header.rescode, ResultCode::SERVFAIL);
assert_eq!(
resp.questions.len(),
1,
"SERVFAIL must echo question section"
);
assert_eq!(resp.questions[0].name, "question.test");
assert_eq!(resp.questions[0].qtype, QueryType::A);
assert_eq!(resp.header.id, 500);
}
#[tokio::test]
async fn leader_error_preserves_message() {
let inflight: Mutex<InflightMap> = Mutex::new(HashMap::new());
let query = mock_query(700, "err-msg.test", QueryType::A);
let (_, path, err) = resolve_coalesced(
&inflight,
("err-msg.test".to_string(), QueryType::A),
&query,
|| async { Err::<DnsPacket, _>("connection refused by upstream".into()) },
)
.await;
assert_eq!(path, QueryPath::UpstreamError);
assert_eq!(
err.as_deref(),
Some("connection refused by upstream"),
"error message must be preserved for logging"
);
}
} }