diff --git a/src/ctx.rs b/src/ctx.rs index 4c28c98..fbddb15 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -178,7 +178,7 @@ pub async fn handle_query( (resp, QueryPath::Cached, cached_dnssec) } else if ctx.upstream_mode == UpstreamMode::Recursive { let key = (qname.clone(), qtype); - let (resp, path) = resolve_coalesced(&ctx.inflight, key, query.header.id, || { + let (resp, path, err) = resolve_coalesced(&ctx.inflight, key, &query, || { crate::recursive::resolve_recursive( &qname, qtype, @@ -192,7 +192,13 @@ pub async fn handle_query( if path == QueryPath::Coalesced { debug!("{} | {:?} {} | COALESCED", src_addr, qtype, qname); } else if path == QueryPath::UpstreamError { - error!("{} | {:?} {} | RECURSIVE ERROR", src_addr, qtype, qname); + error!( + "{} | {:?} {} | RECURSIVE ERROR | {}", + src_addr, + qtype, + qname, + err.as_deref().unwrap_or("leader failed") + ); } (resp, path, DnssecStatus::Indeterminate) } else { @@ -399,9 +405,9 @@ fn acquire_inflight(inflight: &Mutex, key: (String, QueryType)) -> async fn resolve_coalesced( inflight: &Mutex, key: (String, QueryType), - query_id: u16, + query: &DnsPacket, resolve_fn: F, -) -> (DnsPacket, QueryPath) +) -> (DnsPacket, QueryPath, Option) where F: FnOnce() -> Fut, Fut: std::future::Future>, @@ -411,16 +417,14 @@ where match disposition { Disposition::Follower(mut rx) => match rx.recv().await { Ok(Some(mut resp)) => { - resp.header.id = query_id; - (resp, QueryPath::Coalesced) - } - _ => { - let mut resp = DnsPacket::new(); - resp.header.id = query_id; - resp.header.response = true; - resp.header.rescode = ResultCode::SERVFAIL; - (resp, QueryPath::UpstreamError) + resp.header.id = query.header.id; + (resp, QueryPath::Coalesced, None) } + _ => ( + DnsPacket::response_from(query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + None, + ), }, Disposition::Leader(tx) => { let guard = InflightGuard { inflight, key }; @@ -430,15 +434,16 @@ where match result { Ok(resp) => { let _ = tx.send(Some(resp.clone())); - (resp, QueryPath::Recursive) + (resp, QueryPath::Recursive, None) } - Err(_) => { + Err(e) => { let _ = tx.send(None); - let mut resp = DnsPacket::new(); - resp.header.id = query_id; - resp.header.response = true; - resp.header.rescode = ResultCode::SERVFAIL; - (resp, QueryPath::UpstreamError) + let err_msg = e.to_string(); + ( + DnsPacket::response_from(query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + Some(err_msg), + ) } } } @@ -670,7 +675,14 @@ mod tests { // ---- Integration: resolve_coalesced with mock futures ---- - // ---- Integration: resolve_coalesced with mock futures ---- + fn mock_query(id: u16, domain: &str, qtype: QueryType) -> DnsPacket { + let mut pkt = DnsPacket::new(); + pkt.header.id = id; + pkt.header.recursion_desired = true; + pkt.questions + .push(crate::question::DnsQuestion::new(domain.to_string(), qtype)); + pkt + } fn mock_response(domain: &str) -> DnsPacket { let mut resp = DnsPacket::new(); @@ -694,8 +706,9 @@ mod tests { let count = resolve_count.clone(); let inf = inflight.clone(); let key = ("coalesce.test".to_string(), QueryType::A); + let query = mock_query(100 + i, "coalesce.test", QueryType::A); handles.push(tokio::spawn(async move { - resolve_coalesced(&inf, key, 100 + i, || async { + resolve_coalesced(&inf, key, &query, || async { count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); tokio::time::sleep(Duration::from_millis(200)).await; Ok(mock_response("coalesce.test")) @@ -706,7 +719,7 @@ mod tests { let mut paths = Vec::new(); for h in handles { - let (_, path) = h.await.unwrap(); + let (_, path, _) = h.await.unwrap(); paths.push(path); } @@ -731,11 +744,14 @@ mod tests { let count1 = resolve_count.clone(); let count2 = resolve_count.clone(); + let query_a = mock_query(200, "same.domain", QueryType::A); + let query_aaaa = mock_query(201, "same.domain", QueryType::AAAA); + let h1 = tokio::spawn(async move { resolve_coalesced( &inf1, ("same.domain".to_string(), QueryType::A), - 200, + &query_a, || async { count1.fetch_add(1, std::sync::atomic::Ordering::Relaxed); tokio::time::sleep(Duration::from_millis(100)).await; @@ -748,7 +764,7 @@ mod tests { resolve_coalesced( &inf2, ("same.domain".to_string(), QueryType::AAAA), - 201, + &query_aaaa, || async { count2.fetch_add(1, std::sync::atomic::Ordering::Relaxed); tokio::time::sleep(Duration::from_millis(100)).await; @@ -758,8 +774,8 @@ mod tests { .await }); - let (_, path1) = h1.await.unwrap(); - let (_, path2) = h2.await.unwrap(); + let (_, path1, _) = h1.await.unwrap(); + let (_, path2, _) = h2.await.unwrap(); let actual = resolve_count.load(std::sync::atomic::Ordering::Relaxed); assert_eq!(actual, 2, "A and AAAA should each resolve, got {}", actual); @@ -772,11 +788,12 @@ mod tests { #[tokio::test] async fn inflight_map_cleaned_after_error() { let inflight: Mutex = Mutex::new(HashMap::new()); + let query = mock_query(300, "will-fail.test", QueryType::A); - let (_, path) = resolve_coalesced( + let (_, path, _) = resolve_coalesced( &inflight, ("will-fail.test".to_string(), QueryType::A), - 300, + &query, || async { Err::("upstream timeout".into()) }, ) .await; @@ -792,11 +809,12 @@ mod tests { let mut handles = Vec::new(); for i in 0..3u16 { let inf = inflight.clone(); + let query = mock_query(400 + i, "fail.test", QueryType::A); handles.push(tokio::spawn(async move { resolve_coalesced( &inf, ("fail.test".to_string(), QueryType::A), - 400 + i, + &query, || async { tokio::time::sleep(Duration::from_millis(200)).await; Err::("upstream error".into()) @@ -808,8 +826,14 @@ mod tests { let mut paths = Vec::new(); for h in handles { - let (resp, path) = h.await.unwrap(); + let (resp, path, _) = h.await.unwrap(); assert_eq!(resp.header.rescode, ResultCode::SERVFAIL); + assert_eq!( + resp.questions.len(), + 1, + "SERVFAIL must echo question section" + ); + assert_eq!(resp.questions[0].name, "fail.test"); paths.push(path); } @@ -821,4 +845,49 @@ mod tests { assert!(inflight.lock().unwrap().is_empty()); } + + #[tokio::test] + async fn servfail_leader_includes_question_section() { + let inflight: Mutex = Mutex::new(HashMap::new()); + let query = mock_query(500, "question.test", QueryType::A); + + let (resp, _, _) = resolve_coalesced( + &inflight, + ("question.test".to_string(), QueryType::A), + &query, + || async { Err::("fail".into()) }, + ) + .await; + + assert_eq!(resp.header.rescode, ResultCode::SERVFAIL); + assert_eq!( + resp.questions.len(), + 1, + "SERVFAIL must echo question section" + ); + assert_eq!(resp.questions[0].name, "question.test"); + assert_eq!(resp.questions[0].qtype, QueryType::A); + assert_eq!(resp.header.id, 500); + } + + #[tokio::test] + async fn leader_error_preserves_message() { + let inflight: Mutex = Mutex::new(HashMap::new()); + let query = mock_query(700, "err-msg.test", QueryType::A); + + let (_, path, err) = resolve_coalesced( + &inflight, + ("err-msg.test".to_string(), QueryType::A), + &query, + || async { Err::("connection refused by upstream".into()) }, + ) + .await; + + assert_eq!(path, QueryPath::UpstreamError); + assert_eq!( + err.as_deref(), + Some("connection refused by upstream"), + "error message must be preserved for logging" + ); + } }