From ab9579bec474d6f61a450214d9472b16af7f6385 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Sun, 29 Mar 2026 10:31:03 +0300 Subject: [PATCH] refactor: extract acquire_inflight, rewrite tests against real code Move Disposition enum and inflight acquisition logic into a standalone acquire_inflight() function. Rewrite 4 tests that were exercising tokio primitives to call the real coalescing code path instead. Co-Authored-By: Claude Opus 4.6 --- src/ctx.rs | 137 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 84 insertions(+), 53 deletions(-) diff --git a/src/ctx.rs b/src/ctx.rs index 2f8bda0..5aee946 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -178,22 +178,7 @@ pub async fn handle_query( (resp, QueryPath::Cached, cached_dnssec) } else if ctx.upstream_mode == UpstreamMode::Recursive { let key = (qname.clone(), qtype); - - enum Disposition { - Leader(broadcast::Sender>), - Follower(broadcast::Receiver>), - } - - let disposition = { - let mut inflight = ctx.inflight.lock().unwrap(); - if let Some(tx) = inflight.get(&key) { - Disposition::Follower(tx.subscribe()) - } else { - let (tx, _) = broadcast::channel::>(1); - inflight.insert(key.clone(), tx.clone()); - Disposition::Leader(tx) - } - }; + let disposition = acquire_inflight(&ctx.inflight, key.clone()); match disposition { Disposition::Follower(mut rx) => { @@ -431,6 +416,22 @@ fn is_special_use_domain(qname: &str) -> bool { qname == "local" || qname.ends_with(".local") } +enum Disposition { + Leader(broadcast::Sender>), + Follower(broadcast::Receiver>), +} + +fn acquire_inflight(inflight: &Mutex, key: (String, QueryType)) -> Disposition { + let mut map = inflight.lock().unwrap(); + if let Some(tx) = map.get(&key) { + Disposition::Follower(tx.subscribe()) + } else { + let (tx, _) = broadcast::channel::>(1); + map.insert(key, tx.clone()); + Disposition::Leader(tx) + } +} + struct InflightGuard<'a> { inflight: &'a Mutex, key: (String, QueryType), @@ -559,39 +560,46 @@ mod tests { assert!(m.contains_key(&key_aaaa)); } - // ---- Coalescing disposition tests ---- + // ---- Coalescing disposition tests (via acquire_inflight) ---- #[test] - fn leader_follower_disposition() { - // First caller becomes leader, second becomes follower + fn first_caller_becomes_leader() { let map: Mutex = Mutex::new(HashMap::new()); let key = ("test.com".to_string(), QueryType::A); - // First: no entry → insert and become leader - let is_leader = { - let mut m = map.lock().unwrap(); - if m.get(&key).is_some() { - false - } else { - let (tx, _) = broadcast::channel::>(1); - m.insert(key.clone(), tx); - true - } - }; - assert!(is_leader); + let d = acquire_inflight(&map, key.clone()); + assert!(matches!(d, Disposition::Leader(_))); + assert_eq!(map.lock().unwrap().len(), 1); + } - // Second: entry exists → become follower - let is_follower = { - let m = map.lock().unwrap(); - m.get(&key).is_some() - }; - assert!(is_follower); + #[test] + fn second_caller_becomes_follower() { + let map: Mutex = Mutex::new(HashMap::new()); + let key = ("test.com".to_string(), QueryType::A); + + let _leader = acquire_inflight(&map, key.clone()); + let follower = acquire_inflight(&map, key); + assert!(matches!(follower, Disposition::Follower(_))); + // Map still has exactly 1 entry — follower subscribes, doesn't insert + assert_eq!(map.lock().unwrap().len(), 1); } #[tokio::test] - async fn broadcast_delivers_result_to_follower() { - let (tx, _) = broadcast::channel::>(1); - let mut rx = tx.subscribe(); + async fn leader_broadcast_reaches_follower() { + let map: Mutex = Mutex::new(HashMap::new()); + let key = ("test.com".to_string(), QueryType::A); + + let leader = acquire_inflight(&map, key.clone()); + let follower = acquire_inflight(&map, key); + + let tx = match leader { + Disposition::Leader(tx) => tx, + _ => panic!("expected leader"), + }; + let mut rx = match follower { + Disposition::Follower(rx) => rx, + _ => panic!("expected follower"), + }; let mut resp = DnsPacket::new(); resp.header.id = 42; @@ -600,29 +608,48 @@ mod tests { addr: Ipv4Addr::new(1, 2, 3, 4), ttl: 300, }); - let _ = tx.send(Some(resp)); + let received = rx.recv().await.unwrap().unwrap(); assert_eq!(received.header.id, 42); assert_eq!(received.answers.len(), 1); } #[tokio::test] - async fn broadcast_none_signals_failure() { - let (tx, _) = broadcast::channel::>(1); - let mut rx = tx.subscribe(); - let _ = tx.send(None); + async fn leader_none_signals_failure_to_follower() { + let map: Mutex = Mutex::new(HashMap::new()); + let key = ("test.com".to_string(), QueryType::A); - let received = rx.recv().await.unwrap(); - assert!(received.is_none()); + let leader = acquire_inflight(&map, key.clone()); + let follower = acquire_inflight(&map, key); + + let tx = match leader { + Disposition::Leader(tx) => tx, + _ => panic!("expected leader"), + }; + let mut rx = match follower { + Disposition::Follower(rx) => rx, + _ => panic!("expected follower"), + }; + + let _ = tx.send(None); + assert!(rx.recv().await.unwrap().is_none()); } #[tokio::test] - async fn multiple_followers_all_receive_result() { - let (tx, _) = broadcast::channel::>(1); - let mut rx1 = tx.subscribe(); - let mut rx2 = tx.subscribe(); - let mut rx3 = tx.subscribe(); + async fn multiple_followers_all_receive_via_acquire() { + let map: Mutex = Mutex::new(HashMap::new()); + let key = ("multi.com".to_string(), QueryType::A); + + let leader = acquire_inflight(&map, key.clone()); + let f1 = acquire_inflight(&map, key.clone()); + let f2 = acquire_inflight(&map, key.clone()); + let f3 = acquire_inflight(&map, key); + + let tx = match leader { + Disposition::Leader(tx) => tx, + _ => panic!("expected leader"), + }; let mut resp = DnsPacket::new(); resp.answers.push(DnsRecord::A { @@ -632,7 +659,11 @@ mod tests { }); let _ = tx.send(Some(resp)); - for rx in [&mut rx1, &mut rx2, &mut rx3] { + for f in [f1, f2, f3] { + let mut rx = match f { + Disposition::Follower(rx) => rx, + _ => panic!("expected follower"), + }; let r = rx.recv().await.unwrap().unwrap(); assert_eq!(r.answers.len(), 1); }