refactor: extract acquire_inflight, rewrite tests against real code

Move Disposition enum and inflight acquisition logic into a standalone
acquire_inflight() function. Rewrite 4 tests that were exercising tokio
primitives to call the real coalescing code path instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-29 10:31:03 +03:00
parent b5ef76dd65
commit ab9579bec4

View File

@@ -178,22 +178,7 @@ pub async fn handle_query(
(resp, QueryPath::Cached, cached_dnssec) (resp, QueryPath::Cached, cached_dnssec)
} else if ctx.upstream_mode == UpstreamMode::Recursive { } else if ctx.upstream_mode == UpstreamMode::Recursive {
let key = (qname.clone(), qtype); let key = (qname.clone(), qtype);
let disposition = acquire_inflight(&ctx.inflight, key.clone());
enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>),
}
let disposition = {
let mut inflight = ctx.inflight.lock().unwrap();
if let Some(tx) = inflight.get(&key) {
Disposition::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
inflight.insert(key.clone(), tx.clone());
Disposition::Leader(tx)
}
};
match disposition { match disposition {
Disposition::Follower(mut rx) => { Disposition::Follower(mut rx) => {
@@ -431,6 +416,22 @@ fn is_special_use_domain(qname: &str) -> bool {
qname == "local" || qname.ends_with(".local") qname == "local" || qname.ends_with(".local")
} }
enum Disposition {
Leader(broadcast::Sender<Option<DnsPacket>>),
Follower(broadcast::Receiver<Option<DnsPacket>>),
}
fn acquire_inflight(inflight: &Mutex<InflightMap>, key: (String, QueryType)) -> Disposition {
let mut map = inflight.lock().unwrap();
if let Some(tx) = map.get(&key) {
Disposition::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
map.insert(key, tx.clone());
Disposition::Leader(tx)
}
}
struct InflightGuard<'a> { struct InflightGuard<'a> {
inflight: &'a Mutex<InflightMap>, inflight: &'a Mutex<InflightMap>,
key: (String, QueryType), key: (String, QueryType),
@@ -559,39 +560,46 @@ mod tests {
assert!(m.contains_key(&key_aaaa)); assert!(m.contains_key(&key_aaaa));
} }
// ---- Coalescing disposition tests ---- // ---- Coalescing disposition tests (via acquire_inflight) ----
#[test] #[test]
fn leader_follower_disposition() { fn first_caller_becomes_leader() {
// First caller becomes leader, second becomes follower
let map: Mutex<InflightMap> = Mutex::new(HashMap::new()); let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let key = ("test.com".to_string(), QueryType::A); let key = ("test.com".to_string(), QueryType::A);
// First: no entry → insert and become leader let d = acquire_inflight(&map, key.clone());
let is_leader = { assert!(matches!(d, Disposition::Leader(_)));
let mut m = map.lock().unwrap(); assert_eq!(map.lock().unwrap().len(), 1);
if m.get(&key).is_some() { }
false
} else {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1);
m.insert(key.clone(), tx);
true
}
};
assert!(is_leader);
// Second: entry exists → become follower #[test]
let is_follower = { fn second_caller_becomes_follower() {
let m = map.lock().unwrap(); let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
m.get(&key).is_some() let key = ("test.com".to_string(), QueryType::A);
};
assert!(is_follower); let _leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
assert!(matches!(follower, Disposition::Follower(_)));
// Map still has exactly 1 entry — follower subscribes, doesn't insert
assert_eq!(map.lock().unwrap().len(), 1);
} }
#[tokio::test] #[tokio::test]
async fn broadcast_delivers_result_to_follower() { async fn leader_broadcast_reaches_follower() {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1); let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let mut rx = tx.subscribe(); let key = ("test.com".to_string(), QueryType::A);
let leader = acquire_inflight(&map, key.clone());
let follower = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut rx = match follower {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let mut resp = DnsPacket::new(); let mut resp = DnsPacket::new();
resp.header.id = 42; resp.header.id = 42;
@@ -600,29 +608,48 @@ mod tests {
addr: Ipv4Addr::new(1, 2, 3, 4), addr: Ipv4Addr::new(1, 2, 3, 4),
ttl: 300, ttl: 300,
}); });
let _ = tx.send(Some(resp)); let _ = tx.send(Some(resp));
let received = rx.recv().await.unwrap().unwrap(); let received = rx.recv().await.unwrap().unwrap();
assert_eq!(received.header.id, 42); assert_eq!(received.header.id, 42);
assert_eq!(received.answers.len(), 1); assert_eq!(received.answers.len(), 1);
} }
#[tokio::test] #[tokio::test]
async fn broadcast_none_signals_failure() { async fn leader_none_signals_failure_to_follower() {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1); let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let mut rx = tx.subscribe(); let key = ("test.com".to_string(), QueryType::A);
let _ = tx.send(None);
let received = rx.recv().await.unwrap(); let leader = acquire_inflight(&map, key.clone());
assert!(received.is_none()); let follower = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut rx = match follower {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let _ = tx.send(None);
assert!(rx.recv().await.unwrap().is_none());
} }
#[tokio::test] #[tokio::test]
async fn multiple_followers_all_receive_result() { async fn multiple_followers_all_receive_via_acquire() {
let (tx, _) = broadcast::channel::<Option<DnsPacket>>(1); let map: Mutex<InflightMap> = Mutex::new(HashMap::new());
let mut rx1 = tx.subscribe(); let key = ("multi.com".to_string(), QueryType::A);
let mut rx2 = tx.subscribe();
let mut rx3 = tx.subscribe(); let leader = acquire_inflight(&map, key.clone());
let f1 = acquire_inflight(&map, key.clone());
let f2 = acquire_inflight(&map, key.clone());
let f3 = acquire_inflight(&map, key);
let tx = match leader {
Disposition::Leader(tx) => tx,
_ => panic!("expected leader"),
};
let mut resp = DnsPacket::new(); let mut resp = DnsPacket::new();
resp.answers.push(DnsRecord::A { resp.answers.push(DnsRecord::A {
@@ -632,7 +659,11 @@ mod tests {
}); });
let _ = tx.send(Some(resp)); let _ = tx.send(Some(resp));
for rx in [&mut rx1, &mut rx2, &mut rx3] { for f in [f1, f2, f3] {
let mut rx = match f {
Disposition::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let r = rx.recv().await.unwrap().unwrap(); let r = rx.recv().await.unwrap().unwrap();
assert_eq!(r.answers.len(), 1); assert_eq!(r.answers.len(), 1);
} }