fix: forwarding rules override special-use NXDOMAIN #95

Merged
razvandimescu merged 7 commits from fix/forwarding-precedes-special-use into main 2026-04-13 14:37:19 +08:00
3 changed files with 36 additions and 25 deletions
Showing only changes of commit b8ddc16027 - Show all commits

View File

@@ -88,7 +88,7 @@ pub async fn resolve_query(
src_addr: SocketAddr, src_addr: SocketAddr,
ctx: &Arc<ServerCtx>, ctx: &Arc<ServerCtx>,
transport: Transport, transport: Transport,
) -> crate::Result<BytePacketBuffer> { ) -> crate::Result<(BytePacketBuffer, QueryPath)> {
let start = Instant::now(); let start = Instant::now();
let (qname, qtype) = match query.questions.first() { let (qname, qtype) = match query.questions.first() {
@@ -377,7 +377,7 @@ pub async fn resolve_query(
dnssec, dnssec,
}); });
Ok(resp_buffer) Ok((resp_buffer, path))
} }
fn cache_and_parse( fn cache_and_parse(
@@ -461,7 +461,7 @@ pub async fn handle_query(
} }
}; };
match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx, transport).await { match resolve_query(query, &buffer.buf[..raw_len], src_addr, ctx, transport).await {
Ok(resp_buffer) => { Ok((resp_buffer, _)) => {
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
} }
Err(e) => { Err(e) => {
@@ -1048,7 +1048,7 @@ mod tests {
test_ctx_with_forwarding(Vec::new()).await test_ctx_with_forwarding(Vec::new()).await
} }
/// Helper: send a query through the full resolve_query pipeline and return /// Send a query through the full resolve_query pipeline and return
/// the parsed response + query path. /// the parsed response + query path.
async fn resolve_in_test( async fn resolve_in_test(
ctx: &Arc<ServerCtx>, ctx: &Arc<ServerCtx>,
@@ -1061,21 +1061,10 @@ mod tests {
let raw = &buf.buf[..buf.pos]; let raw = &buf.buf[..buf.pos];
let src: SocketAddr = "127.0.0.1:1234".parse().unwrap(); let src: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let resp_buf = resolve_query(query, raw, src, ctx, Transport::Udp) let (resp_buf, path) = resolve_query(query, raw, src, ctx, Transport::Udp)
.await .await
.unwrap(); .unwrap();
let log = ctx.query_log.lock().unwrap();
let entry = log.query(&crate::query_log::QueryLogFilter {
domain: None,
query_type: None,
path: None,
since: None,
limit: Some(1),
});
let path = entry.first().unwrap().path;
drop(log);
let mut resp_parse_buf = BytePacketBuffer::from_bytes(resp_buf.filled()); let mut resp_parse_buf = BytePacketBuffer::from_bytes(resp_buf.filled());
let resp = DnsPacket::from_buffer(&mut resp_parse_buf).unwrap(); let resp = DnsPacket::from_buffer(&mut resp_parse_buf).unwrap();
(resp, path) (resp, path)
@@ -1134,23 +1123,45 @@ mod tests {
}) })
} }
/// Spawn a UDP socket that replies to the first DNS query with the given
/// response packet (patching the query ID). Returns the socket address.
async fn mock_upstream(response: DnsPacket) -> SocketAddr {
let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = sock.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = [0u8; 512];
let (_, src) = sock.recv_from(&mut buf).await.unwrap();
let query_id = u16::from_be_bytes([buf[0], buf[1]]);
let mut resp = response;
resp.header.id = query_id;
let mut out = BytePacketBuffer::new();
resp.write(&mut out).unwrap();
sock.send_to(out.filled(), src).await.unwrap();
});
addr
}
#[tokio::test] #[tokio::test]
async fn forwarding_rule_overrides_special_use_domain() { async fn forwarding_rule_overrides_special_use_domain() {
let mut resp = DnsPacket::new();
resp.header.response = true;
resp.header.rescode = ResultCode::NOERROR;
let upstream_addr = mock_upstream(resp).await;
let rules = vec![ForwardingRule::new( let rules = vec![ForwardingRule::new(
"168.192.in-addr.arpa".to_string(), "168.192.in-addr.arpa".to_string(),
"192.168.88.1:53".parse().unwrap(), upstream_addr,
)]; )];
let ctx = test_ctx_with_forwarding(rules).await; let ctx = test_ctx_with_forwarding(rules).await;
let (_, path) = resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await; let (resp, path) =
resolve_in_test(&ctx, "153.188.168.192.in-addr.arpa", QueryType::PTR).await;
// Should attempt forwarding, not return local NXDOMAIN. assert_eq!(
// The forwarding will fail (no real upstream at 192.168.88.1), so we
// expect UpstreamError — but critically NOT QueryPath::Local.
assert_ne!(
path, path,
QueryPath::Local, QueryPath::Forwarded,
"forwarding rule must take precedence over special-use NXDOMAIN" "forwarding rule must take precedence over special-use NXDOMAIN"
); );
assert_eq!(resp.header.rescode, ResultCode::NOERROR);
} }
} }

View File

@@ -113,7 +113,7 @@ async fn resolve_doh(
let questions = query.questions.clone(); let questions = query.questions.clone();
match resolve_query(query, dns_bytes, src, ctx, Transport::Doh).await { match resolve_query(query, dns_bytes, src, ctx, Transport::Doh).await {
Ok(resp_buffer) => { Ok((resp_buffer, _)) => {
let min_ttl = extract_min_ttl(resp_buffer.filled()); let min_ttl = extract_min_ttl(resp_buffer.filled());
dns_response(resp_buffer.filled(), min_ttl) dns_response(resp_buffer.filled(), min_ttl)
} }

View File

@@ -211,7 +211,7 @@ async fn handle_dot_connection<S>(
) )
.await .await
{ {
Ok(resp_buffer) => { Ok((resp_buffer, _)) => {
if write_framed(&mut stream, resp_buffer.filled()) if write_framed(&mut stream, resp_buffer.filled())
.await .await
.is_err() .is_err()