From b8ddc16027453beec62888e9ee062b105f362543 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Mon, 13 Apr 2026 07:51:14 +0300 Subject: [PATCH] refactor: return QueryPath from resolve_query, add mock upstream to tests resolve_query now returns (BytePacketBuffer, QueryPath) so callers and tests can inspect the resolution path without reading the query log. Production call sites (UDP, DoT, DoH) destructure and ignore it. The forwarding test now uses a mock UDP upstream that replies with a canned response, asserting QueryPath::Forwarded instead of != Local. --- src/ctx.rs | 57 ++++++++++++++++++++++++++++++++---------------------- src/doh.rs | 2 +- src/dot.rs | 2 +- 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/src/ctx.rs b/src/ctx.rs index e440c2d..3f1370a 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -88,7 +88,7 @@ pub async fn resolve_query( src_addr: SocketAddr, ctx: &Arc, transport: Transport, -) -> crate::Result { +) -> crate::Result<(BytePacketBuffer, QueryPath)> { let start = Instant::now(); let (qname, qtype) = match query.questions.first() { @@ -377,7 +377,7 @@ pub async fn resolve_query( dnssec, }); - Ok(resp_buffer) + Ok((resp_buffer, path)) } fn cache_and_parse( @@ -461,7 +461,7 @@ pub async fn handle_query( } }; match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx, transport).await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } Err(e) => { @@ -1048,7 +1048,7 @@ mod tests { test_ctx_with_forwarding(Vec::new()).await } - /// Helper: send a query through the full resolve_query pipeline and return + /// Send a query through the full resolve_query pipeline and return /// the parsed response + query path. async fn resolve_in_test( ctx: &Arc, @@ -1061,21 +1061,10 @@ mod tests { let raw = &buf.buf[..buf.pos]; let src: SocketAddr = "127.0.0.1:1234".parse().unwrap(); - let resp_buf = resolve_query(query, raw, src, ctx, Transport::Udp) + let (resp_buf, path) = resolve_query(query, raw, src, ctx, Transport::Udp) .await .unwrap(); - let log = ctx.query_log.lock().unwrap(); - let entry = log.query(&crate::query_log::QueryLogFilter { - domain: None, - query_type: None, - path: None, - since: None, - limit: Some(1), - }); - let path = entry.first().unwrap().path; - drop(log); - let mut resp_parse_buf = BytePacketBuffer::from_bytes(resp_buf.filled()); let resp = DnsPacket::from_buffer(&mut resp_parse_buf).unwrap(); (resp, path) @@ -1134,23 +1123,45 @@ mod tests { }) } + /// Spawn a UDP socket that replies to the first DNS query with the given + /// response packet (patching the query ID). Returns the socket address. + async fn mock_upstream(response: DnsPacket) -> SocketAddr { + let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = sock.local_addr().unwrap(); + tokio::spawn(async move { + let mut buf = [0u8; 512]; + let (_, src) = sock.recv_from(&mut buf).await.unwrap(); + let query_id = u16::from_be_bytes([buf[0], buf[1]]); + let mut resp = response; + resp.header.id = query_id; + let mut out = BytePacketBuffer::new(); + resp.write(&mut out).unwrap(); + sock.send_to(out.filled(), src).await.unwrap(); + }); + addr + } + #[tokio::test] async fn forwarding_rule_overrides_special_use_domain() { + let mut resp = DnsPacket::new(); + resp.header.response = true; + resp.header.rescode = ResultCode::NOERROR; + let upstream_addr = mock_upstream(resp).await; + let rules = vec![ForwardingRule::new( "168.192.in-addr.arpa".to_string(), - "192.168.88.1:53".parse().unwrap(), + upstream_addr, )]; let ctx = test_ctx_with_forwarding(rules).await; - let (_, path) = resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; + let (resp, path) = + resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; - // Should attempt forwarding, not return local NXDOMAIN. - // The forwarding will fail (no real upstream at 192.168.88.1), so we - // expect UpstreamError — but critically NOT QueryPath::Local. - assert_ne!( + assert_eq!( path, - QueryPath::Local, + QueryPath::Forwarded, "forwarding rule must take precedence over special-use NXDOMAIN" ); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); } } diff --git a/src/doh.rs b/src/doh.rs index f90b919..900edb4 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -113,7 +113,7 @@ async fn resolve_doh( let questions = query.questions.clone(); match resolve_query(query, dns_bytes, src, ctx, Transport::Doh).await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { let min_ttl = extract_min_ttl(resp_buffer.filled()); dns_response(resp_buffer.filled(), min_ttl) } diff --git a/src/dot.rs b/src/dot.rs index e883e0b..db8257d 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -211,7 +211,7 @@ async fn handle_dot_connection( ) .await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { if write_framed(&mut stream, resp_buffer.filled()) .await .is_err()