perf: optimize DNS query hot path (#15)
* perf: optimize hot path — RwLock, inline filtering, pre-allocated strings - Mutex → RwLock for cache, blocklist, and overrides (concurrent read access) - Make cache.lookup() and overrides.lookup() take &self (read-only) - Eliminate 3 Vec allocations per DnsPacket::write() via inline filtering - Pre-allocate domain strings with capacity 64 in parse path - Add criterion micro-benchmarks (hot_path + throughput) - Add bench README documenting both benchmark suites Measured improvement: ~14% faster parsing, ~9% pipeline throughput, round-trip cached 733ns → 698ns (~2.3M queries/sec). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * chore: simplify benchmark code after review - Remove redundant DnsHeader::new() (already set by DnsPacket::new()) - Remove unused DnsHeader import - Change simulate_cached_pipeline to take &DnsCache (lookup is &self now) - Remove unnecessary mut on cache in cache_lookup_miss bench Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit was merged in pull request #15.
This commit is contained in:
51
src/api.rs
51
src/api.rs
@@ -220,7 +220,7 @@ async fn create_overrides(
|
||||
})
|
||||
.collect::<Result<Vec<_>, (StatusCode, String)>>()?;
|
||||
|
||||
let mut store = ctx.overrides.lock().unwrap();
|
||||
let mut store = ctx.overrides.write().unwrap();
|
||||
let mut responses = Vec::with_capacity(parsed.len());
|
||||
|
||||
for (domain, target, ttl, duration_secs) in parsed {
|
||||
@@ -241,7 +241,7 @@ async fn create_overrides(
|
||||
}
|
||||
|
||||
async fn list_overrides(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<OverrideResponse>> {
|
||||
let store = ctx.overrides.lock().unwrap();
|
||||
let store = ctx.overrides.read().unwrap();
|
||||
let entries: Vec<OverrideResponse> = store
|
||||
.list()
|
||||
.into_iter()
|
||||
@@ -254,7 +254,7 @@ async fn get_override(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Path(domain): Path<String>,
|
||||
) -> Result<Json<OverrideResponse>, StatusCode> {
|
||||
let store = ctx.overrides.lock().unwrap();
|
||||
let store = ctx.overrides.read().unwrap();
|
||||
let entry = store.get(&domain).ok_or(StatusCode::NOT_FOUND)?;
|
||||
Ok(Json(OverrideResponse::from(entry)))
|
||||
}
|
||||
@@ -263,7 +263,7 @@ async fn remove_override(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Path(domain): Path<String>,
|
||||
) -> StatusCode {
|
||||
let mut store = ctx.overrides.lock().unwrap();
|
||||
let mut store = ctx.overrides.write().unwrap();
|
||||
if store.remove(&domain) {
|
||||
StatusCode::NO_CONTENT
|
||||
} else {
|
||||
@@ -272,7 +272,7 @@ async fn remove_override(
|
||||
}
|
||||
|
||||
async fn clear_overrides(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
|
||||
ctx.overrides.lock().unwrap().clear();
|
||||
ctx.overrides.write().unwrap().clear();
|
||||
StatusCode::NO_CONTENT
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ async fn load_environment(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Json(req): Json<EnvironmentRequest>,
|
||||
) -> Result<(StatusCode, Json<EnvironmentResponse>), (StatusCode, String)> {
|
||||
let mut store = ctx.overrides.lock().unwrap();
|
||||
let mut store = ctx.overrides.write().unwrap();
|
||||
|
||||
for entry in &req.overrides {
|
||||
let duration = entry.duration_secs.or(req.duration_secs);
|
||||
@@ -307,7 +307,7 @@ async fn diagnose(
|
||||
|
||||
// Check overrides
|
||||
{
|
||||
let store = ctx.overrides.lock().unwrap();
|
||||
let store = ctx.overrides.read().unwrap();
|
||||
let entry = store.get(&domain_lower);
|
||||
steps.push(DiagnoseStep {
|
||||
source: "override".to_string(),
|
||||
@@ -319,7 +319,7 @@ async fn diagnose(
|
||||
|
||||
// Check blocklist
|
||||
{
|
||||
let bl = ctx.blocklist.lock().unwrap();
|
||||
let bl = ctx.blocklist.read().unwrap();
|
||||
let blocked = bl.is_blocked(&domain_lower);
|
||||
steps.push(DiagnoseStep {
|
||||
source: "blocklist".to_string(),
|
||||
@@ -345,7 +345,7 @@ async fn diagnose(
|
||||
|
||||
// Check cache
|
||||
{
|
||||
let mut cache = ctx.cache.lock().unwrap();
|
||||
let cache = ctx.cache.read().unwrap();
|
||||
let cached = cache.lookup(&domain_lower, qtype);
|
||||
steps.push(DiagnoseStep {
|
||||
source: "cache".to_string(),
|
||||
@@ -443,11 +443,11 @@ async fn query_log(
|
||||
async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
|
||||
let snap = ctx.stats.lock().unwrap().snapshot();
|
||||
let (cache_len, cache_max) = {
|
||||
let cache = ctx.cache.lock().unwrap();
|
||||
let cache = ctx.cache.read().unwrap();
|
||||
(cache.len(), cache.max_entries())
|
||||
};
|
||||
let override_count = ctx.overrides.lock().unwrap().active_count();
|
||||
let bl_stats = ctx.blocklist.lock().unwrap().stats();
|
||||
let override_count = ctx.overrides.read().unwrap().active_count();
|
||||
let bl_stats = ctx.blocklist.read().unwrap().stats();
|
||||
|
||||
let upstream = ctx.upstream.lock().unwrap().to_string();
|
||||
|
||||
@@ -486,7 +486,7 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
|
||||
}
|
||||
|
||||
async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryResponse>> {
|
||||
let cache = ctx.cache.lock().unwrap();
|
||||
let cache = ctx.cache.read().unwrap();
|
||||
let entries: Vec<CacheEntryResponse> = cache
|
||||
.list()
|
||||
.into_iter()
|
||||
@@ -500,7 +500,7 @@ async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryRes
|
||||
}
|
||||
|
||||
async fn flush_cache(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
|
||||
ctx.cache.lock().unwrap().clear();
|
||||
ctx.cache.write().unwrap().clear();
|
||||
StatusCode::NO_CONTENT
|
||||
}
|
||||
|
||||
@@ -508,7 +508,7 @@ async fn flush_cache_domain(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Path(domain): Path<String>,
|
||||
) -> StatusCode {
|
||||
ctx.cache.lock().unwrap().remove(&domain);
|
||||
ctx.cache.write().unwrap().remove(&domain);
|
||||
StatusCode::NO_CONTENT
|
||||
}
|
||||
|
||||
@@ -519,7 +519,7 @@ async fn health() -> Json<serde_json::Value> {
|
||||
// --- Blocking handlers ---
|
||||
|
||||
async fn blocking_stats(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
|
||||
let stats = ctx.blocklist.lock().unwrap().stats();
|
||||
let stats = ctx.blocklist.read().unwrap().stats();
|
||||
Json(serde_json::json!({
|
||||
"enabled": stats.enabled,
|
||||
"paused": stats.paused,
|
||||
@@ -539,7 +539,7 @@ async fn blocking_toggle(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Json(req): Json<BlockingToggleRequest>,
|
||||
) -> Json<serde_json::Value> {
|
||||
ctx.blocklist.lock().unwrap().set_enabled(req.enabled);
|
||||
ctx.blocklist.write().unwrap().set_enabled(req.enabled);
|
||||
Json(serde_json::json!({ "enabled": req.enabled }))
|
||||
}
|
||||
|
||||
@@ -557,12 +557,12 @@ async fn blocking_pause(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Json(req): Json<BlockingPauseRequest>,
|
||||
) -> Json<serde_json::Value> {
|
||||
ctx.blocklist.lock().unwrap().pause(req.minutes * 60);
|
||||
ctx.blocklist.write().unwrap().pause(req.minutes * 60);
|
||||
Json(serde_json::json!({ "paused_minutes": req.minutes }))
|
||||
}
|
||||
|
||||
async fn blocking_unpause(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
|
||||
ctx.blocklist.lock().unwrap().unpause();
|
||||
ctx.blocklist.write().unwrap().unpause();
|
||||
Json(serde_json::json!({ "paused": false }))
|
||||
}
|
||||
|
||||
@@ -570,12 +570,12 @@ async fn blocking_check(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Path(domain): Path<String>,
|
||||
) -> Json<crate::blocklist::BlockCheckResult> {
|
||||
let result = ctx.blocklist.lock().unwrap().check(&domain);
|
||||
let result = ctx.blocklist.read().unwrap().check(&domain);
|
||||
Json(result)
|
||||
}
|
||||
|
||||
async fn blocking_allowlist(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<String>> {
|
||||
let list = ctx.blocklist.lock().unwrap().allowlist();
|
||||
let list = ctx.blocklist.read().unwrap().allowlist();
|
||||
Json(list)
|
||||
}
|
||||
|
||||
@@ -588,7 +588,7 @@ async fn blocking_allowlist_add(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Json(req): Json<AllowlistRequest>,
|
||||
) -> (StatusCode, Json<serde_json::Value>) {
|
||||
ctx.blocklist.lock().unwrap().add_to_allowlist(&req.domain);
|
||||
ctx.blocklist.write().unwrap().add_to_allowlist(&req.domain);
|
||||
(
|
||||
StatusCode::CREATED,
|
||||
Json(serde_json::json!({ "allowed": req.domain })),
|
||||
@@ -599,7 +599,12 @@ async fn blocking_allowlist_remove(
|
||||
State(ctx): State<Arc<ServerCtx>>,
|
||||
Path(domain): Path<String>,
|
||||
) -> StatusCode {
|
||||
if ctx.blocklist.lock().unwrap().remove_from_allowlist(&domain) {
|
||||
if ctx
|
||||
.blocklist
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove_from_allowlist(&domain)
|
||||
{
|
||||
StatusCode::NO_CONTENT
|
||||
} else {
|
||||
StatusCode::NOT_FOUND
|
||||
|
||||
18
src/cache.rs
18
src/cache.rs
@@ -19,7 +19,6 @@ pub struct DnsCache {
|
||||
max_entries: usize,
|
||||
min_ttl: u32,
|
||||
max_ttl: u32,
|
||||
query_count: u64,
|
||||
}
|
||||
|
||||
impl DnsCache {
|
||||
@@ -30,29 +29,16 @@ impl DnsCache {
|
||||
max_entries,
|
||||
min_ttl,
|
||||
max_ttl,
|
||||
query_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lookup(&mut self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
|
||||
self.query_count += 1;
|
||||
|
||||
if self.query_count.is_multiple_of(1000) {
|
||||
self.evict_expired();
|
||||
}
|
||||
|
||||
/// Read-only lookup — expired entries are left in place (cleaned up on insert).
|
||||
pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
|
||||
let type_map = self.entries.get(domain)?;
|
||||
let entry = type_map.get(&qtype)?;
|
||||
|
||||
let elapsed = entry.inserted_at.elapsed();
|
||||
if elapsed >= entry.ttl {
|
||||
// Expired: remove this entry
|
||||
let type_map = self.entries.get_mut(domain).unwrap();
|
||||
type_map.remove(&qtype);
|
||||
self.entry_count -= 1;
|
||||
if type_map.is_empty() {
|
||||
self.entries.remove(domain);
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
|
||||
16
src/ctx.rs
16
src/ctx.rs
@@ -1,6 +1,6 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::{Mutex, RwLock};
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
@@ -27,10 +27,10 @@ use crate::system_dns::ForwardingRule;
|
||||
pub struct ServerCtx {
|
||||
pub socket: UdpSocket,
|
||||
pub zone_map: ZoneMap,
|
||||
pub cache: Mutex<DnsCache>,
|
||||
pub cache: RwLock<DnsCache>,
|
||||
pub stats: Mutex<ServerStats>,
|
||||
pub overrides: Mutex<OverrideStore>,
|
||||
pub blocklist: Mutex<BlocklistStore>,
|
||||
pub overrides: RwLock<OverrideStore>,
|
||||
pub blocklist: RwLock<BlocklistStore>,
|
||||
pub query_log: Mutex<QueryLog>,
|
||||
pub services: Mutex<ServiceStore>,
|
||||
pub lan_peers: Mutex<PeerStore>,
|
||||
@@ -73,7 +73,7 @@ pub async fn handle_query(
|
||||
// Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream
|
||||
// Each lock is scoped to avoid holding MutexGuard across await points.
|
||||
let (response, path) = {
|
||||
let override_record = ctx.overrides.lock().unwrap().lookup(&qname);
|
||||
let override_record = ctx.overrides.read().unwrap().lookup(&qname);
|
||||
if let Some(record) = override_record {
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
resp.answers.push(record);
|
||||
@@ -116,7 +116,7 @@ pub async fn handle_query(
|
||||
}),
|
||||
}
|
||||
(resp, QueryPath::Local)
|
||||
} else if ctx.blocklist.lock().unwrap().is_blocked(&qname) {
|
||||
} else if ctx.blocklist.read().unwrap().is_blocked(&qname) {
|
||||
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
|
||||
match qtype {
|
||||
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
|
||||
@@ -136,7 +136,7 @@ pub async fn handle_query(
|
||||
resp.answers = records.clone();
|
||||
(resp, QueryPath::Local)
|
||||
} else {
|
||||
let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype);
|
||||
let cached = ctx.cache.read().unwrap().lookup(&qname, qtype);
|
||||
if let Some(cached) = cached {
|
||||
let mut resp = cached;
|
||||
resp.header.id = query.header.id;
|
||||
@@ -149,7 +149,7 @@ pub async fn handle_query(
|
||||
};
|
||||
match forward_query(&query, &upstream, ctx.timeout).await {
|
||||
Ok(resp) => {
|
||||
ctx.cache.lock().unwrap().insert(&qname, qtype, &resp);
|
||||
ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
|
||||
(resp, QueryPath::Forwarded)
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
10
src/main.rs
10
src/main.rs
@@ -1,5 +1,5 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
@@ -170,14 +170,14 @@ async fn main() -> numa::Result<()> {
|
||||
let ctx = Arc::new(ServerCtx {
|
||||
socket: UdpSocket::bind(&config.server.bind_addr).await?,
|
||||
zone_map: build_zone_map(&config.zones)?,
|
||||
cache: Mutex::new(DnsCache::new(
|
||||
cache: RwLock::new(DnsCache::new(
|
||||
config.cache.max_entries,
|
||||
config.cache.min_ttl,
|
||||
config.cache.max_ttl,
|
||||
)),
|
||||
stats: Mutex::new(ServerStats::new()),
|
||||
overrides: Mutex::new(OverrideStore::new()),
|
||||
blocklist: Mutex::new(blocklist),
|
||||
overrides: RwLock::new(OverrideStore::new()),
|
||||
blocklist: RwLock::new(blocklist),
|
||||
query_log: Mutex::new(QueryLog::new(1000)),
|
||||
services: Mutex::new(service_store),
|
||||
lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)),
|
||||
@@ -541,7 +541,7 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
|
||||
|
||||
// Swap under lock — sub-microsecond
|
||||
ctx.blocklist
|
||||
.lock()
|
||||
.write()
|
||||
.unwrap()
|
||||
.swap_domains(all_domains, sources);
|
||||
info!(
|
||||
|
||||
@@ -64,6 +64,9 @@ impl OverrideStore {
|
||||
ttl: u32,
|
||||
duration_secs: Option<u64>,
|
||||
) -> Result<QueryType> {
|
||||
// Clean up expired entries on write
|
||||
self.entries.retain(|_, e| !e.is_expired());
|
||||
|
||||
let domain_lower = domain.to_lowercase();
|
||||
let (qtype, record) = parse_target(&domain_lower, target, ttl)?;
|
||||
|
||||
@@ -84,10 +87,10 @@ impl OverrideStore {
|
||||
}
|
||||
|
||||
/// Hot path: assumes `domain` is already lowercased (the parser does this).
|
||||
pub fn lookup(&mut self, domain: &str) -> Option<DnsRecord> {
|
||||
/// Read-only — expired entries are left in place (cleaned up on write operations).
|
||||
pub fn lookup(&self, domain: &str) -> Option<DnsRecord> {
|
||||
let entry = self.entries.get(domain)?;
|
||||
if entry.is_expired() {
|
||||
self.entries.remove(domain);
|
||||
return None;
|
||||
}
|
||||
Some(entry.record.clone())
|
||||
|
||||
@@ -46,7 +46,7 @@ impl DnsPacket {
|
||||
result.header.read(buffer)?;
|
||||
|
||||
for _ in 0..result.header.questions {
|
||||
let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
|
||||
let mut question = DnsQuestion::new(String::with_capacity(64), QueryType::UNKNOWN(0));
|
||||
question.read(buffer)?;
|
||||
result.questions.push(question);
|
||||
}
|
||||
@@ -68,34 +68,36 @@ impl DnsPacket {
|
||||
}
|
||||
|
||||
pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> {
|
||||
// Filter out UNKNOWN records (e.g. EDNS OPT) that we can't re-serialize
|
||||
let answers: Vec<_> = self.answers.iter().filter(|r| !r.is_unknown()).collect();
|
||||
let authorities: Vec<_> = self
|
||||
.authorities
|
||||
.iter()
|
||||
.filter(|r| !r.is_unknown())
|
||||
.collect();
|
||||
let resources: Vec<_> = self.resources.iter().filter(|r| !r.is_unknown()).collect();
|
||||
// Count known records without allocating filter Vecs
|
||||
let answer_count = self.answers.iter().filter(|r| !r.is_unknown()).count() as u16;
|
||||
let auth_count = self.authorities.iter().filter(|r| !r.is_unknown()).count() as u16;
|
||||
let res_count = self.resources.iter().filter(|r| !r.is_unknown()).count() as u16;
|
||||
|
||||
let mut header = self.header.clone();
|
||||
header.questions = self.questions.len() as u16;
|
||||
header.answers = answers.len() as u16;
|
||||
header.authoritative_entries = authorities.len() as u16;
|
||||
header.resource_entries = resources.len() as u16;
|
||||
header.answers = answer_count;
|
||||
header.authoritative_entries = auth_count;
|
||||
header.resource_entries = res_count;
|
||||
|
||||
header.write(buffer)?;
|
||||
|
||||
for question in &self.questions {
|
||||
question.write(buffer)?;
|
||||
}
|
||||
for rec in answers {
|
||||
rec.write(buffer)?;
|
||||
for rec in &self.answers {
|
||||
if !rec.is_unknown() {
|
||||
rec.write(buffer)?;
|
||||
}
|
||||
}
|
||||
for rec in authorities {
|
||||
rec.write(buffer)?;
|
||||
for rec in &self.authorities {
|
||||
if !rec.is_unknown() {
|
||||
rec.write(buffer)?;
|
||||
}
|
||||
}
|
||||
for rec in resources {
|
||||
rec.write(buffer)?;
|
||||
for rec in &self.resources {
|
||||
if !rec.is_unknown() {
|
||||
rec.write(buffer)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -70,7 +70,7 @@ impl DnsRecord {
|
||||
}
|
||||
|
||||
pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> {
|
||||
let mut domain = String::new();
|
||||
let mut domain = String::with_capacity(64);
|
||||
buffer.read_qname(&mut domain)?;
|
||||
|
||||
let qtype_num = buffer.read_u16()?;
|
||||
@@ -110,7 +110,7 @@ impl DnsRecord {
|
||||
Ok(DnsRecord::AAAA { domain, addr, ttl })
|
||||
}
|
||||
QueryType::NS => {
|
||||
let mut ns = String::new();
|
||||
let mut ns = String::with_capacity(64);
|
||||
buffer.read_qname(&mut ns)?;
|
||||
|
||||
Ok(DnsRecord::NS {
|
||||
@@ -120,7 +120,7 @@ impl DnsRecord {
|
||||
})
|
||||
}
|
||||
QueryType::CNAME => {
|
||||
let mut cname = String::new();
|
||||
let mut cname = String::with_capacity(64);
|
||||
buffer.read_qname(&mut cname)?;
|
||||
|
||||
Ok(DnsRecord::CNAME {
|
||||
@@ -131,7 +131,7 @@ impl DnsRecord {
|
||||
}
|
||||
QueryType::MX => {
|
||||
let priority = buffer.read_u16()?;
|
||||
let mut mx = String::new();
|
||||
let mut mx = String::with_capacity(64);
|
||||
buffer.read_qname(&mut mx)?;
|
||||
|
||||
Ok(DnsRecord::MX {
|
||||
|
||||
Reference in New Issue
Block a user