add ad blocking, live dashboard, system DNS auto-discovery

- DNS-level ad blocking: 385K+ domains via Hagezi Pro blocklist, subdomain
  matching, one-click allowlist, pause/toggle, background refresh every 24h
- Live dashboard at :5380 with real-time stats, query log, override
  management (create/edit/delete), blocking controls
- System DNS auto-discovery: parses scutil --dns on macOS to find
  conditional forwarding rules (Tailscale, VPN split-DNS)
- REST API expanded to 18 endpoints (blocking, overrides, diagnostics)
- Startup banner with colored system info
- Performance benchmarks (bench/dns-bench.sh)
- Landing page updated with new positioning and comparison table
- CI, Dockerfile, LICENSE, development plan docs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-20 10:54:23 +02:00
parent e31188fb88
commit 4dc5b94c7a
23 changed files with 5494 additions and 226 deletions

565
src/api.rs Normal file
View File

@@ -0,0 +1,565 @@
use std::sync::Arc;
use std::time::UNIX_EPOCH;
use axum::extract::{Path, Query, State};
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{delete, get, post, put};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use crate::ctx::ServerCtx;
use crate::forward::forward_query;
use crate::query_log::QueryLogFilter;
use crate::question::QueryType;
use crate::stats::QueryPath;
const DASHBOARD_HTML: &str = include_str!("../site/dashboard.html");
pub fn router(ctx: Arc<ServerCtx>) -> Router {
Router::new()
.route("/", get(dashboard))
.route("/overrides", post(create_overrides))
.route("/overrides", get(list_overrides))
.route("/overrides", delete(clear_overrides))
.route("/overrides/environment", post(load_environment))
.route("/overrides/{domain}", get(get_override))
.route("/overrides/{domain}", delete(remove_override))
.route("/diagnose/{domain}", get(diagnose))
.route("/query-log", get(query_log))
.route("/stats", get(stats))
.route("/cache", get(list_cache))
.route("/cache", delete(flush_cache))
.route("/cache/{domain}", delete(flush_cache_domain))
.route("/health", get(health))
.route("/blocking/stats", get(blocking_stats))
.route("/blocking/toggle", put(blocking_toggle))
.route("/blocking/pause", post(blocking_pause))
.route("/blocking/allowlist", get(blocking_allowlist))
.route("/blocking/allowlist", post(blocking_allowlist_add))
.route(
"/blocking/allowlist/{domain}",
delete(blocking_allowlist_remove),
)
.with_state(ctx)
}
async fn dashboard() -> impl IntoResponse {
(
[(header::CONTENT_TYPE, "text/html; charset=utf-8")],
DASHBOARD_HTML,
)
}
// --- Request/Response DTOs ---
#[derive(Deserialize)]
struct CreateOverrideRequest {
domain: String,
target: String,
#[serde(default = "default_ttl")]
ttl: u32,
duration_secs: Option<u64>,
}
fn default_ttl() -> u32 {
60
}
#[derive(Serialize)]
struct OverrideResponse {
domain: String,
target: String,
record_type: String,
ttl: u32,
remaining_secs: Option<u64>,
}
impl From<&crate::override_store::OverrideEntry> for OverrideResponse {
fn from(e: &crate::override_store::OverrideEntry) -> Self {
OverrideResponse {
domain: e.domain.clone(),
target: e.target.clone(),
record_type: e.query_type.as_str().to_string(),
ttl: e.ttl,
remaining_secs: e.remaining_secs(),
}
}
}
#[derive(Deserialize)]
struct EnvironmentRequest {
#[serde(default)]
duration_secs: Option<u64>,
overrides: Vec<CreateOverrideRequest>,
}
#[derive(Serialize)]
struct EnvironmentResponse {
created: usize,
}
#[derive(Deserialize)]
struct QueryLogParams {
domain: Option<String>,
r#type: Option<String>,
path: Option<String>,
limit: Option<usize>,
}
#[derive(Serialize)]
struct QueryLogResponse {
timestamp_epoch: f64,
src: String,
domain: String,
query_type: String,
path: String,
rescode: String,
latency_ms: f64,
}
#[derive(Serialize)]
struct StatsResponse {
uptime_secs: u64,
queries: QueriesStats,
cache: CacheStats,
overrides: OverrideStats,
blocking: BlockingStatsResponse,
}
#[derive(Serialize)]
struct QueriesStats {
total: u64,
forwarded: u64,
cached: u64,
local: u64,
overridden: u64,
blocked: u64,
errors: u64,
}
#[derive(Serialize)]
struct CacheStats {
entries: usize,
max_entries: usize,
}
#[derive(Serialize)]
struct OverrideStats {
active: usize,
}
#[derive(Serialize)]
struct BlockingStatsResponse {
enabled: bool,
paused: bool,
domains_loaded: usize,
allowlist_size: usize,
}
#[derive(Serialize)]
struct DiagnoseResponse {
domain: String,
query_type: String,
steps: Vec<DiagnoseStep>,
}
#[derive(Serialize)]
struct DiagnoseStep {
source: String,
matched: bool,
detail: Option<String>,
}
#[derive(Serialize)]
struct CacheEntryResponse {
domain: String,
query_type: String,
ttl_remaining: u32,
}
// --- Handlers ---
async fn create_overrides(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<serde_json::Value>,
) -> Result<(StatusCode, Json<Vec<OverrideResponse>>), (StatusCode, String)> {
let requests: Vec<CreateOverrideRequest> = if req.is_array() {
serde_json::from_value(req).map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
} else {
let single: CreateOverrideRequest =
serde_json::from_value(req).map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
vec![single]
};
// Parse and validate all requests before acquiring the lock
let parsed: Vec<_> = requests
.into_iter()
.map(|req| {
let domain_lower = req.domain.to_lowercase();
Ok((domain_lower, req.target, req.ttl, req.duration_secs))
})
.collect::<Result<Vec<_>, (StatusCode, String)>>()?;
let mut store = ctx.overrides.lock().unwrap();
let mut responses = Vec::with_capacity(parsed.len());
for (domain, target, ttl, duration_secs) in parsed {
let qtype = store
.insert(&domain, &target, ttl, duration_secs)
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
responses.push(OverrideResponse {
domain,
target,
record_type: qtype.as_str().to_string(),
ttl,
remaining_secs: duration_secs,
});
}
Ok((StatusCode::CREATED, Json(responses)))
}
async fn list_overrides(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<OverrideResponse>> {
let store = ctx.overrides.lock().unwrap();
let entries: Vec<OverrideResponse> = store
.list()
.into_iter()
.map(OverrideResponse::from)
.collect();
Json(entries)
}
async fn get_override(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> Result<Json<OverrideResponse>, StatusCode> {
let store = ctx.overrides.lock().unwrap();
let entry = store.get(&domain).ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(OverrideResponse::from(entry)))
}
async fn remove_override(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
let mut store = ctx.overrides.lock().unwrap();
if store.remove(&domain) {
StatusCode::NO_CONTENT
} else {
StatusCode::NOT_FOUND
}
}
async fn clear_overrides(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.overrides.lock().unwrap().clear();
StatusCode::NO_CONTENT
}
async fn load_environment(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<EnvironmentRequest>,
) -> Result<(StatusCode, Json<EnvironmentResponse>), (StatusCode, String)> {
let mut store = ctx.overrides.lock().unwrap();
for entry in &req.overrides {
let duration = entry.duration_secs.or(req.duration_secs);
store
.insert(&entry.domain, &entry.target, entry.ttl, duration)
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
}
Ok((
StatusCode::CREATED,
Json(EnvironmentResponse {
created: req.overrides.len(),
}),
))
}
async fn diagnose(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> Json<DiagnoseResponse> {
let domain_lower = domain.to_lowercase();
let qtype = QueryType::A;
let mut steps = Vec::new();
// Check overrides
{
let store = ctx.overrides.lock().unwrap();
let entry = store.get(&domain_lower);
steps.push(DiagnoseStep {
source: "override".to_string(),
matched: entry.is_some(),
detail: entry
.map(|e| format!("{} -> {} ({})", e.domain, e.target, e.query_type.as_str())),
});
}
// Check blocklist
{
let bl = ctx.blocklist.lock().unwrap();
let blocked = bl.is_blocked(&domain_lower);
steps.push(DiagnoseStep {
source: "blocklist".to_string(),
matched: blocked,
detail: if blocked {
Some("domain is in blocklist".to_string())
} else {
None
},
});
}
// Check local zones
let zone_match = ctx
.zone_map
.get(domain_lower.as_str())
.and_then(|m| m.get(&qtype));
steps.push(DiagnoseStep {
source: "local_zone".to_string(),
matched: zone_match.is_some(),
detail: zone_match.map(|records| format!("{} records", records.len())),
});
// Check cache
{
let mut cache = ctx.cache.lock().unwrap();
let cached = cache.lookup(&domain_lower, qtype);
steps.push(DiagnoseStep {
source: "cache".to_string(),
matched: cached.is_some(),
detail: cached.map(|p| format!("{} answers", p.answers.len())),
});
}
// Check upstream (async, no locks held)
let (upstream_matched, upstream_detail) =
forward_query_for_diagnose(&domain_lower, ctx.upstream, ctx.timeout).await;
steps.push(DiagnoseStep {
source: "upstream".to_string(),
matched: upstream_matched,
detail: Some(upstream_detail),
});
Json(DiagnoseResponse {
domain: domain_lower,
query_type: qtype.as_str().to_string(),
steps,
})
}
async fn forward_query_for_diagnose(
domain: &str,
upstream: std::net::SocketAddr,
timeout: std::time::Duration,
) -> (bool, String) {
use crate::packet::DnsPacket;
use crate::question::DnsQuestion;
let mut query = DnsPacket::new();
query.header.id = 0xBEEF;
query.header.recursion_desired = true;
query
.questions
.push(DnsQuestion::new(domain.to_string(), QueryType::A));
match forward_query(&query, upstream, timeout).await {
Ok(resp) => (
true,
format!(
"{} ({} answers)",
resp.header.rescode.as_str(),
resp.answers.len()
),
),
Err(e) => (false, format!("error: {}", e)),
}
}
async fn query_log(
State(ctx): State<Arc<ServerCtx>>,
Query(params): Query<QueryLogParams>,
) -> Json<Vec<QueryLogResponse>> {
let qtype = params.r#type.as_deref().and_then(QueryType::parse_str);
let path = params.path.as_deref().and_then(QueryPath::parse_str);
let filter = QueryLogFilter {
domain: params.domain,
query_type: qtype,
path,
since: None,
limit: params.limit,
};
let raw_entries: Vec<QueryLogResponse> = {
let log = ctx.query_log.lock().unwrap();
log.query(&filter)
.into_iter()
.map(|e| {
let epoch = e
.timestamp
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
QueryLogResponse {
timestamp_epoch: epoch,
src: e.src_addr.to_string(),
domain: e.domain.clone(),
query_type: e.query_type.as_str().to_string(),
path: e.path.as_str().to_string(),
rescode: e.rescode.as_str().to_string(),
latency_ms: e.latency_us as f64 / 1000.0,
}
})
.collect()
};
Json(raw_entries)
}
async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
let snap = ctx.stats.lock().unwrap().snapshot();
let (cache_len, cache_max) = {
let cache = ctx.cache.lock().unwrap();
(cache.len(), cache.max_entries())
};
let override_count = ctx.overrides.lock().unwrap().active_count();
let bl_stats = ctx.blocklist.lock().unwrap().stats();
Json(StatsResponse {
uptime_secs: snap.uptime_secs,
queries: QueriesStats {
total: snap.total,
forwarded: snap.forwarded,
cached: snap.cached,
local: snap.local,
overridden: snap.overridden,
blocked: snap.blocked,
errors: snap.errors,
},
cache: CacheStats {
entries: cache_len,
max_entries: cache_max,
},
overrides: OverrideStats {
active: override_count,
},
blocking: BlockingStatsResponse {
enabled: bl_stats.enabled,
paused: bl_stats.paused,
domains_loaded: bl_stats.domains_loaded,
allowlist_size: bl_stats.allowlist_size,
},
})
}
async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryResponse>> {
let cache = ctx.cache.lock().unwrap();
let entries: Vec<CacheEntryResponse> = cache
.list()
.into_iter()
.map(|info| CacheEntryResponse {
domain: info.domain,
query_type: info.query_type.as_str().to_string(),
ttl_remaining: info.ttl_remaining,
})
.collect();
Json(entries)
}
async fn flush_cache(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.cache.lock().unwrap().clear();
StatusCode::NO_CONTENT
}
async fn flush_cache_domain(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
ctx.cache.lock().unwrap().remove(&domain);
StatusCode::NO_CONTENT
}
async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "status": "ok" }))
}
// --- Blocking handlers ---
async fn blocking_stats(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
let stats = ctx.blocklist.lock().unwrap().stats();
Json(serde_json::json!({
"enabled": stats.enabled,
"paused": stats.paused,
"domains_loaded": stats.domains_loaded,
"allowlist_size": stats.allowlist_size,
"list_sources": stats.list_sources,
"last_refresh_secs_ago": stats.last_refresh_secs_ago,
}))
}
#[derive(Deserialize)]
struct BlockingToggleRequest {
enabled: bool,
}
async fn blocking_toggle(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingToggleRequest>,
) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().set_enabled(req.enabled);
Json(serde_json::json!({ "enabled": req.enabled }))
}
#[derive(Deserialize)]
struct BlockingPauseRequest {
#[serde(default = "default_pause_minutes")]
minutes: u64,
}
fn default_pause_minutes() -> u64 {
5
}
async fn blocking_pause(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingPauseRequest>,
) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().pause(req.minutes * 60);
Json(serde_json::json!({ "paused_minutes": req.minutes }))
}
async fn blocking_allowlist(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<String>> {
let list = ctx.blocklist.lock().unwrap().allowlist();
Json(list)
}
#[derive(Deserialize)]
struct AllowlistRequest {
domain: String,
}
async fn blocking_allowlist_add(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<AllowlistRequest>,
) -> (StatusCode, Json<serde_json::Value>) {
ctx.blocklist.lock().unwrap().add_to_allowlist(&req.domain);
(
StatusCode::CREATED,
Json(serde_json::json!({ "allowed": req.domain })),
)
}
async fn blocking_allowlist_remove(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
if ctx.blocklist.lock().unwrap().remove_from_allowlist(&domain) {
StatusCode::NO_CONTENT
} else {
StatusCode::NOT_FOUND
}
}

187
src/blocklist.rs Normal file
View File

@@ -0,0 +1,187 @@
use std::collections::HashSet;
use std::time::Instant;
use log::{info, warn};
pub struct BlocklistStore {
domains: HashSet<String>,
allowlist: HashSet<String>,
enabled: bool,
paused_until: Option<Instant>,
list_sources: Vec<String>,
last_refresh: Option<Instant>,
}
pub struct BlocklistStats {
pub enabled: bool,
pub paused: bool,
pub domains_loaded: usize,
pub allowlist_size: usize,
pub list_sources: Vec<String>,
pub last_refresh_secs_ago: Option<u64>,
}
impl Default for BlocklistStore {
fn default() -> Self {
Self::new()
}
}
impl BlocklistStore {
pub fn new() -> Self {
BlocklistStore {
domains: HashSet::new(),
allowlist: HashSet::new(),
enabled: true,
paused_until: None,
list_sources: Vec::new(),
last_refresh: None,
}
}
pub fn is_blocked(&self, domain: &str) -> bool {
if !self.enabled {
return false;
}
if let Some(until) = self.paused_until {
if Instant::now() < until {
return false;
}
}
if self.allowlist.contains(domain) {
return false;
}
if self.domains.contains(domain) {
return true;
}
// Walk up: ads.tracker.example.com → tracker.example.com → example.com
let mut d = domain;
while let Some(dot) = d.find('.') {
d = &d[dot + 1..];
if self.allowlist.contains(d) {
return false;
}
if self.domains.contains(d) {
return true;
}
}
false
}
/// Atomically swap in a new domain set. Build the set outside the lock,
/// then call this to swap — keeps lock hold time sub-microsecond.
pub fn swap_domains(&mut self, domains: HashSet<String>, sources: Vec<String>) {
self.domains = domains;
self.list_sources = sources;
self.last_refresh = Some(Instant::now());
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn pause(&mut self, seconds: u64) {
self.paused_until = Some(Instant::now() + std::time::Duration::from_secs(seconds));
}
pub fn is_paused(&self) -> bool {
self.paused_until
.map(|until| Instant::now() < until)
.unwrap_or(false)
}
pub fn add_to_allowlist(&mut self, domain: &str) {
self.allowlist.insert(domain.to_lowercase());
}
pub fn remove_from_allowlist(&mut self, domain: &str) -> bool {
self.allowlist.remove(&domain.to_lowercase())
}
pub fn allowlist(&self) -> Vec<String> {
self.allowlist.iter().cloned().collect()
}
pub fn stats(&self) -> BlocklistStats {
BlocklistStats {
enabled: self.is_enabled(),
paused: self.is_paused(),
domains_loaded: self.domains.len(),
allowlist_size: self.allowlist.len(),
list_sources: self.list_sources.clone(),
last_refresh_secs_ago: self.last_refresh.map(|t| t.elapsed().as_secs()),
}
}
}
/// Parse a blocklist text file into a set of domains.
pub fn parse_blocklist(text: &str) -> HashSet<String> {
let mut domains = HashSet::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') || line.starts_with('!') {
continue;
}
// Handle hosts-file format: "0.0.0.0 domain" or "127.0.0.1 domain" (space or tab)
let domain = if line.starts_with("0.0.0.0")
|| line.starts_with("127.0.0.1")
|| line.starts_with("::")
{
line.split_whitespace()
.nth(1)
.unwrap_or("")
.trim_end_matches('.')
} else if line.contains(' ') || line.contains('\t') {
continue;
} else {
// Plain domain or adblock filter syntax
let d = line.trim_start_matches("*.").trim_start_matches("||");
let d = d.split('$').next().unwrap_or(d); // strip adblock $options
d.trim_end_matches('^').trim_end_matches('.')
};
let domain = domain.to_lowercase();
if !domain.is_empty()
&& domain.contains('.')
&& domain != "localhost"
&& domain != "localhost.localdomain"
{
domains.insert(domain);
}
}
domains
}
pub async fn download_blocklists(lists: &[String]) -> Vec<(String, String)> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.unwrap_or_default();
let mut results = Vec::new();
for url in lists {
match client.get(url).send().await {
Ok(resp) => match resp.text().await {
Ok(text) => {
info!("downloaded blocklist: {} ({} bytes)", url, text.len());
results.push((url.clone(), text));
}
Err(e) => warn!("failed to read blocklist body {}: {}", url, e),
},
Err(e) => warn!("failed to download blocklist {}: {}", url, e),
}
}
results
}

