diff --git a/src/ctx.rs b/src/ctx.rs index e440c2d..3f1370a 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -88,7 +88,7 @@ pub async fn resolve_query( src_addr: SocketAddr, ctx: &Arc, transport: Transport, -) -> crate::Result { +) -> crate::Result<(BytePacketBuffer, QueryPath)> { let start = Instant::now(); let (qname, qtype) = match query.questions.first() { @@ -377,7 +377,7 @@ pub async fn resolve_query( dnssec, }); - Ok(resp_buffer) + Ok((resp_buffer, path)) } fn cache_and_parse( @@ -461,7 +461,7 @@ pub async fn handle_query( } }; match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx, transport).await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; } Err(e) => { @@ -1048,7 +1048,7 @@ mod tests { test_ctx_with_forwarding(Vec::new()).await } - /// Helper: send a query through the full resolve_query pipeline and return + /// Send a query through the full resolve_query pipeline and return /// the parsed response + query path. async fn resolve_in_test( ctx: &Arc, @@ -1061,21 +1061,10 @@ mod tests { let raw = &buf.buf[..buf.pos]; let src: SocketAddr = "127.0.0.1:1234".parse().unwrap(); - let resp_buf = resolve_query(query, raw, src, ctx, Transport::Udp) + let (resp_buf, path) = resolve_query(query, raw, src, ctx, Transport::Udp) .await .unwrap(); - let log = ctx.query_log.lock().unwrap(); - let entry = log.query(&crate::query_log::QueryLogFilter { - domain: None, - query_type: None, - path: None, - since: None, - limit: Some(1), - }); - let path = entry.first().unwrap().path; - drop(log); - let mut resp_parse_buf = BytePacketBuffer::from_bytes(resp_buf.filled()); let resp = DnsPacket::from_buffer(&mut resp_parse_buf).unwrap(); (resp, path) @@ -1134,23 +1123,45 @@ mod tests { }) } + /// Spawn a UDP socket that replies to the first DNS query with the given + /// response packet (patching the query ID). Returns the socket address. + async fn mock_upstream(response: DnsPacket) -> SocketAddr { + let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = sock.local_addr().unwrap(); + tokio::spawn(async move { + let mut buf = [0u8; 512]; + let (_, src) = sock.recv_from(&mut buf).await.unwrap(); + let query_id = u16::from_be_bytes([buf[0], buf[1]]); + let mut resp = response; + resp.header.id = query_id; + let mut out = BytePacketBuffer::new(); + resp.write(&mut out).unwrap(); + sock.send_to(out.filled(), src).await.unwrap(); + }); + addr + } + #[tokio::test] async fn forwarding_rule_overrides_special_use_domain() { + let mut resp = DnsPacket::new(); + resp.header.response = true; + resp.header.rescode = ResultCode::NOERROR; + let upstream_addr = mock_upstream(resp).await; + let rules = vec![ForwardingRule::new( "168.192.in-addr.arpa".to_string(), - "192.168.88.1:53".parse().unwrap(), + upstream_addr, )]; let ctx = test_ctx_with_forwarding(rules).await; - let (_, path) = resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; + let (resp, path) = + resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; - // Should attempt forwarding, not return local NXDOMAIN. - // The forwarding will fail (no real upstream at 192.168.88.1), so we - // expect UpstreamError — but critically NOT QueryPath::Local. - assert_ne!( + assert_eq!( path, - QueryPath::Local, + QueryPath::Forwarded, "forwarding rule must take precedence over special-use NXDOMAIN" ); + assert_eq!(resp.header.rescode, ResultCode::NOERROR); } } diff --git a/src/doh.rs b/src/doh.rs index f90b919..900edb4 100644 --- a/src/doh.rs +++ b/src/doh.rs @@ -113,7 +113,7 @@ async fn resolve_doh( let questions = query.questions.clone(); match resolve_query(query, dns_bytes, src, ctx, Transport::Doh).await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { let min_ttl = extract_min_ttl(resp_buffer.filled()); dns_response(resp_buffer.filled(), min_ttl) } diff --git a/src/dot.rs b/src/dot.rs index e883e0b..db8257d 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -211,7 +211,7 @@ async fn handle_dot_connection( ) .await { - Ok(resp_buffer) => { + Ok((resp_buffer, _)) => { if write_framed(&mut stream, resp_buffer.filled()) .await .is_err()