refactor: extract acquire_inflight, rewrite tests against real code

Move Disposition enum and inflight acquisition logic into a standalone
acquire_inflight() function. Rewrite 4 tests that were exercising tokio
primitives to call the real coalescing code path instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-29 10:31:03 +03:00
parent b5ef76dd65
commit ab9579bec4

View File

@@ -178,22 +178,7 @@ pub async fn handle_query(
(resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive {
let key = (qname.clone(), qtype);
enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>),
}
let disposition = {
let mut inflight = ctx.inflight.lock().unwrap();
if let Some(tx) = inflight.get(&key) {
Disposition::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
inflight.insert(key.clone(), tx.clone());
Disposition::Leader(tx)
}
};
let disposition = acquire_inflight(&ctx.inflight, key.clone());
match disposition {
Disposition::Follower(mut rx) => {
@@ -431,6 +416,22 @@ fn is_special_use_domain(qname: &str) -> bool {
qname == "local" || qname.ends_with(".local")
}
enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>),
}
fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) -> Disposition {
let mut map = inflight.lock().unwrap();
if let Some(tx) = map.get(&key) {
Disposition::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
map.insert(key, tx.clone());
Disposition::Leader(tx)
}
}
struct InflightGuard<'a> {
inflight: &'a Mutex<InflightMap>,
key: (String, QueryType),
@@ -559,39 +560,46 @@ mod tests {
assert!(m.contains_key(&key_aaaa));
}
// ---- Coalescing disposition tests ----
// ---- Coalescing disposition tests (via acquire_inflight) ----
#[test]
fn leader_follower_disposition() {
// First caller becomes leader, second becomes follower
fn first_caller_becomes_leader() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
// First: no entry → insert and become leader
let is_leader = {
let mut m = map.lock().unwrap();
if m.get(&key).is_some() {
false
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
m.insert(key.clone(), tx);
true
}
};
assert!(is_leader);
let d = acquire_inflight(&map, key.clone());
assert!(matches!(d, Disposition::Leader(_)));
assert_eq!(map.lock().unwrap().len(), 1);
}
// Second: entry exists → become follower
let is_follower = {
let m = map.lock().unwrap();
m.get(&key).is_some()
};
assert!(is_follower);
#[test]
fn second_caller_becomes_follower() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
let _leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
assert!(matches!(follower, Disposition::Follower(_)));
// Map still has exactly 1 entry — follower subscribes, doesn't insert
assert_eq!(map.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn broadcast_delivers_result_to_follower() {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
let mut rx = tx.subscribe();
async fn leader_broadcast_reaches_follower() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
let leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut rx = match follower {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let mut resp = DnsPacket::new();
resp.header.id = 42;
@@ -600,29 +608,48 @@ mod tests {
addr: Ipv4Addr::new(1, 2, 3, 4),
ttl: 300,
});
let _ = tx.send(Some(resp));
let received = rx.recv().await.unwrap().unwrap();
assert_eq!(received.header.id, 42);
assert_eq!(received.answers.len(), 1);
}
#[tokio::test]
async fn broadcast_none_signals_failure() {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
let mut rx = tx.subscribe();
let _ = tx.send(None);
async fn leader_none_signals_failure_to_follower() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A);
let received = rx.recv().await.unwrap();
assert!(received.is_none());
let leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut rx = match follower {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let _ = tx.send(None);
assert!(rx.recv().await.unwrap().is_none());
}
#[tokio::test]
async fn multiple_followers_all_receive_result() {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
let mut rx1 = tx.subscribe();
let mut rx2 = tx.subscribe();
let mut rx3 = tx.subscribe();
async fn multiple_followers_all_receive_via_acquire() {
let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("multi.com".to_string(), QueryType::A);
let leader = acquire_inflight(&map, key.clone());
let f1 = acquire_inflight(&map, key.clone());
let f2 = acquire_inflight(&map, key.clone());
let f3 = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut resp = DnsPacket::new();
resp.answers.push(DnsRecord::A {
@@ -632,7 +659,11 @@ mod tests {
});
let _ = tx.send(Some(resp));
for rx in [&mut rx1, &mut rx2, &mut rx3] {
for f in [f1, f2, f3] {
let mut rx = match f {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let r = rx.recv().await.unwrap().unwrap();
assert_eq!(r.answers.len(), 1);
}