perf: optimize DNS query hot path (#15)

* perf: optimize hot path — RwLock, inline filtering, pre-allocated strings

- Mutex → RwLock for cache, blocklist, and overrides (concurrent read access)
- Make cache.lookup() and overrides.lookup() take &self (read-only)
- Eliminate 3 Vec allocations per DnsPacket::write() via inline filtering
- Pre-allocate domain strings with capacity 64 in parse path
- Add criterion micro-benchmarks (hot_path + throughput)
- Add bench README documenting both benchmark suites

Measured improvement: ~14% faster parsing, ~9% pipeline throughput,
round-trip cached 733ns → 698ns (~2.3M queries/sec).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore: simplify benchmark code after review

- Remove redundant DnsHeader::new() (already set by DnsPacket::new())
- Remove unused DnsHeader import
- Change simulate_cached_pipeline to take &DnsCache (lookup is &self now)
- Remove unnecessary mut on cache in cache_lookup_miss bench

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-27 02:01:08 +02:00
committed by GitHub
parent 1f4063d5db
commit 962b400f4c
13 changed files with 728 additions and 77 deletions

View File

@@ -220,7 +220,7 @@ async fn create_overrides(
})
.collect::<Result<Vec<_>, (StatusCode, String)>>()?;
let mut store = ctx.overrides.lock().unwrap();
let mut store = ctx.overrides.write().unwrap();
let mut responses = Vec::with_capacity(parsed.len());
for (domain, target, ttl, duration_secs) in parsed {
@@ -241,7 +241,7 @@ async fn create_overrides(
}
async fn list_overrides(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<OverrideResponse>> {
let store = ctx.overrides.lock().unwrap();
let store = ctx.overrides.read().unwrap();
let entries: Vec<OverrideResponse> = store
.list()
.into_iter()
@@ -254,7 +254,7 @@ async fn get_override(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> Result<Json<OverrideResponse>, StatusCode> {
let store = ctx.overrides.lock().unwrap();
let store = ctx.overrides.read().unwrap();
let entry = store.get(&domain).ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(OverrideResponse::from(entry)))
}
@@ -263,7 +263,7 @@ async fn remove_override(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
let mut store = ctx.overrides.lock().unwrap();
let mut store = ctx.overrides.write().unwrap();
if store.remove(&domain) {
StatusCode::NO_CONTENT
} else {
@@ -272,7 +272,7 @@ async fn remove_override(
}
async fn clear_overrides(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.overrides.lock().unwrap().clear();
ctx.overrides.write().unwrap().clear();
StatusCode::NO_CONTENT
}
@@ -280,7 +280,7 @@ async fn load_environment(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<EnvironmentRequest>,
) -> Result<(StatusCode, Json<EnvironmentResponse>), (StatusCode, String)> {
let mut store = ctx.overrides.lock().unwrap();
let mut store = ctx.overrides.write().unwrap();
for entry in &req.overrides {
let duration = entry.duration_secs.or(req.duration_secs);
@@ -307,7 +307,7 @@ async fn diagnose(
// Check overrides
{
let store = ctx.overrides.lock().unwrap();
let store = ctx.overrides.read().unwrap();
let entry = store.get(&domain_lower);
steps.push(DiagnoseStep {
source: "override".to_string(),
@@ -319,7 +319,7 @@ async fn diagnose(
// Check blocklist
{
let bl = ctx.blocklist.lock().unwrap();
let bl = ctx.blocklist.read().unwrap();
let blocked = bl.is_blocked(&domain_lower);
steps.push(DiagnoseStep {
source: "blocklist".to_string(),
@@ -345,7 +345,7 @@ async fn diagnose(
// Check cache
{
let mut cache = ctx.cache.lock().unwrap();
let cache = ctx.cache.read().unwrap();
let cached = cache.lookup(&domain_lower, qtype);
steps.push(DiagnoseStep {
source: "cache".to_string(),
@@ -443,11 +443,11 @@ async fn query_log(
async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
let snap = ctx.stats.lock().unwrap().snapshot();
let (cache_len, cache_max) = {
let cache = ctx.cache.lock().unwrap();
let cache = ctx.cache.read().unwrap();
(cache.len(), cache.max_entries())
};
let override_count = ctx.overrides.lock().unwrap().active_count();
let bl_stats = ctx.blocklist.lock().unwrap().stats();
let override_count = ctx.overrides.read().unwrap().active_count();
let bl_stats = ctx.blocklist.read().unwrap().stats();
let upstream = ctx.upstream.lock().unwrap().to_string();
@@ -486,7 +486,7 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
}
async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryResponse>> {
let cache = ctx.cache.lock().unwrap();
let cache = ctx.cache.read().unwrap();
let entries: Vec<CacheEntryResponse> = cache
.list()
.into_iter()
@@ -500,7 +500,7 @@ async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryRes
}
async fn flush_cache(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.cache.lock().unwrap().clear();
ctx.cache.write().unwrap().clear();
StatusCode::NO_CONTENT
}
@@ -508,7 +508,7 @@ async fn flush_cache_domain(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
ctx.cache.lock().unwrap().remove(&domain);
ctx.cache.write().unwrap().remove(&domain);
StatusCode::NO_CONTENT
}
@@ -519,7 +519,7 @@ async fn health() -> Json<serde_json::Value> {
// --- Blocking handlers ---
async fn blocking_stats(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
let stats = ctx.blocklist.lock().unwrap().stats();
let stats = ctx.blocklist.read().unwrap().stats();
Json(serde_json::json!({
"enabled": stats.enabled,
"paused": stats.paused,
@@ -539,7 +539,7 @@ async fn blocking_toggle(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingToggleRequest>,
) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().set_enabled(req.enabled);
ctx.blocklist.write().unwrap().set_enabled(req.enabled);
Json(serde_json::json!({ "enabled": req.enabled }))
}
@@ -557,12 +557,12 @@ async fn blocking_pause(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingPauseRequest>,
) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().pause(req.minutes * 60);
ctx.blocklist.write().unwrap().pause(req.minutes * 60);
Json(serde_json::json!({ "paused_minutes": req.minutes }))
}
async fn blocking_unpause(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().unpause();
ctx.blocklist.write().unwrap().unpause();
Json(serde_json::json!({ "paused": false }))
}
@@ -570,12 +570,12 @@ async fn blocking_check(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> Json<crate::blocklist::BlockCheckResult> {
let result = ctx.blocklist.lock().unwrap().check(&domain);
let result = ctx.blocklist.read().unwrap().check(&domain);
Json(result)
}
async fn blocking_allowlist(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<String>> {
let list = ctx.blocklist.lock().unwrap().allowlist();
let list = ctx.blocklist.read().unwrap().allowlist();
Json(list)
}
@@ -588,7 +588,7 @@ async fn blocking_allowlist_add(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<AllowlistRequest>,
) -> (StatusCode, Json<serde_json::Value>) {
ctx.blocklist.lock().unwrap().add_to_allowlist(&req.domain);
ctx.blocklist.write().unwrap().add_to_allowlist(&req.domain);
(
StatusCode::CREATED,
Json(serde_json::json!({ "allowed": req.domain })),
@@ -599,7 +599,12 @@ async fn blocking_allowlist_remove(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
if ctx.blocklist.lock().unwrap().remove_from_allowlist(&domain) {
if ctx
.blocklist
.write()
.unwrap()
.remove_from_allowlist(&domain)
{
StatusCode::NO_CONTENT
} else {
StatusCode::NOT_FOUND

View File

@@ -19,7 +19,6 @@ pub struct DnsCache {
max_entries: usize,
min_ttl: u32,
max_ttl: u32,
query_count: u64,
}
impl DnsCache {
@@ -30,29 +29,16 @@ impl DnsCache {
max_entries,
min_ttl,
max_ttl,
query_count: 0,
}
}
pub fn lookup(&mut self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
self.query_count += 1;
if self.query_count.is_multiple_of(1000) {
self.evict_expired();
}
/// Read-only lookup — expired entries are left in place (cleaned up on insert).
pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
let type_map = self.entries.get(domain)?;
let entry = type_map.get(&qtype)?;
let elapsed = entry.inserted_at.elapsed();
if elapsed >= entry.ttl {
// Expired: remove this entry
let type_map = self.entries.get_mut(domain).unwrap();
type_map.remove(&qtype);
self.entry_count -= 1;
if type_map.is_empty() {
self.entries.remove(domain);
}
return None;
}

View File

@@ -1,6 +1,6 @@
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Mutex;
use std::sync::{Mutex, RwLock};
use std::time::{Duration, Instant, SystemTime};
use arc_swap::ArcSwap;
@@ -27,10 +27,10 @@ use crate::system_dns::ForwardingRule;
pub struct ServerCtx {
pub socket: UdpSocket,
pub zone_map: ZoneMap,
pub cache: Mutex<DnsCache>,
pub cache: RwLock<DnsCache>,
pub stats: Mutex<ServerStats>,
pub overrides: Mutex<OverrideStore>,
pub blocklist: Mutex<BlocklistStore>,
pub overrides: RwLock<OverrideStore>,
pub blocklist: RwLock<BlocklistStore>,
pub query_log: Mutex<QueryLog>,
pub services: Mutex<ServiceStore>,
pub lan_peers: Mutex<PeerStore>,
@@ -73,7 +73,7 @@ pub async fn handle_query(
// Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream
// Each lock is scoped to avoid holding MutexGuard across await points.
let (response, path) = {
let override_record = ctx.overrides.lock().unwrap().lookup(&qname);
let override_record = ctx.overrides.read().unwrap().lookup(&qname);
if let Some(record) = override_record {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers.push(record);
@@ -116,7 +116,7 @@ pub async fn handle_query(
}),
}
(resp, QueryPath::Local)
} else if ctx.blocklist.lock().unwrap().is_blocked(&qname) {
} else if ctx.blocklist.read().unwrap().is_blocked(&qname) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
match qtype {
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
@@ -136,7 +136,7 @@ pub async fn handle_query(
resp.answers = records.clone();
(resp, QueryPath::Local)
} else {
let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype);
let cached = ctx.cache.read().unwrap().lookup(&qname, qtype);
if let Some(cached) = cached {
let mut resp = cached;
resp.header.id = query.header.id;
@@ -149,7 +149,7 @@ pub async fn handle_query(
};
match forward_query(&query, &upstream, ctx.timeout).await {
Ok(resp) => {
ctx.cache.lock().unwrap().insert(&qname, qtype, &resp);
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded)
}
Err(e) => {

View File

@@ -1,5 +1,5 @@
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;
use arc_swap::ArcSwap;
@@ -170,14 +170,14 @@ async fn main() -> numa::Result<()> {
let ctx = Arc::new(ServerCtx {
socket: UdpSocket::bind(&config.server.bind_addr).await?,
zone_map: build_zone_map(&config.zones)?,
cache: Mutex::new(DnsCache::new(
cache: RwLock::new(DnsCache::new(
config.cache.max_entries,
config.cache.min_ttl,
config.cache.max_ttl,
)),
stats: Mutex::new(ServerStats::new()),
overrides: Mutex::new(OverrideStore::new()),
blocklist: Mutex::new(blocklist),
overrides: RwLock::new(OverrideStore::new()),
blocklist: RwLock::new(blocklist),
query_log: Mutex::new(QueryLog::new(1000)),
services: Mutex::new(service_store),
lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)),
@@ -541,7 +541,7 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
// Swap under lock — sub-microsecond
ctx.blocklist
.lock()
.write()
.unwrap()
.swap_domains(all_domains, sources);
info!(

View File

@@ -64,6 +64,9 @@ impl OverrideStore {
ttl: u32,
duration_secs: Option<u64>,
) -> Result<QueryType> {
// Clean up expired entries on write
self.entries.retain(|_, e| !e.is_expired());
let domain_lower = domain.to_lowercase();
let (qtype, record) = parse_target(&domain_lower, target, ttl)?;
@@ -84,10 +87,10 @@ impl OverrideStore {
}
/// Hot path: assumes `domain` is already lowercased (the parser does this).
pub fn lookup(&mut self, domain: &str) -> Option<DnsRecord> {
/// Read-only — expired entries are left in place (cleaned up on write operations).
pub fn lookup(&self, domain: &str) -> Option<DnsRecord> {
let entry = self.entries.get(domain)?;
if entry.is_expired() {
self.entries.remove(domain);
return None;
}
Some(entry.record.clone())

View File

@@ -46,7 +46,7 @@ impl DnsPacket {
result.header.read(buffer)?;
for _ in 0..result.header.questions {
let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
let mut question = DnsQuestion::new(String::with_capacity(64), QueryType::UNKNOWN(0));
question.read(buffer)?;
result.questions.push(question);
}
@@ -68,34 +68,36 @@ impl DnsPacket {
}
pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> {
// Filter out UNKNOWN records (e.g. EDNS OPT) that we can't re-serialize
let answers: Vec<_> = self.answers.iter().filter(|r| !r.is_unknown()).collect();
let authorities: Vec<_> = self
.authorities
.iter()
.filter(|r| !r.is_unknown())
.collect();
let resources: Vec<_> = self.resources.iter().filter(|r| !r.is_unknown()).collect();
// Count known records without allocating filter Vecs
let answer_count = self.answers.iter().filter(|r| !r.is_unknown()).count() as u16;
let auth_count = self.authorities.iter().filter(|r| !r.is_unknown()).count() as u16;
let res_count = self.resources.iter().filter(|r| !r.is_unknown()).count() as u16;
let mut header = self.header.clone();
header.questions = self.questions.len() as u16;
header.answers = answers.len() as u16;
header.authoritative_entries = authorities.len() as u16;
header.resource_entries = resources.len() as u16;
header.answers = answer_count;
header.authoritative_entries = auth_count;
header.resource_entries = res_count;
header.write(buffer)?;
for question in &self.questions {
question.write(buffer)?;
}
for rec in answers {
rec.write(buffer)?;
for rec in &self.answers {
if !rec.is_unknown() {
rec.write(buffer)?;
}
}
for rec in authorities {
rec.write(buffer)?;
for rec in &self.authorities {
if !rec.is_unknown() {
rec.write(buffer)?;
}
}
for rec in resources {
rec.write(buffer)?;
for rec in &self.resources {
if !rec.is_unknown() {
rec.write(buffer)?;
}
}
Ok(())

View File

@@ -70,7 +70,7 @@ impl DnsRecord {
}
pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> {
let mut domain = String::new();
let mut domain = String::with_capacity(64);
buffer.read_qname(&mut domain)?;
let qtype_num = buffer.read_u16()?;
@@ -110,7 +110,7 @@ impl DnsRecord {
Ok(DnsRecord::AAAA { domain, addr, ttl })
}
QueryType::NS => {
let mut ns = String::new();
let mut ns = String::with_capacity(64);
buffer.read_qname(&mut ns)?;
Ok(DnsRecord::NS {
@@ -120,7 +120,7 @@ impl DnsRecord {
})
}
QueryType::CNAME => {
let mut cname = String::new();
let mut cname = String::with_capacity(64);
buffer.read_qname(&mut cname)?;
Ok(DnsRecord::CNAME {
@@ -131,7 +131,7 @@ impl DnsRecord {
}
QueryType::MX => {
let priority = buffer.read_u16()?;
let mut mx = String::new();
let mut mx = String::with_capacity(64);
buffer.read_qname(&mut mx)?;
Ok(DnsRecord::MX {