View File

@@ -11,8 +11,11 @@ struct CacheEntry {
ttl: Duration,
}
/// DNS cache using a two-level map (domain -> query_type -> entry) so that
/// lookups can borrow `&str` instead of allocating a `String` key.
pub struct DnsCache {
entries: HashMap<(String, QueryType), CacheEntry>,
entries: HashMap<String, HashMap<QueryType, CacheEntry>>,
entry_count: usize,
max_entries: usize,
min_ttl: u32,
max_ttl: u32,
@@ -23,6 +26,7 @@ impl DnsCache {
pub fn new(max_entries: usize, min_ttl: u32, max_ttl: u32) -> Self {
DnsCache {
entries: HashMap::new(),
entry_count: 0,
max_entries,
min_ttl,
max_ttl,
@@ -33,17 +37,22 @@ impl DnsCache {
pub fn lookup(&mut self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
self.query_count += 1;
// Periodic eviction every 1000 queries
if self.query_count.is_multiple_of(1000) {
self.evict_expired();
}
let key = (domain.to_string(), qtype);
let entry = self.entries.get(&key)?;
let type_map = self.entries.get(domain)?;
let entry = type_map.get(&qtype)?;
let elapsed = entry.inserted_at.elapsed();
if elapsed >= entry.ttl {
self.entries.remove(&key);
// Expired: remove this entry
let type_map = self.entries.get_mut(domain).unwrap();
type_map.remove(&qtype);
self.entry_count -= 1;
if type_map.is_empty() {
self.entries.remove(domain);
}
return None;
}
@@ -59,10 +68,9 @@ impl DnsCache {
}
pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) {
if self.entries.len() >= self.max_entries {
if self.entry_count >= self.max_entries {
self.evict_expired();
// If still full after eviction, skip insertion
if self.entries.len() >= self.max_entries {
if self.entry_count >= self.max_entries {
return;
}
}
@@ -71,9 +79,18 @@ impl DnsCache {
.unwrap_or(self.min_ttl)
.clamp(self.min_ttl, self.max_ttl);
let key = (domain.to_string(), qtype);
self.entries.insert(
key,
let type_map = if let Some(existing) = self.entries.get_mut(domain) {
existing
} else {
self.entries.entry(domain.to_string()).or_default()
};
if !type_map.contains_key(&qtype) {
self.entry_count += 1;
}
type_map.insert(
qtype,
CacheEntry {
packet: packet.clone(),
inserted_at: Instant::now(),
@@ -82,10 +99,64 @@ impl DnsCache {
);
}
fn evict_expired(&mut self) {
self.entries
.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl);
pub fn len(&self) -> usize {
self.entry_count
}
pub fn is_empty(&self) -> bool {
self.entry_count == 0
}
pub fn max_entries(&self) -> usize {
self.max_entries
}
pub fn clear(&mut self) {
self.entries.clear();
self.entry_count = 0;
}
pub fn remove(&mut self, domain: &str) {
let domain_lower = domain.to_lowercase();
if let Some(type_map) = self.entries.remove(&domain_lower) {
self.entry_count -= type_map.len();
}
}
pub fn list(&self) -> Vec<CacheInfo> {
let mut result = Vec::new();
for (domain, type_map) in &self.entries {
for (qtype, entry) in type_map {
let elapsed = entry.inserted_at.elapsed();
if elapsed < entry.ttl {
let remaining = (entry.ttl - elapsed).as_secs() as u32;
result.push(CacheInfo {
domain: domain.clone(),
query_type: *qtype,
ttl_remaining: remaining,
});
}
}
}
result
}
fn evict_expired(&mut self) {
let mut count = 0;
self.entries.retain(|_, type_map| {
let before = type_map.len();
type_map.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl);
count += before - type_map.len();
!type_map.is_empty()
});
self.entry_count -= count;
}
}
pub struct CacheInfo {
pub domain: String,
pub query_type: QueryType,
pub ttl_remaining: u32,
}
fn extract_min_ttl(records: &[DnsRecord]) -> Option<u32> {

View File

@@ -18,6 +18,8 @@ pub struct Config {
#[serde(default)]
pub cache: CacheConfig,
#[serde(default)]
pub blocking: BlockingConfig,
#[serde(default)]
pub zones: Vec<ZoneRecord>,
}
@@ -25,12 +27,15 @@ pub struct Config {
pub struct ServerConfig {
#[serde(default = "default_bind_addr")]
pub bind_addr: String,
#[serde(default = "default_api_port")]
pub api_port: u16,
}
impl Default for ServerConfig {
fn default() -> Self {
ServerConfig {
bind_addr: default_bind_addr(),
api_port: default_api_port(),
}
}
}
@@ -39,6 +44,10 @@ fn default_bind_addr() -> String {
"0.0.0.0:53".to_string()
}
fn default_api_port() -> u16 {
5380
}
#[derive(Deserialize)]
pub struct UpstreamConfig {
#[serde(default = "default_upstream_addr")]
@@ -108,6 +117,41 @@ pub struct ZoneRecord {
pub ttl: u32,
}
#[derive(Deserialize)]
pub struct BlockingConfig {
#[serde(default = "default_blocking_enabled")]
pub enabled: bool,
#[serde(default = "default_blocklists")]
pub lists: Vec<String>,
#[serde(default = "default_refresh_hours")]
pub refresh_hours: u64,
#[serde(default)]
pub allowlist: Vec<String>,
}
impl Default for BlockingConfig {
fn default() -> Self {
BlockingConfig {
enabled: default_blocking_enabled(),
lists: default_blocklists(),
refresh_hours: default_refresh_hours(),
allowlist: Vec::new(),
}
}
}
fn default_blocking_enabled() -> bool {
true
}
fn default_blocklists() -> Vec<String> {
vec!["https://cdn.jsdelivr.net/gh/hagezi/dns-blocklists@latest/hosts/pro.txt".to_string()]
}
fn default_refresh_hours() -> u64 {
24
}
fn default_zone_ttl() -> u32 {
300
}
@@ -118,6 +162,7 @@ pub fn load_config(path: &str) -> Result<Config> {
server: ServerConfig::default(),
upstream: UpstreamConfig::default(),
cache: CacheConfig::default(),
blocking: BlockingConfig::default(),
zones: Vec::new(),
});
}
@@ -126,10 +171,10 @@ pub fn load_config(path: &str) -> Result<Config> {
Ok(config)
}
pub fn build_zone_map(
zones: &[ZoneRecord],
) -> Result<HashMap<(String, QueryType), Vec<DnsRecord>>> {
let mut map: HashMap<(String, QueryType), Vec<DnsRecord>> = HashMap::new();
pub type ZoneMap = HashMap<String, HashMap<QueryType, Vec<DnsRecord>>>;
pub fn build_zone_map(zones: &[ZoneRecord]) -> Result<ZoneMap> {
let mut map: ZoneMap = HashMap::new();
for zone in zones {
let domain = zone.domain.to_lowercase();
@@ -203,7 +248,11 @@ pub fn build_zone_map(
}
};
map.entry((domain, qtype)).or_default().push(record);
map.entry(domain)
.or_default()
.entry(qtype)
.or_default()
.push(record);
}
Ok(map)

155
src/ctx.rs Normal file
View File

@@ -0,0 +1,155 @@
use std::net::SocketAddr;
use std::sync::Mutex;
use std::time::{Duration, Instant, SystemTime};
use log::{debug, error, info, warn};
use tokio::net::UdpSocket;
use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer;
use crate::cache::DnsCache;
use crate::config::ZoneMap;
use crate::forward::forward_query;
use crate::header::ResultCode;
use crate::override_store::OverrideStore;
use crate::packet::DnsPacket;
use crate::query_log::{QueryLog, QueryLogEntry};
use crate::record::DnsRecord;
use crate::stats::{QueryPath, ServerStats};
use crate::system_dns::ForwardingRule;
pub struct ServerCtx {
pub socket: UdpSocket,
pub zone_map: ZoneMap,
pub cache: Mutex<DnsCache>,
pub stats: Mutex<ServerStats>,
pub overrides: Mutex<OverrideStore>,
pub blocklist: Mutex<BlocklistStore>,
pub query_log: Mutex<QueryLog>,
pub forwarding_rules: Vec<ForwardingRule>,
pub upstream: SocketAddr,
pub timeout: Duration,
}
pub async fn handle_query(
mut buffer: BytePacketBuffer,
src_addr: SocketAddr,
ctx: &ServerCtx,
) -> crate::Result<()> {
let start = Instant::now();
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(packet) => packet,
Err(e) => {
warn!("{} | PARSE ERROR | {}", src_addr, e);
return Ok(());
}
};
let (qname, qtype) = match query.questions.first() {
Some(q) => (q.name.clone(), q.qtype),
None => return Ok(()),
};
// Pipeline: overrides -> blocklist -> local zones -> cache -> upstream
// Each lock is scoped to avoid holding MutexGuard across await points.
let (response, path) = {
let override_record = ctx.overrides.lock().unwrap().lookup(&qname);
if let Some(record) = override_record {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers.push(record);
(resp, QueryPath::Overridden)
} else if ctx.blocklist.lock().unwrap().is_blocked(&qname) {
use crate::question::QueryType;
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
match qtype {
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
domain: qname.clone(),
addr: std::net::Ipv6Addr::UNSPECIFIED,
ttl: 60,
}),
_ => resp.answers.push(DnsRecord::A {
domain: qname.clone(),
addr: std::net::Ipv4Addr::UNSPECIFIED,
ttl: 60,
}),
}
(resp, QueryPath::Blocked)
} else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers = records.clone();
(resp, QueryPath::Local)
} else {
let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype);
if let Some(cached) = cached {
let mut resp = cached;
resp.header.id = query.header.id;
(resp, QueryPath::Cached)
} else {
let upstream =
crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules)
.unwrap_or(ctx.upstream);
match forward_query(&query, upstream, ctx.timeout).await {
Ok(resp) => {
ctx.cache.lock().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded)
}
Err(e) => {
error!(
"{} | {:?} {} | UPSTREAM ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
)
}
}
}
}
};
let elapsed = start.elapsed();
info!(
"{} | {:?} {} | {} | {} | {}ms",
src_addr,
qtype,
qname,
path.as_str(),
response.header.rescode.as_str(),
elapsed.as_millis(),
);
debug!(
"response: {} answers, {} authorities, {} resources",
response.answers.len(),
response.authorities.len(),
response.resources.len(),
);
let mut resp_buffer = BytePacketBuffer::new();
response.write(&mut resp_buffer)?;
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
// Record stats and query log
{
let mut s = ctx.stats.lock().unwrap();
let total = s.record(path);
if total.is_multiple_of(1000) {
s.log_summary();
}
}
ctx.query_log.lock().unwrap().push(QueryLogEntry {
timestamp: SystemTime::now(),
src_addr,
domain: qname,
query_type: qtype,
path,
rescode: response.header.rescode,
latency_us: elapsed.as_micros() as u64,
});
Ok(())
}

View File

@@ -1,12 +1,18 @@
pub mod api;
pub mod blocklist;
pub mod buffer;
pub mod cache;
pub mod config;
pub mod ctx;
pub mod forward;
pub mod header;
pub mod override_store;
pub mod packet;
pub mod query_log;
pub mod question;
pub mod record;
pub mod stats;
pub mod system_dns;
pub type Error = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -1,47 +1,48 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::time::Duration;
use log::{debug, error, info, warn};
use log::{error, info};
use tokio::net::UdpSocket;
use dns_fun::buffer::BytePacketBuffer;
use dns_fun::cache::DnsCache;
use dns_fun::config::{build_zone_map, load_config};
use dns_fun::forward::forward_query;
use dns_fun::header::ResultCode;
use dns_fun::packet::DnsPacket;
use dns_fun::question::QueryType;
use dns_fun::record::DnsRecord;
use dns_fun::stats::{QueryPath, ServerStats};
struct ServerCtx {
socket: Arc<UdpSocket>,
zone_map: HashMap<(String, QueryType), Vec<DnsRecord>>,
cache: Mutex<DnsCache>,
stats: Mutex<ServerStats>,
upstream: SocketAddr,
timeout: Duration,
}
use numa::blocklist::{download_blocklists, parse_blocklist, BlocklistStore};
use numa::buffer::BytePacketBuffer;
use numa::cache::DnsCache;
use numa::config::{build_zone_map, load_config};
use numa::ctx::{handle_query, ServerCtx};
use numa::override_store::OverrideStore;
use numa::query_log::QueryLog;
use numa::stats::ServerStats;
use numa::system_dns::discover_forwarding_rules;
#[tokio::main]
async fn main() -> dns_fun::Result<()> {
async fn main() -> numa::Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
.format_timestamp_millis()
.init();
let config_path = std::env::args()
.nth(1)
.unwrap_or_else(|| "dns_fun.toml".to_string());
.unwrap_or_else(|| "numa.toml".to_string());
let config = load_config(&config_path)?;
let upstream: SocketAddr =
format!("{}:{}", config.upstream.address, config.upstream.port).parse()?;
let socket = Arc::new(UdpSocket::bind(&config.server.bind_addr).await?);
let api_port = config.server.api_port;
let mut blocklist = BlocklistStore::new();
for domain in &config.blocking.allowlist {
blocklist.add_to_allowlist(domain);
}
if !config.blocking.enabled {
blocklist.set_enabled(false);
}
// Auto-discover conditional forwarding rules from OS (Tailscale, VPN, etc.)
let forwarding_rules = discover_forwarding_rules();
let ctx = Arc::new(ServerCtx {
socket: Arc::clone(&socket),
socket: UdpSocket::bind(&config.server.bind_addr).await?,
zone_map: build_zone_map(&config.zones)?,
cache: Mutex::new(DnsCache::new(
config.cache.max_entries,
@@ -49,21 +50,72 @@ async fn main() -> dns_fun::Result<()> {
config.cache.max_ttl,
)),
stats: Mutex::new(ServerStats::new()),
overrides: Mutex::new(OverrideStore::new()),
blocklist: Mutex::new(blocklist),
query_log: Mutex::new(QueryLog::new(1000)),
forwarding_rules,
upstream,
timeout: Duration::from_millis(config.upstream.timeout_ms),
});
let zone_count: usize = ctx.zone_map.values().map(|m| m.len()).sum();
eprintln!("\n\x1b[38;2;192;98;58m ╔══════════════════════════════════════════╗\x1b[0m");
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[1;38;2;192;98;58mNUMA\x1b[0m \x1b[3;38;2;163;152;136mDNS that governs itself\x1b[0m \x1b[38;2;192;98;58m║\x1b[0m");
eprintln!("\x1b[38;2;192;98;58m ╠══════════════════════════════════════════╣\x1b[0m");
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mDNS\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", config.server.bind_addr);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mAPI\x1b[0m http://localhost:{:<16}\x1b[38;2;192;98;58m║\x1b[0m", api_port);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mDashboard\x1b[0m http://localhost:{:<16}\x1b[38;2;192;98;58m║\x1b[0m", api_port);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mUpstream\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", upstream);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mZones\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", format!("{} records", zone_count));
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mCache\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", format!("max {} entries", config.cache.max_entries));
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mBlocking\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m",
if config.blocking.enabled { format!("{} lists", config.blocking.lists.len()) } else { "disabled".to_string() });
if !ctx.forwarding_rules.is_empty() {
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mRouting\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m",
format!("{} conditional rules", ctx.forwarding_rules.len()));
}
eprintln!("\x1b[38;2;192;98;58m ╚══════════════════════════════════════════╝\x1b[0m\n");
info!(
"dns_fun starting on {}, upstream {}, {} zone records, cache max {}",
config.server.bind_addr,
upstream,
ctx.zone_map.len(),
config.cache.max_entries,
"numa listening on {}, upstream {}, {} zone records, cache max {}, API on port {}",
config.server.bind_addr, upstream, zone_count, config.cache.max_entries, api_port,
);
// Download blocklists on startup
let blocklist_lists = config.blocking.lists.clone();
let refresh_hours = config.blocking.refresh_hours;
if config.blocking.enabled && !blocklist_lists.is_empty() {
let bl_ctx = Arc::clone(&ctx);
let bl_lists = blocklist_lists.clone();
tokio::spawn(async move {
load_blocklists(&bl_ctx, &bl_lists).await;
// Periodic refresh
let mut interval = tokio::time::interval(Duration::from_secs(refresh_hours * 3600));
interval.tick().await; // skip immediate tick
loop {
interval.tick().await;
info!("refreshing blocklists...");
load_blocklists(&bl_ctx, &bl_lists).await;
}
});
}
// Spawn HTTP API server
let api_ctx = Arc::clone(&ctx);
let api_addr: SocketAddr = format!("0.0.0.0:{}", api_port).parse()?;
tokio::spawn(async move {
let app = numa::api::router(api_ctx);
let listener = tokio::net::TcpListener::bind(api_addr).await.unwrap();
info!("HTTP API listening on {}", api_addr);
axum::serve(listener, app).await.unwrap();
});
// UDP DNS listener
#[allow(clippy::infinite_loop)]
loop {
let mut buffer = BytePacketBuffer::new();
let (_, src_addr) = socket.recv_from(&mut buffer.buf).await?;
let (_, src_addr) = ctx.socket.recv_from(&mut buffer.buf).await?;
let ctx = Arc::clone(&ctx);
tokio::spawn(async move {
@@ -74,87 +126,28 @@ async fn main() -> dns_fun::Result<()> {
}
}
async fn handle_query(
mut buffer: BytePacketBuffer,
src_addr: SocketAddr,
ctx: &ServerCtx,
) -> dns_fun::Result<()> {
let start = Instant::now();
async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
let downloaded = download_blocklists(lists).await;
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(packet) => packet,
Err(e) => {
warn!("{} | PARSE ERROR | {}", src_addr, e);
return Ok(());
}
};
let (qname, qtype) = match query.questions.first() {
Some(q) => (q.name.clone(), q.qtype),
None => return Ok(()),
};
// Pipeline: local zones -> cache -> upstream
// Each lock is scoped to avoid holding MutexGuard across await points.
let (response, path) = if let Some(records) = ctx.zone_map.get(&(qname.to_lowercase(), qtype)) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers = records.clone();
(resp, QueryPath::Local)
} else {
let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype);
if let Some(cached) = cached {
let mut resp = cached;
resp.header.id = query.header.id;
(resp, QueryPath::Cached)
} else {
match forward_query(&query, ctx.upstream, ctx.timeout).await {
Ok(resp) => {
ctx.cache.lock().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded)
}
Err(e) => {
error!(
"{} | {:?} {} | UPSTREAM ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
)
}
}
}
};
let elapsed = start.elapsed();
info!(
"{} | {:?} {} | {} | {} | {}ms",
src_addr,
qtype,
qname,
path.as_str(),
response.header.rescode.as_str(),
elapsed.as_millis(),
);
debug!(
"response: {} answers, {} authorities, {} resources",
response.answers.len(),
response.authorities.len(),
response.resources.len(),
);
let mut resp_buffer = BytePacketBuffer::new();
response.write(&mut resp_buffer)?;
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
// Record stats and log summary every 1000 queries (single lock acquisition)
let mut s = ctx.stats.lock().unwrap();
let total = s.record(path);
if total.is_multiple_of(1000) {
s.log_summary();
// Parse outside the lock to avoid blocking DNS queries during parse (~100ms)
let mut all_domains = std::collections::HashSet::new();
let mut sources = Vec::new();
for (source, text) in &downloaded {
let domains = parse_blocklist(text);
info!("blocklist: {} domains from {}", domains.len(), source);
all_domains.extend(domains);
sources.push(source.clone());
}
let total = all_domains.len();
Ok(())
// Swap under lock — sub-microsecond
ctx.blocklist
.lock()
.unwrap()
.swap_domains(all_domains, sources);
info!(
"blocking enabled: {} unique domains from {} lists",
total,
downloaded.len()
);
}

153
src/override_store.rs Normal file
View File

@@ -0,0 +1,153 @@
use std::collections::HashMap;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::time::Instant;
use crate::question::QueryType;
use crate::record::DnsRecord;
use crate::Result;
pub struct OverrideEntry {
pub domain: String,
pub target: String,
pub record: DnsRecord,
pub query_type: QueryType,
pub ttl: u32,
pub created_at: Instant,
pub duration_secs: Option<u64>,
}
impl OverrideEntry {
pub fn expires_at(&self) -> Option<Instant> {
self.duration_secs
.map(|d| self.created_at + std::time::Duration::from_secs(d))
}
pub fn is_expired(&self) -> bool {
self.expires_at()
.map(|exp| Instant::now() >= exp)
.unwrap_or(false)
}
pub fn remaining_secs(&self) -> Option<u64> {
self.expires_at().map(|exp| {
let now = Instant::now();
if now >= exp {
0
} else {
(exp - now).as_secs()
}
})
}
}
pub struct OverrideStore {
entries: HashMap<String, OverrideEntry>,
}
impl Default for OverrideStore {
fn default() -> Self {
Self::new()
}
}
impl OverrideStore {
pub fn new() -> Self {
OverrideStore {
entries: HashMap::new(),
}
}
pub fn insert(
&mut self,
domain: &str,
target: &str,
ttl: u32,
duration_secs: Option<u64>,
) -> Result<QueryType> {
let domain_lower = domain.to_lowercase();
let (qtype, record) = parse_target(&domain_lower, target, ttl)?;
self.entries.insert(
domain_lower.clone(),
OverrideEntry {
domain: domain_lower,
target: target.to_string(),
record,
query_type: qtype,
ttl,
created_at: Instant::now(),
duration_secs,
},
);
Ok(qtype)
}
/// Hot path: assumes `domain` is already lowercased (the parser does this).
pub fn lookup(&mut self, domain: &str) -> Option<DnsRecord> {
let entry = self.entries.get(domain)?;
if entry.is_expired() {
self.entries.remove(domain);
return None;
}
Some(entry.record.clone())
}
pub fn get(&self, domain: &str) -> Option<&OverrideEntry> {
let key = domain.to_lowercase();
let entry = self.entries.get(&key)?;
if entry.is_expired() {
return None;
}
Some(entry)
}
pub fn remove(&mut self, domain: &str) -> bool {
self.entries.remove(&domain.to_lowercase()).is_some()
}
pub fn list(&self) -> Vec<&OverrideEntry> {
self.entries.values().filter(|e| !e.is_expired()).collect()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn active_count(&self) -> usize {
self.entries.values().filter(|e| !e.is_expired()).count()
}
}
fn parse_target(domain: &str, target: &str, ttl: u32) -> Result<(QueryType, DnsRecord)> {
if let Ok(addr) = target.parse::<Ipv4Addr>() {
return Ok((
QueryType::A,
DnsRecord::A {
domain: domain.to_string(),
addr,
ttl,
},
));
}
if let Ok(addr) = target.parse::<Ipv6Addr>() {
return Ok((
QueryType::AAAA,
DnsRecord::AAAA {
domain: domain.to_string(),
addr,
ttl,
},
));
}
Ok((
QueryType::CNAME,
DnsRecord::CNAME {
domain: domain.to_string(),
host: target.to_string(),
ttl,
},
))
}

77
src/query_log.rs Normal file
View File

@@ -0,0 +1,77 @@
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::time::SystemTime;
use crate::header::ResultCode;
use crate::question::QueryType;
use crate::stats::QueryPath;
pub struct QueryLogEntry {
pub timestamp: SystemTime,
pub src_addr: SocketAddr,
pub domain: String,
pub query_type: QueryType,
pub path: QueryPath,
pub rescode: ResultCode,
pub latency_us: u64,
}
pub struct QueryLog {
entries: VecDeque<QueryLogEntry>,
capacity: usize,
}
impl QueryLog {
pub fn new(capacity: usize) -> Self {
QueryLog {
entries: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, entry: QueryLogEntry) {
if self.entries.len() >= self.capacity {
self.entries.pop_front();
}
self.entries.push_back(entry);
}
pub fn query(&self, filter: &QueryLogFilter) -> Vec<&QueryLogEntry> {
self.entries
.iter()
.rev()
.filter(|e| {
if let Some(ref domain) = filter.domain {
if !e.domain.contains(domain.as_str()) {
return false;
}
}
if let Some(qtype) = filter.query_type {
if e.query_type != qtype {
return false;
}
}
if let Some(path) = filter.path {
if e.path != path {
return false;
}
}
if let Some(since) = filter.since {
if e.timestamp < since {
return false;
}
}
true
})
.take(filter.limit.unwrap_or(50))
.collect()
}
}
pub struct QueryLogFilter {
pub domain: Option<String>,
pub query_type: Option<QueryType>,
pub path: Option<QueryPath>,
pub since: Option<SystemTime>,
pub limit: Option<usize>,
}

View File

@@ -33,6 +33,33 @@ impl QueryType {
_ => QueryType::UNKNOWN(num),
}
}
pub fn as_str(&self) -> &'static str {
match self {
QueryType::A => "A",
QueryType::NS => "NS",
QueryType::CNAME => "CNAME",
QueryType::MX => "MX",
QueryType::AAAA => "AAAA",
QueryType::UNKNOWN(_) => "UNKNOWN",
}
}
pub fn parse_str(s: &str) -> Option<QueryType> {
if s.eq_ignore_ascii_case("A") {
Some(QueryType::A)
} else if s.eq_ignore_ascii_case("NS") {
Some(QueryType::NS)
} else if s.eq_ignore_ascii_case("CNAME") {
Some(QueryType::CNAME)
} else if s.eq_ignore_ascii_case("MX") {
Some(QueryType::MX)
} else if s.eq_ignore_ascii_case("AAAA") {
Some(QueryType::AAAA)
} else {
None
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]

View File

@@ -240,7 +240,7 @@ impl DnsRecord {
}
}
DnsRecord::UNKNOWN { .. } => {
println!("Skipping record: {:?}", self);
log::debug!("Skipping record: {:?}", self);
}
}

View File

@@ -6,15 +6,18 @@ pub struct ServerStats {
queries_cached: u64,
queries_blocked: u64,
queries_local: u64,
queries_overridden: u64,
upstream_errors: u64,
started_at: Instant,
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum QueryPath {
Local,
Cached,
Forwarded,
Blocked,
Overridden,
UpstreamError,
}
@@ -25,9 +28,28 @@ impl QueryPath {
QueryPath::Cached => "CACHED",
QueryPath::Forwarded => "FORWARD",
QueryPath::Blocked => "BLOCKED",
QueryPath::Overridden => "OVERRIDE",
QueryPath::UpstreamError => "SERVFAIL",
}
}
pub fn parse_str(s: &str) -> Option<QueryPath> {
if s.eq_ignore_ascii_case("LOCAL") {
Some(QueryPath::Local)
} else if s.eq_ignore_ascii_case("CACHED") {
Some(QueryPath::Cached)
} else if s.eq_ignore_ascii_case("FORWARD") {
Some(QueryPath::Forwarded)
} else if s.eq_ignore_ascii_case("BLOCKED") {
Some(QueryPath::Blocked)
} else if s.eq_ignore_ascii_case("OVERRIDE") {
Some(QueryPath::Overridden)
} else if s.eq_ignore_ascii_case("SERVFAIL") {
Some(QueryPath::UpstreamError)
} else {
None
}
}
}
impl Default for ServerStats {
@@ -44,6 +66,7 @@ impl ServerStats {
queries_cached: 0,
queries_blocked: 0,
queries_local: 0,
queries_overridden: 0,
upstream_errors: 0,
started_at: Instant::now(),
}
@@ -56,6 +79,7 @@ impl ServerStats {
QueryPath::Cached => self.queries_cached += 1,
QueryPath::Forwarded => self.queries_forwarded += 1,
QueryPath::Blocked => self.queries_blocked += 1,
QueryPath::Overridden => self.queries_overridden += 1,
QueryPath::UpstreamError => self.upstream_errors += 1,
}
self.queries_total
@@ -65,6 +89,23 @@ impl ServerStats {
self.queries_total
}
pub fn uptime_secs(&self) -> u64 {
self.started_at.elapsed().as_secs()
}
pub fn snapshot(&self) -> StatsSnapshot {
StatsSnapshot {
uptime_secs: self.uptime_secs(),
total: self.queries_total,
forwarded: self.queries_forwarded,
cached: self.queries_cached,
local: self.queries_local,
overridden: self.queries_overridden,
blocked: self.queries_blocked,
errors: self.upstream_errors,
}
}
pub fn log_summary(&self) {
let uptime = self.started_at.elapsed();
let hours = uptime.as_secs() / 3600;
@@ -72,14 +113,26 @@ impl ServerStats {
let secs = uptime.as_secs() % 60;
log::info!(
"STATS | uptime {}h{}m{}s | total {} | fwd {} | cached {} | local {} | blocked {} | errors {}",
"STATS | uptime {}h{}m{}s | total {} | fwd {} | cached {} | local {} | override {} | blocked {} | errors {}",
hours, mins, secs,
self.queries_total,
self.queries_forwarded,
self.queries_cached,
self.queries_local,
self.queries_overridden,
self.queries_blocked,
self.upstream_errors,
);
}
}
pub struct StatsSnapshot {
pub uptime_secs: u64,
pub total: u64,
pub forwarded: u64,
pub cached: u64,
pub local: u64,
pub overridden: u64,
pub blocked: u64,
pub errors: u64,
}

144
src/system_dns.rs Normal file
View File

@@ -0,0 +1,144 @@
use std::net::SocketAddr;
use log::{debug, info, warn};
/// A conditional forwarding rule: domains matching `suffix` are forwarded to `upstream`.
#[derive(Debug, Clone)]
pub struct ForwardingRule {
pub suffix: String,
dot_suffix: String, // pre-computed ".suffix" for zero-alloc matching
pub upstream: SocketAddr,
}
/// Discover system DNS forwarding rules from the OS.
/// On macOS, parses `scutil --dns`. Returns rules sorted longest-suffix-first
/// so more specific matches take priority.
pub fn discover_forwarding_rules() -> Vec<ForwardingRule> {
#[cfg(target_os = "macos")]
{
discover_macos()
}
#[cfg(not(target_os = "macos"))]
{
info!("system DNS auto-discovery not implemented for this OS");
Vec::new()
}
}
#[cfg(target_os = "macos")]
fn discover_macos() -> Vec<ForwardingRule> {
let output = match std::process::Command::new("scutil").arg("--dns").output() {
Ok(o) => o,
Err(e) => {
warn!("failed to run scutil --dns: {}", e);
return Vec::new();
}
};
let text = String::from_utf8_lossy(&output.stdout);
let mut rules = Vec::new();
// Parse resolver blocks: look for blocks with both `domain` and `nameserver[0]`
// that have the `Supplemental` flag (conditional forwarding, not default)
let mut current_domain: Option<String> = None;
let mut current_nameserver: Option<String> = None;
let mut is_supplemental = false;
for line in text.lines() {
let line = line.trim();
if line.starts_with("resolver #") {
// Emit previous block if valid
if let (Some(domain), Some(ns), true) = (
current_domain.take(),
current_nameserver.take(),
is_supplemental,
) {
if let Some(rule) = make_rule(&domain, &ns) {
rules.push(rule);
}
}
current_domain = None;
current_nameserver = None;
is_supplemental = false;
} else if line.starts_with("domain") && line.contains(':') {
// "domain : tailcee7cc.ts.net."
if let Some(val) = line.split(':').nth(1) {
let domain = val.trim().trim_end_matches('.').to_lowercase();
if !domain.is_empty()
&& domain != "local"
&& !domain.ends_with("in-addr.arpa")
&& !domain.ends_with("ip6.arpa")
{
current_domain = Some(domain);
}
}
} else if line.starts_with("nameserver[0]") && line.contains(':') {
if let Some(val) = line.split(':').nth(1) {
let ns = val.trim().to_string();
// Only use IPv4 nameservers for now
if ns.parse::<std::net::Ipv4Addr>().is_ok() {
current_nameserver = Some(ns);
}
}
} else if line.starts_with("flags") && line.contains("Supplemental") {
is_supplemental = true;
} else if line.starts_with("DNS configuration (for scoped") {
// Stop at scoped section — those are interface-specific, not conditional
if let (Some(domain), Some(ns), true) = (
current_domain.take(),
current_nameserver.take(),
is_supplemental,
) {
if let Some(rule) = make_rule(&domain, &ns) {
rules.push(rule);
}
}
break;
}
}
// Emit last block
if let (Some(domain), Some(ns), true) = (current_domain, current_nameserver, is_supplemental) {
if let Some(rule) = make_rule(&domain, &ns) {
rules.push(rule);
}
}
// Sort longest suffix first for most-specific matching
rules.sort_by(|a, b| b.suffix.len().cmp(&a.suffix.len()));
for rule in &rules {
info!(
"auto-discovered forwarding: *.{} -> {}",
rule.suffix, rule.upstream
);
}
if rules.is_empty() {
debug!("no conditional forwarding rules discovered from scutil --dns");
}
rules
}
fn make_rule(domain: &str, nameserver: &str) -> Option<ForwardingRule> {
let addr: SocketAddr = format!("{}:53", nameserver).parse().ok()?;
Some(ForwardingRule {
dot_suffix: format!(".{}", domain),
suffix: domain.to_string(),
upstream: addr,
})
}
/// Find the upstream for a domain by checking forwarding rules.
/// Returns None if no rule matches (use default upstream).
/// Zero-allocation on the hot path — dot_suffix is pre-computed.
pub fn match_forwarding_rule(domain: &str, rules: &[ForwardingRule]) -> Option<SocketAddr> {
for rule in rules {
if domain == rule.suffix || domain.ends_with(&rule.dot_suffix) {
return Some(rule.upstream);
}
}
None
}