add ad blocking, live dashboard, system DNS auto-discovery

- DNS-level ad blocking: 385K+ domains via Hagezi Pro blocklist, subdomain
  matching, one-click allowlist, pause/toggle, background refresh every 24h
- Live dashboard at :5380 with real-time stats, query log, override
  management (create/edit/delete), blocking controls
- System DNS auto-discovery: parses scutil --dns on macOS to find
  conditional forwarding rules (Tailscale, VPN split-DNS)
- REST API expanded to 18 endpoints (blocking, overrides, diagnostics)
- Startup banner with colored system info
- Performance benchmarks (bench/dns-bench.sh)
- Landing page updated with new positioning and comparison table
- CI, Dockerfile, LICENSE, development plan docs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Razvan Dimescu
2026-03-20 10:54:23 +02:00
parent e31188fb88
commit 4dc5b94c7a
23 changed files with 5494 additions and 226 deletions

26
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
env:
CARGO_TERM_COLOR: always
jobs:
check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- uses: Swatinem/rust-cache@v2
- name: fmt
run: cargo fmt --check
- name: clippy
run: cargo clippy -- -D warnings
- name: test
run: cargo test

1325
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,20 @@
[package]
name = "dns_fun"
name = "numa"
version = "0.1.0"
authors = ["razvandimescu <razvan@dimescu.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
edition = "2021"
description = "Ephemeral DNS overrides for development and testing. Point any hostname to any endpoint. Auto-revert when you're done."
license = "MIT"
repository = "https://github.com/razvandimescu/numa"
keywords = ["dns", "proxy", "override", "development", "networking"]
categories = ["network-programming", "development-tools"]
[dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] }
toml = "0.8"
axum = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
log = "0.4"
env_logger = "0.11"
reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }

15
Dockerfile Normal file
View File

@@ -0,0 +1,15 @@
FROM rust:1.85-alpine AS builder
RUN apk add --no-cache musl-dev
WORKDIR /app
COPY Cargo.toml Cargo.lock ./
RUN mkdir src && echo 'fn main() {}' > src/main.rs && echo '' > src/lib.rs
RUN cargo build --release 2>/dev/null || true
RUN rm -rf src
COPY src/ src/
RUN touch src/main.rs src/lib.rs
RUN cargo build --release
FROM scratch
COPY --from=builder /app/target/release/numa /numa
EXPOSE 53/udp 5380/tcp
ENTRYPOINT ["/numa"]

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 Razvan Dimescu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

184
README.md
View File

@@ -1,36 +1,64 @@
# dns_fun
# Numa
A DNS forwarding/caching proxy written from scratch in Rust. Parses and serializes DNS wire protocol (RFC 1035), serves local zone records from TOML config, caches upstream responses with TTL-aware expiration, and logs every query with structured output.
**DNS you own. Everywhere you go.**
No DNS libraries — just `tokio::net::UdpSocket` and manual packet parsing. Each query is handled concurrently via `tokio::spawn`.
Block ads and trackers. Override DNS for development. Cache for speed. A single portable binary built from scratch in Rust — no Raspberry Pi, no cloud, no account.
## Record Types
## Why
A, NS, CNAME, MX, AAAA
- **Ad blocking that travels with you** — 385K+ domains blocked out of the box. Works on any network: coffee shops, hotels, airports.
- **Developer overrides** — point any hostname to any IP with auto-revert. No more editing `/etc/hosts`.
- **Sub-millisecond caching** — cached lookups in 0ms. Faster than any public resolver.
- **Live dashboard** — real-time query stats, blocking controls, override management at `http://localhost:5380`.
- **Single binary, zero config** — just run it.
## Usage
## Quick Start
```bash
# Run with default config (dns_fun.toml)
sudo cargo run
# Run with custom config path
sudo cargo run -- path/to/config.toml
# Test
dig @127.0.0.1 google.com
dig @127.0.0.1 mysite.local
cargo build
sudo cargo run # binds to port 53
```
Requires root/sudo for binding to port 53.
Open the dashboard: **http://localhost:5380**
Test it:
```bash
dig @127.0.0.1 google.com # normal resolution
dig @127.0.0.1 ads.google.com # blocked → 0.0.0.0
```
## Resolution Pipeline
```
Query → Overrides → Blocklist → Local Zones → Cache → Upstream → Respond
```
1. **Overrides** — ephemeral, time-scoped redirects (highest priority)
2. **Blocklist** — 385K+ ad/tracker domains → returns `0.0.0.0` / `::`
3. **Local zones** — records defined in `[[zones]]` config
4. **Cache** — TTL-adjusted cached upstream responses (sub-ms)
5. **Forward** — query upstream resolver, cache the result
6. **SERVFAIL** — returned on upstream failure
## Dashboard
Live at `http://localhost:5380` when Numa is running:
- Total queries, cache hit rate, blocked count, uptime
- Resolution path breakdown (forward / cached / local / override / blocked)
- Scrolling query log with colored path tags
- Active overrides with create/edit/delete
- Blocking controls: toggle on/off, pause 5 minutes, one-click allowlist
- Cached domains list
## Configuration
Edit `dns_fun.toml`:
`numa.toml` (all sections optional, sensible defaults if missing):
```toml
[server]
bind_addr = "0.0.0.0:53"
api_port = 5380
[upstream]
address = "8.8.8.8"
@@ -39,85 +67,87 @@ timeout_ms = 3000
[cache]
max_entries = 10000
min_ttl = 60 # floor: cache at least 60s
max_ttl = 86400 # ceiling: never cache longer than 24h
min_ttl = 60
max_ttl = 86400
[blocking]
enabled = true
lists = [
"https://cdn.jsdelivr.net/gh/hagezi/dns-blocklists@latest/hosts/pro.txt",
]
refresh_hours = 24
allowlist = []
[[zones]]
domain = "mysite.local"
record_type = "A"
value = "127.0.0.1"
ttl = 60
[[zones]]
domain = "other.local"
record_type = "AAAA"
value = "::1"
ttl = 120
```
All sections are optional — sensible defaults are used if the config file is missing.
## HTTP API
## Request Pipeline
REST API on port 5380 (18 endpoints):
```
Query -> Parse -> Local Zones -> Cache -> Upstream Forward -> Respond
```
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/` | GET | Live dashboard |
| `/overrides` | POST | Create override(s) |
| `/overrides` | GET | List active overrides |
| `/overrides` | DELETE | Clear all overrides |
| `/overrides/environment` | POST | Batch load overrides |
| `/overrides/{domain}` | GET | Get specific override |
| `/overrides/{domain}` | DELETE | Remove specific override |
| `/blocking/stats` | GET | Blocklist stats (domains loaded, sources, enabled) |
| `/blocking/toggle` | PUT | Enable/disable blocking |
| `/blocking/pause` | POST | Pause blocking for N minutes |
| `/blocking/allowlist` | GET | List allowlisted domains |
| `/blocking/allowlist` | POST | Add domain to allowlist |
| `/blocking/allowlist/{domain}` | DELETE | Remove from allowlist |
| `/diagnose/{domain}` | GET | Trace resolution path |
| `/query-log` | GET | Recent queries (filterable) |
| `/stats` | GET | Server statistics |
| `/cache` | GET | List cached entries |
| `/cache` | DELETE | Flush cache |
| `/cache/{domain}` | DELETE | Flush specific domain |
| `/health` | GET | Health check |
1. **Local zones** — match against records defined in `[[zones]]`, respond immediately
2. **Cache** — return TTL-adjusted cached response if available
3. **Forward** — send query to upstream resolver, cache the response
4. **SERVFAIL** — returned to client on upstream failure
## How It Compares
## Caching
| | Pi-hole | NextDNS | Cloudflare | Numa |
|---|---|---|---|---|
| Ad blocking | Yes | Yes | Limited | 385K+ domains |
| Portable | No (Raspberry Pi) | Cloud only | Cloud only | Single binary |
| Developer overrides | No | No | No | REST API + auto-expiry |
| Data stays local | Yes | Cloud | Cloud | 100% local |
| Zero config | Complex setup | Yes | Yes | Works out of the box |
| Self-sovereign DNS | No | No | No | pkarr/DHT roadmap |
- TTL derived from minimum TTL across answer records
- Clamped to configured `min_ttl`/`max_ttl` bounds
- TTLs in cached responses decrease over time (adjusted on serve)
- Lazy eviction on capacity overflow + periodic sweep every 1000 queries
## Use Cases
## Logging
**Block ads everywhere** — Run Numa on your laptop. Your ad blocker works on any network.
Controlled via `RUST_LOG` environment variable:
**Mock external services**`Point api.stripe.com to localhost:8080 for 30 minutes`
**Provision dev environments** — Create overrides for `db.dev`, `api.dev`, `cache.dev`
**Debug DNS**`/diagnose/example.com` traces the full resolution path
## Docker
```bash
RUST_LOG=info sudo cargo run # default — one line per query
RUST_LOG=debug sudo cargo run # includes response details
RUST_LOG=warn sudo cargo run # errors only
```
Log output:
```
2026-03-10T14:23:01.123Z INFO 192.168.1.5:41234 | A google.com | FORWARD | NOERROR | 12ms
2026-03-10T14:23:01.456Z INFO 192.168.1.5:41235 | A mysite.local | LOCAL | NOERROR | 0ms
2026-03-10T14:23:02.789Z INFO 192.168.1.5:41236 | A google.com | CACHED | NOERROR | 0ms
```
Stats summary (total, forwarded, cached, local, blocked, errors) logged every 1000 queries.
## Project Structure
```
src/
main.rs # async startup, tokio event loop, ServerCtx, per-query task spawn
lib.rs # module declarations, Error/Result type aliases
buffer.rs # BytePacketBuffer — 512-byte DNS wire format read/write
header.rs # DnsHeader, ResultCode
question.rs # DnsQuestion, QueryType
record.rs # DnsRecord (A, NS, CNAME, MX, AAAA, UNKNOWN)
packet.rs # DnsPacket — full DNS message parse/serialize
config.rs # TOML config loading, zone map builder
cache.rs # TTL-aware DNS response cache with lazy eviction
forward.rs # async upstream forwarding
stats.rs # query counters and periodic summary
docker build -t numa .
docker run -p 53:53/udp -p 5380:5380 numa
```
## Dependencies
```toml
tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] }
toml = "0.8"
serde = { version = "1", features = ["derive"] }
log = "0.4"
env_logger = "0.11"
```
tokio, axum, serde, serde_json, toml, log, env_logger, reqwest
```
Zero external DNS libraries. Wire protocol (RFC 1035) parsed from scratch.
## License
MIT

151
bench/dns-bench.sh Executable file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""DNS performance benchmark — compares Numa against public resolvers."""
import subprocess
import sys
import re
import statistics
import json
NUMA_PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 15353
ROUNDS = int(sys.argv[2]) if len(sys.argv) > 2 else 20
DOMAINS = [
"google.com", "github.com", "amazon.com", "cloudflare.com",
"reddit.com", "stackoverflow.com", "rust-lang.org", "wikipedia.org",
"netflix.com", "twitter.com",
]
RESOLVERS = [
("Numa(cold)", "127.0.0.1", NUMA_PORT),
("Numa(cached)", "127.0.0.1", NUMA_PORT),
("System", "", 53),
]
# Detect system resolver
try:
out = subprocess.run(["scutil", "--dns"], capture_output=True, text=True)
m = re.search(r"nameserver\[0\]\s*:\s*([\d.]+)", out.stdout)
if m:
RESOLVERS[2] = ("System", m.group(1), 53)
except Exception:
pass
# Add public resolvers — skip if unreachable
for name, ip in [("Google", "8.8.8.8"), ("Cloudflare", "1.1.1.1"), ("Quad9", "9.9.9.9")]:
try:
out = subprocess.run(
["dig", f"@{ip}", "example.com", "+short", "+time=2", "+tries=1"],
capture_output=True, text=True, timeout=4
)
if out.stdout.strip():
RESOLVERS.append((name, ip, 53))
except Exception:
pass
A = "\033[38;2;192;98;58m"
T = "\033[38;2;107;124;78m"
D = "\033[38;2;163;152;136m"
B = "\033[1m"
R = "\033[0m"
def query_ms(server, port, domain):
try:
out = subprocess.run(
["dig", f"@{server}", "-p", str(port), domain,
"+noall", "+stats", "+tries=1", "+time=3"],
capture_output=True, text=True, timeout=5
)
m = re.search(r"Query time:\s+(\d+)\s+msec", out.stdout)
return int(m.group(1)) if m else None
except Exception:
return None
def flush_cache(domain=None):
try:
url = f"http://localhost:5380/cache/{domain}" if domain else "http://localhost:5380/cache"
subprocess.run(["curl", "-s", "-X", "DELETE", url],
capture_output=True, timeout=3)
except Exception:
pass
print()
print(f"{A} ╔══════════════════════════════════════════════════════════╗{R}")
print(f"{A} ║{R} {B}{A}NUMA{R} DNS Performance Benchmark {A}║{R}")
print(f"{A} ╚══════════════════════════════════════════════════════════╝{R}")
print()
print(f"{D} Domains: {len(DOMAINS)} | Rounds: {ROUNDS} | Total: {len(DOMAINS) * ROUNDS} queries per resolver{R}")
print()
results = {}
for name, server, port in RESOLVERS:
print(f" {T}Testing{R} {B}{name}{R}...", end="", flush=True)
if name == "Numa(cold)":
flush_cache()
latencies = []
for r in range(ROUNDS):
for domain in DOMAINS:
if name == "Numa(cold)":
flush_cache(domain)
ms = query_ms(server, port, domain)
if ms is not None:
latencies.append(ms)
if latencies:
latencies.sort()
n = len(latencies)
results[name] = {
"avg": round(statistics.mean(latencies), 1),
"p50": latencies[n // 2],
"p99": latencies[int(n * 0.99)],
"min": min(latencies),
"max": max(latencies),
"count": n,
}
print(f" {D}done ({len(latencies)} queries){R}")
print()
print(f"{A} ┌──────────────┬────────┬────────┬────────┬────────┬────────┐{R}")
print(f"{A} │{R} {B}Resolver{R} {A}│{R} {B}Avg{R} {A}│{R} {B}P50{R} {A}│{R} {B}P99{R} {A}│{R} {B}Min{R} {A}│{R} {B}Max{R} {A}│{R}")
print(f"{A} ├──────────────┼────────┼────────┼────────┼────────┼────────┤{R}")
for name, _, _ in RESOLVERS:
if name not in results:
continue
r = results[name]
if "cached" in name.lower():
c = T
elif "cold" in name.lower():
c = A
else:
c = D
print(f"{c} │ {name:<12s} │ {r['avg']:5.1f}ms │ {r['p50']:4d}ms │ {r['p99']:4d}ms │ {r['min']:4d}ms │ {r['max']:4d}ms │{R}")
print(f"{A} └──────────────┴────────┴────────┴────────┴────────┴────────┘{R}")
# Summary comparison
cached = results.get("Numa(cached)", {})
cold = results.get("Numa(cold)", {})
print()
if cached and cached["avg"] > 0:
for name in [n for n, _, _ in RESOLVERS if n not in ("Numa(cold)", "Numa(cached)")]:
other = results.get(name, {})
if other and other["avg"] > 0:
x = other["avg"] / max(cached["avg"], 0.1)
print(f" {T}Numa cached is ~{x:.0f}x faster than {name} (avg){R}")
if cold and cold["avg"] > 0:
x = cold["avg"] / max(cached["avg"], 0.1)
print(f" {T}Numa cached is ~{x:.0f}x faster than Numa cold (avg){R}")
# Save raw results as JSON
out_path = "bench/results.json"
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\n {D}Raw results saved to {out_path}{R}")
print()

View File

@@ -1,5 +1,6 @@
[server]
bind_addr = "0.0.0.0:53"
api_port = 5380
[upstream]
address = "8.8.8.8"

842
site/dashboard.html Normal file
View File

@@ -0,0 +1,842 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Numa — Dashboard</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Instrument+Serif:ital@0;1&family=DM+Sans:opsz,wght@9..40,400;9..40,500;9..40,600&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
<style>
*, *::before, *::after { margin: 0; padding: 0; box-sizing: border-box; }
:root {
--bg-deep: #f5f0e8;
--bg-surface: #ece5da;
--bg-elevated: #e3dbce;
--bg-card: #faf7f2;
--amber: #c0623a;
--amber-dim: #9e4e2d;
--teal: #6b7c4e;
--teal-dim: #566540;
--violet: #64748b;
--violet-dim: #4a5568;
--emerald: #527a52;
--rose: #b5443a;
--cyan: #4a7c8a;
--text-primary: #2c2418;
--text-secondary: #6b5e4f;
--text-dim: #a39888;
--border: rgba(0, 0, 0, 0.08);
--border-amber: rgba(192, 98, 58, 0.22);
--font-display: 'Instrument Serif', Georgia, serif;
--font-body: 'DM Sans', system-ui, sans-serif;
--font-mono: 'JetBrains Mono', 'SF Mono', monospace;
}
html { font-size: 15px; }
body {
font-family: var(--font-body);
background: var(--bg-deep);
color: var(--text-primary);
min-height: 100vh;
-webkit-font-smoothing: antialiased;
}
/* Header */
.header {
display: flex;
align-items: center;
justify-content: space-between;
padding: 1.2rem 2rem;
border-bottom: 1px solid var(--border);
background: var(--bg-card);
}
.header-left {
display: flex;
align-items: baseline;
gap: 1rem;
}
.logo {
font-family: var(--font-display);
font-size: 1.8rem;
color: var(--amber);
letter-spacing: 0.04em;
}
.tagline {
font-size: 0.85rem;
color: var(--text-dim);
font-style: italic;
font-family: var(--font-display);
}
.status-badge {
display: flex;
align-items: center;
gap: 0.5rem;
font-size: 0.8rem;
color: var(--text-dim);
font-family: var(--font-mono);
}
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: var(--emerald);
animation: pulse 2s ease-in-out infinite;
}
.status-dot.error { background: var(--rose); animation: none; }
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.4; }
}
/* Layout */
.dashboard {
max-width: 1400px;
margin: 0 auto;
padding: 1.5rem 2rem;
display: flex;
flex-direction: column;
gap: 1.2rem;
}
/* Stat cards row */
.stats-row {
display: grid;
grid-template-columns: repeat(5, 1fr);
gap: 1rem;
}
.stat-card {
background: var(--bg-card);
border: 1px solid var(--border);
border-radius: 10px;
padding: 1.2rem 1.4rem;
position: relative;
overflow: hidden;
}
.stat-card::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
height: 3px;
}
.stat-card.queries::before { background: var(--amber); }
.stat-card.cache::before { background: var(--teal); }
.stat-card.blocked::before { background: var(--rose); }
.stat-card.overrides::before { background: var(--violet); }
.stat-card.uptime::before { background: var(--cyan); }
.stat-label {
font-size: 0.7rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.1em;
color: var(--text-dim);
margin-bottom: 0.5rem;
}
.stat-value {
font-family: var(--font-mono);
font-size: 2rem;
font-weight: 500;
line-height: 1;
}
.stat-card.queries .stat-value { color: var(--amber); }
.stat-card.cache .stat-value { color: var(--teal); }
.stat-card.blocked .stat-value { color: var(--rose); }
.stat-card.overrides .stat-value { color: var(--violet); }
.stat-card.uptime .stat-value { color: var(--cyan); }
.stat-sub {
font-family: var(--font-mono);
font-size: 0.8rem;
color: var(--text-dim);
margin-top: 0.3rem;
}
/* Two-column main area */
.main-grid {
display: grid;
grid-template-columns: 1fr 340px;
gap: 1.2rem;
}
/* Panels */
.panel {
background: var(--bg-card);
border: 1px solid var(--border);
border-radius: 10px;
overflow: hidden;
}
.panel-header {
display: flex;
align-items: center;
justify-content: space-between;
padding: 0.9rem 1.2rem;
border-bottom: 1px solid var(--border);
background: var(--bg-surface);
}
.panel-title {
font-size: 0.7rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.1em;
color: var(--text-secondary);
}
.panel-body {
padding: 1rem 1.2rem;
}
/* Resolution paths bar chart */
.path-bar-row {
display: flex;
align-items: center;
gap: 0.8rem;
margin-bottom: 0.6rem;
}
.path-bar-row:last-child { margin-bottom: 0; }
.path-label {
font-family: var(--font-mono);
font-size: 0.75rem;
width: 70px;
text-align: right;
color: var(--text-secondary);
flex-shrink: 0;
}
.path-bar-track {
flex: 1;
height: 22px;
background: var(--bg-surface);
border-radius: 4px;
overflow: hidden;
}
.path-bar-fill {
height: 100%;
border-radius: 4px;
transition: width 0.6s ease;
min-width: 2px;
}
.path-bar-fill.forward { background: var(--amber); }
.path-bar-fill.cached { background: var(--teal); }
.path-bar-fill.local { background: var(--violet); }
.path-bar-fill.override { background: var(--emerald); }
.path-bar-fill.error { background: var(--rose); }
.path-bar-fill.blocked { background: var(--text-dim); }
.path-pct {
font-family: var(--font-mono);
font-size: 0.75rem;
width: 42px;
color: var(--text-dim);
flex-shrink: 0;
}
/* Query log table */
.query-log {
max-height: 380px;
overflow-y: auto;
scrollbar-width: thin;
scrollbar-color: var(--bg-elevated) transparent;
}
.query-log table {
width: 100%;
border-collapse: collapse;
font-family: var(--font-mono);
font-size: 0.75rem;
}
.query-log th {
text-align: left;
padding: 0.5rem 0.6rem;
font-size: 0.65rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.08em;
color: var(--text-dim);
border-bottom: 1px solid var(--border);
position: sticky;
top: 0;
background: var(--bg-card);
z-index: 1;
}
.query-log td {
padding: 0.4rem 0.6rem;
border-bottom: 1px solid var(--border);
white-space: nowrap;
color: var(--text-secondary);
}
.query-log tr:hover td {
background: var(--bg-surface);
}
.query-log .domain-cell {
max-width: 220px;
overflow: hidden;
text-overflow: ellipsis;
color: var(--text-primary);
}
.path-tag {
display: inline-block;
padding: 0.1rem 0.4rem;
border-radius: 3px;
font-size: 0.65rem;
font-weight: 500;
}
.path-tag.FORWARD { background: rgba(192, 98, 58, 0.12); color: var(--amber-dim); }
.path-tag.CACHED { background: rgba(107, 124, 78, 0.12); color: var(--teal-dim); }
.path-tag.LOCAL { background: rgba(100, 116, 139, 0.12); color: var(--violet-dim); }
.path-tag.OVERRIDE { background: rgba(82, 122, 82, 0.12); color: var(--emerald); }
.path-tag.SERVFAIL { background: rgba(181, 68, 58, 0.12); color: var(--rose); }
.path-tag.BLOCKED { background: rgba(163, 152, 136, 0.15); color: var(--text-dim); }
/* Sidebar panels */
.sidebar {
display: flex;
flex-direction: column;
gap: 1.2rem;
}
/* Overrides list */
.override-item {
display: flex;
flex-direction: column;
gap: 0.2rem;
padding: 0.6rem 0;
border-bottom: 1px solid var(--border);
}
.override-item:last-child { border-bottom: none; }
.override-domain {
font-family: var(--font-mono);
font-size: 0.8rem;
font-weight: 500;
color: var(--emerald);
}
.override-target {
font-family: var(--font-mono);
font-size: 0.72rem;
color: var(--text-dim);
}
.override-ttl {
font-family: var(--font-mono);
font-size: 0.68rem;
color: var(--amber);
}
.empty-state {
font-size: 0.8rem;
color: var(--text-dim);
font-style: italic;
padding: 0.8rem 0;
}
/* Cache panel */
.cache-item {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.35rem 0;
border-bottom: 1px solid var(--border);
font-family: var(--font-mono);
font-size: 0.72rem;
}
.cache-item:last-child { border-bottom: none; }
.cache-domain {
color: var(--text-primary);
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
max-width: 200px;
}
.cache-ttl {
color: var(--text-dim);
flex-shrink: 0;
}
/* Override form */
.override-form {
display: flex;
flex-direction: column;
gap: 0.5rem;
padding-bottom: 0.8rem;
margin-bottom: 0.6rem;
border-bottom: 1px solid var(--border);
}
.override-form input {
font-family: var(--font-mono);
font-size: 0.75rem;
padding: 0.45rem 0.6rem;
border: 1px solid var(--border);
border-radius: 5px;
background: var(--bg-surface);
color: var(--text-primary);
outline: none;
transition: border-color 0.2s;
}
.override-form input:focus {
border-color: var(--amber);
}
.override-form input::placeholder {
color: var(--text-dim);
}
.override-form-row {
display: flex;
gap: 0.4rem;
}
.override-form-row input {
flex: 1;
min-width: 0;
}
.btn {
font-family: var(--font-body);
font-size: 0.72rem;
font-weight: 600;
padding: 0.4rem 0.8rem;
border: none;
border-radius: 5px;
cursor: pointer;
transition: opacity 0.2s;
}
.btn:hover { opacity: 0.85; }
.btn:active { opacity: 0.7; }
.btn-add {
background: var(--emerald);
color: white;
}
.btn-delete {
background: none;
border: none;
cursor: pointer;
color: var(--text-dim);
font-size: 0.75rem;
padding: 0.15rem 0.3rem;
border-radius: 3px;
transition: color 0.2s, background 0.2s;
}
.btn-delete:hover {
color: var(--rose);
background: rgba(181, 68, 58, 0.08);
}
.override-item-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.override-error {
font-size: 0.7rem;
color: var(--rose);
display: none;
}
/* Responsive */
@media (max-width: 1100px) {
.main-grid { grid-template-columns: 1fr; }
}
@media (max-width: 700px) {
.stats-row { grid-template-columns: repeat(2, 1fr); }
.dashboard { padding: 1rem; }
.header { padding: 1rem; }
}
</style>
</head>
<body>
<div class="header">
<div class="header-left">
<div class="logo">Numa</div>
<div class="tagline">DNS that governs itself</div>
</div>
<div style="display:flex;align-items:center;gap:1.2rem;">
<button class="btn" id="pauseBtn" onclick="pauseBlocking()" style="background:var(--amber);color:white;font-family:var(--font-mono);font-size:0.7rem;display:none;">Pause 5m</button>
<button class="btn" id="toggleBtn" onclick="toggleBlocking()" style="background:var(--rose);color:white;font-family:var(--font-mono);font-size:0.7rem;display:none;"></button>
<div class="status-badge">
<span class="status-dot" id="statusDot"></span>
<span id="statusText">connecting...</span>
</div>
</div>
</div>
<div class="dashboard">
<!-- Stat cards -->
<div class="stats-row">
<div class="stat-card queries">
<div class="stat-label">Total Queries</div>
<div class="stat-value" id="totalQueries"></div>
<div class="stat-sub" id="qps"></div>
</div>
<div class="stat-card cache">
<div class="stat-label">Cache Hit Rate</div>
<div class="stat-value" id="cacheRate"></div>
<div class="stat-sub" id="cacheEntries"></div>
</div>
<div class="stat-card blocked">
<div class="stat-label">Blocked</div>
<div class="stat-value" id="blockedCount"></div>
<div class="stat-sub" id="blockedSub">&nbsp;</div>
</div>
<div class="stat-card overrides">
<div class="stat-label">Active Overrides</div>
<div class="stat-value" id="overrideCount"></div>
<div class="stat-sub">&nbsp;</div>
</div>
<div class="stat-card uptime">
<div class="stat-label">Uptime</div>
<div class="stat-value" id="uptime"></div>
<div class="stat-sub" id="uptimeSub">&nbsp;</div>
</div>
</div>
<!-- Resolution paths -->
<div class="panel">
<div class="panel-header">
<span class="panel-title">Resolution Paths</span>
</div>
<div class="panel-body" id="pathBars">
<!-- Populated by JS -->
</div>
</div>
<!-- Main grid: query log + sidebar -->
<div class="main-grid">
<!-- Query log -->
<div class="panel">
<div class="panel-header">
<span class="panel-title">Recent Queries</span>
<span class="panel-title" id="queryCount" style="color: var(--text-dim)"></span>
</div>
<div class="query-log" id="queryLog">
<table>
<thead>
<tr>
<th>Time</th>
<th>Type</th>
<th>Domain</th>
<th>Path</th>
<th>Result</th>
<th>Latency</th>
</tr>
</thead>
<tbody id="queryLogBody">
</tbody>
</table>
</div>
</div>
<!-- Sidebar -->
<div class="sidebar">
<!-- Active overrides -->
<div class="panel">
<div class="panel-header">
<span class="panel-title">Active Overrides</span>
</div>
<div class="panel-body">
<form class="override-form" id="overrideForm" onsubmit="return addOverride(event)">
<input type="text" id="ovDomain" placeholder="domain (e.g. api.dev)" required>
<input type="text" id="ovTarget" placeholder="target IP (e.g. 127.0.0.1)" required>
<div class="override-form-row">
<input type="number" id="ovTTL" placeholder="TTL" value="60" min="1">
<input type="number" id="ovDuration" placeholder="Duration (s)" value="300" min="1">
</div>
<button type="submit" class="btn btn-add">Add Override</button>
<div class="override-error" id="overrideError"></div>
</form>
<div id="overridesList">
<div class="empty-state">No active overrides</div>
</div>
</div>
</div>
<!-- Cache entries -->
<div class="panel">
<div class="panel-header">
<span class="panel-title">Cached Domains</span>
<span class="panel-title" id="cacheCount" style="color: var(--text-dim)"></span>
</div>
<div class="panel-body" id="cacheList" style="max-height: 240px; overflow-y: auto; scrollbar-width: thin;">
<div class="empty-state">Cache empty</div>
</div>
</div>
</div>
</div>
</div>
<script>
const API = '';
let prevTotal = null;
let prevTime = null;
async function fetchJSON(path) {
const res = await fetch(API + path);
if (!res.ok) throw new Error(res.status);
return res.json();
}
function formatUptime(secs) {
if (secs < 60) return `${secs}s`;
if (secs < 3600) return `${Math.floor(secs / 60)}m`;
const h = Math.floor(secs / 3600);
const m = Math.floor((secs % 3600) / 60);
return `${h}h ${m}m`;
}
function formatUptimeSub(secs) {
const d = Math.floor(secs / 86400);
const h = Math.floor((secs % 86400) / 3600);
const m = Math.floor((secs % 3600) / 60);
const s = secs % 60;
if (d > 0) return `${d}d ${h}h ${m}m ${s}s`;
if (h > 0) return `${h}h ${m}m ${s}s`;
if (m > 0) return `${m}m ${s}s`;
return `${s}s`;
}
function formatNumber(n) {
if (n >= 1000000) return (n / 1000000).toFixed(1) + 'M';
if (n >= 1000) return (n / 1000).toFixed(1) + 'K';
return n.toString();
}
function formatTime(epoch) {
const d = new Date(epoch * 1000);
return d.toLocaleTimeString([], { hour12: false });
}
function formatRemaining(secs) {
if (secs == null) return 'permanent';
if (secs < 60) return `${secs}s left`;
if (secs < 3600) return `${Math.floor(secs / 60)}m ${secs % 60}s left`;
return `${Math.floor(secs / 3600)}h ${Math.floor((secs % 3600) / 60)}m left`;
}
const PATH_DEFS = [
{ key: 'forwarded', label: 'Forward', cls: 'forward' },
{ key: 'cached', label: 'Cached', cls: 'cached' },
{ key: 'local', label: 'Local', cls: 'local' },
{ key: 'overridden', label: 'Override', cls: 'override' },
{ key: 'blocked', label: 'Blocked', cls: 'blocked' },
{ key: 'errors', label: 'Errors', cls: 'error' },
];
function renderPaths(queries) {
const total = queries.total || 1;
const container = document.getElementById('pathBars');
container.innerHTML = PATH_DEFS.map(p => {
const count = queries[p.key] || 0;
const pct = ((count / total) * 100).toFixed(1);
return `
<div class="path-bar-row">
<span class="path-label">${p.label}</span>
<div class="path-bar-track">
<div class="path-bar-fill ${p.cls}" style="width: ${pct}%"></div>
</div>
<span class="path-pct">${pct}%</span>
</div>`;
}).join('');
}
function renderQueryLog(entries) {
const tbody = document.getElementById('queryLogBody');
document.getElementById('queryCount').textContent = `last ${entries.length}`;
tbody.innerHTML = entries.map(e => {
const allowBtn = e.path === 'BLOCKED'
? ` <button class="btn-delete" onclick="allowDomain('${e.domain}')" title="Allow this domain" style="color:var(--emerald);font-size:0.65rem;">allow</button>`
: '';
return `
<tr>
<td>${formatTime(e.timestamp_epoch)}</td>
<td>${e.query_type}</td>
<td class="domain-cell" title="${e.domain}">${e.domain}${allowBtn}</td>
<td><span class="path-tag ${e.path}">${e.path}</span></td>
<td>${e.rescode}</td>
<td>${e.latency_ms.toFixed(1)}ms</td>
</tr>`;
}).join('');
}
function renderOverrides(entries) {
const el = document.getElementById('overridesList');
if (!entries.length) {
el.innerHTML = '<div class="empty-state">No active overrides</div>';
return;
}
el.innerHTML = entries.map(e => `
<div class="override-item">
<div class="override-item-header">
<span class="override-domain" style="cursor:pointer" onclick="editOverride('${e.domain}','${e.target}',${e.ttl || 60},${e.remaining_secs || 300})" title="Click to edit">${e.domain}</span>
<button class="btn-delete" onclick="deleteOverride('${e.domain}')" title="Remove override">&times;</button>
</div>
<div class="override-target">${e.record_type} &rarr; ${e.target}</div>
<div class="override-ttl">${e.remaining_secs != null ? formatRemaining(e.remaining_secs) : 'permanent'}</div>
</div>
`).join('');
}
async function addOverride(event) {
event.preventDefault();
const errEl = document.getElementById('overrideError');
errEl.style.display = 'none';
try {
const body = {
domain: document.getElementById('ovDomain').value.trim(),
target: document.getElementById('ovTarget').value.trim(),
ttl: parseInt(document.getElementById('ovTTL').value) || 60,
duration_secs: parseInt(document.getElementById('ovDuration').value) || 300,
};
const res = await fetch(API + '/overrides', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(body),
});
if (!res.ok) {
const text = await res.text();
throw new Error(text);
}
document.getElementById('ovDomain').value = '';
document.getElementById('ovTarget').value = '';
refresh();
} catch (err) {
errEl.textContent = err.message;
errEl.style.display = 'block';
}
return false;
}
async function deleteOverride(domain) {
try {
await fetch(API + '/overrides/' + encodeURIComponent(domain), { method: 'DELETE' });
refresh();
} catch (err) { /* next refresh will update */ }
}
function editOverride(domain, target, ttl, duration) {
document.getElementById('ovDomain').value = domain;
document.getElementById('ovTarget').value = target;
document.getElementById('ovTTL').value = ttl;
document.getElementById('ovDuration').value = duration;
document.getElementById('ovDomain').focus();
}
function renderCache(entries) {
const el = document.getElementById('cacheList');
document.getElementById('cacheCount').textContent = entries.length ? `${entries.length} entries` : '';
if (!entries.length) {
el.innerHTML = '<div class="empty-state">Cache empty</div>';
return;
}
// Show first 50, sorted by TTL remaining desc
const sorted = entries.sort((a, b) => b.ttl_remaining - a.ttl_remaining).slice(0, 50);
el.innerHTML = sorted.map(e => `
<div class="cache-item">
<span class="cache-domain" title="${e.domain}">${e.domain}</span>
<span class="cache-ttl">${e.query_type} ${e.ttl_remaining}s</span>
</div>
`).join('');
}
async function refresh() {
try {
const [stats, logs, overrides, cache] = await Promise.all([
fetchJSON('/stats'),
fetchJSON('/query-log?limit=100'),
fetchJSON('/overrides'),
fetchJSON('/cache'),
]);
// Connection status
document.getElementById('statusDot').className = 'status-dot';
document.getElementById('statusText').textContent = 'connected';
// Stats cards
const q = stats.queries;
document.getElementById('totalQueries').textContent = formatNumber(q.total);
document.getElementById('uptime').textContent = formatUptime(stats.uptime_secs);
document.getElementById('uptimeSub').textContent = formatUptimeSub(stats.uptime_secs);
document.getElementById('overrideCount').textContent = stats.overrides.active;
document.getElementById('blockedCount').textContent = formatNumber(q.blocked);
const bl = stats.blocking;
document.getElementById('blockedSub').textContent =
bl.domains_loaded > 0 ? `${formatNumber(bl.domains_loaded)} in blocklist` : 'loading...';
// Blocking controls
const toggleBtn = document.getElementById('toggleBtn');
const pauseBtn = document.getElementById('pauseBtn');
toggleBtn.style.display = 'inline-block';
pauseBtn.style.display = bl.enabled && !bl.paused ? 'inline-block' : 'none';
if (bl.paused) {
toggleBtn.textContent = 'Paused';
toggleBtn.style.background = 'var(--amber)';
} else if (bl.enabled) {
toggleBtn.textContent = 'Blocking On';
toggleBtn.style.background = 'var(--emerald)';
} else {
toggleBtn.textContent = 'Blocking Off';
toggleBtn.style.background = 'var(--rose)';
}
document.getElementById('cacheEntries').textContent =
`${stats.cache.entries} / ${formatNumber(stats.cache.max_entries)} entries`;
// QPS calculation
const now = Date.now();
if (prevTotal !== null && prevTime !== null) {
const dt = (now - prevTime) / 1000;
const dq = q.total - prevTotal;
const qps = dt > 0 ? (dq / dt).toFixed(1) : '0.0';
document.getElementById('qps').textContent = `~${qps}/s`;
}
prevTotal = q.total;
prevTime = now;
// Cache hit rate
const answered = q.cached + q.forwarded + q.local + q.overridden;
const hitRate = answered > 0 ? ((q.cached / answered) * 100).toFixed(1) : '0.0';
document.getElementById('cacheRate').textContent = hitRate + '%';
// Panels
renderPaths(q);
renderQueryLog(logs);
renderOverrides(overrides);
renderCache(cache);
} catch (err) {
document.getElementById('statusDot').className = 'status-dot error';
document.getElementById('statusText').textContent = 'disconnected';
}
}
async function toggleBlocking() {
try {
const stats = await fetchJSON('/blocking/stats');
const newState = !stats.enabled;
await fetch(API + '/blocking/toggle', {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ enabled: newState }),
});
refresh();
} catch (err) {}
}
async function pauseBlocking() {
try {
await fetch(API + '/blocking/pause', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ minutes: 5 }),
});
refresh();
} catch (err) {}
}
async function allowDomain(domain) {
try {
await fetch(API + '/blocking/allowlist', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ domain: domain }),
});
refresh();
} catch (err) {}
}
// Initial load + polling
refresh();
setInterval(refresh, 2000);
</script>
</body>
</html>

1395
site/index.html Normal file

File diff suppressed because it is too large Load Diff

565
src/api.rs Normal file
View File

@@ -0,0 +1,565 @@
use std::sync::Arc;
use std::time::UNIX_EPOCH;
use axum::extract::{Path, Query, State};
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{delete, get, post, put};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use crate::ctx::ServerCtx;
use crate::forward::forward_query;
use crate::query_log::QueryLogFilter;
use crate::question::QueryType;
use crate::stats::QueryPath;
const DASHBOARD_HTML: &str = include_str!("../site/dashboard.html");
pub fn router(ctx: Arc<ServerCtx>) -> Router {
Router::new()
.route("/", get(dashboard))
.route("/overrides", post(create_overrides))
.route("/overrides", get(list_overrides))
.route("/overrides", delete(clear_overrides))
.route("/overrides/environment", post(load_environment))
.route("/overrides/{domain}", get(get_override))
.route("/overrides/{domain}", delete(remove_override))
.route("/diagnose/{domain}", get(diagnose))
.route("/query-log", get(query_log))
.route("/stats", get(stats))
.route("/cache", get(list_cache))
.route("/cache", delete(flush_cache))
.route("/cache/{domain}", delete(flush_cache_domain))
.route("/health", get(health))
.route("/blocking/stats", get(blocking_stats))
.route("/blocking/toggle", put(blocking_toggle))
.route("/blocking/pause", post(blocking_pause))
.route("/blocking/allowlist", get(blocking_allowlist))
.route("/blocking/allowlist", post(blocking_allowlist_add))
.route(
"/blocking/allowlist/{domain}",
delete(blocking_allowlist_remove),
)
.with_state(ctx)
}
async fn dashboard() -> impl IntoResponse {
(
[(header::CONTENT_TYPE, "text/html; charset=utf-8")],
DASHBOARD_HTML,
)
}
// --- Request/Response DTOs ---
#[derive(Deserialize)]
struct CreateOverrideRequest {
domain: String,
target: String,
#[serde(default = "default_ttl")]
ttl: u32,
duration_secs: Option<u64>,
}
fn default_ttl() -> u32 {
60
}
#[derive(Serialize)]
struct OverrideResponse {
domain: String,
target: String,
record_type: String,
ttl: u32,
remaining_secs: Option<u64>,
}
impl From<&crate::override_store::OverrideEntry> for OverrideResponse {
fn from(e: &crate::override_store::OverrideEntry) -> Self {
OverrideResponse {
domain: e.domain.clone(),
target: e.target.clone(),
record_type: e.query_type.as_str().to_string(),
ttl: e.ttl,
remaining_secs: e.remaining_secs(),
}
}
}
#[derive(Deserialize)]
struct EnvironmentRequest {
#[serde(default)]
duration_secs: Option<u64>,
overrides: Vec<CreateOverrideRequest>,
}
#[derive(Serialize)]
struct EnvironmentResponse {
created: usize,
}
#[derive(Deserialize)]
struct QueryLogParams {
domain: Option<String>,
r#type: Option<String>,
path: Option<String>,
limit: Option<usize>,
}
#[derive(Serialize)]
struct QueryLogResponse {
timestamp_epoch: f64,
src: String,
domain: String,
query_type: String,
path: String,
rescode: String,
latency_ms: f64,
}
#[derive(Serialize)]
struct StatsResponse {
uptime_secs: u64,
queries: QueriesStats,
cache: CacheStats,
overrides: OverrideStats,
blocking: BlockingStatsResponse,
}
#[derive(Serialize)]
struct QueriesStats {
total: u64,
forwarded: u64,
cached: u64,
local: u64,
overridden: u64,
blocked: u64,
errors: u64,
}
#[derive(Serialize)]
struct CacheStats {
entries: usize,
max_entries: usize,
}
#[derive(Serialize)]
struct OverrideStats {
active: usize,
}
#[derive(Serialize)]
struct BlockingStatsResponse {
enabled: bool,
paused: bool,
domains_loaded: usize,
allowlist_size: usize,
}
#[derive(Serialize)]
struct DiagnoseResponse {
domain: String,
query_type: String,
steps: Vec<DiagnoseStep>,
}
#[derive(Serialize)]
struct DiagnoseStep {
source: String,
matched: bool,
detail: Option<String>,
}
#[derive(Serialize)]
struct CacheEntryResponse {
domain: String,
query_type: String,
ttl_remaining: u32,
}
// --- Handlers ---
async fn create_overrides(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<serde_json::Value>,
) -> Result<(StatusCode, Json<Vec<OverrideResponse>>), (StatusCode, String)> {
let requests: Vec<CreateOverrideRequest> = if req.is_array() {
serde_json::from_value(req).map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
} else {
let single: CreateOverrideRequest =
serde_json::from_value(req).map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
vec![single]
};
// Parse and validate all requests before acquiring the lock
let parsed: Vec<_> = requests
.into_iter()
.map(|req| {
let domain_lower = req.domain.to_lowercase();
Ok((domain_lower, req.target, req.ttl, req.duration_secs))
})
.collect::<Result<Vec<_>, (StatusCode, String)>>()?;
let mut store = ctx.overrides.lock().unwrap();
let mut responses = Vec::with_capacity(parsed.len());
for (domain, target, ttl, duration_secs) in parsed {
let qtype = store
.insert(&domain, &target, ttl, duration_secs)
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
responses.push(OverrideResponse {
domain,
target,
record_type: qtype.as_str().to_string(),
ttl,
remaining_secs: duration_secs,
});
}
Ok((StatusCode::CREATED, Json(responses)))
}
async fn list_overrides(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<OverrideResponse>> {
let store = ctx.overrides.lock().unwrap();
let entries: Vec<OverrideResponse> = store
.list()
.into_iter()
.map(OverrideResponse::from)
.collect();
Json(entries)
}
async fn get_override(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> Result<Json<OverrideResponse>, StatusCode> {
let store = ctx.overrides.lock().unwrap();
let entry = store.get(&domain).ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(OverrideResponse::from(entry)))
}
async fn remove_override(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
let mut store = ctx.overrides.lock().unwrap();
if store.remove(&domain) {
StatusCode::NO_CONTENT
} else {
StatusCode::NOT_FOUND
}
}
async fn clear_overrides(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.overrides.lock().unwrap().clear();
StatusCode::NO_CONTENT
}
async fn load_environment(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<EnvironmentRequest>,
) -> Result<(StatusCode, Json<EnvironmentResponse>), (StatusCode, String)> {
let mut store = ctx.overrides.lock().unwrap();
for entry in &req.overrides {
let duration = entry.duration_secs.or(req.duration_secs);
store
.insert(&entry.domain, &entry.target, entry.ttl, duration)
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
}
Ok((
StatusCode::CREATED,
Json(EnvironmentResponse {
created: req.overrides.len(),
}),
))
}
async fn diagnose(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> Json<DiagnoseResponse> {
let domain_lower = domain.to_lowercase();
let qtype = QueryType::A;
let mut steps = Vec::new();
// Check overrides
{
let store = ctx.overrides.lock().unwrap();
let entry = store.get(&domain_lower);
steps.push(DiagnoseStep {
source: "override".to_string(),
matched: entry.is_some(),
detail: entry
.map(|e| format!("{} -> {} ({})", e.domain, e.target, e.query_type.as_str())),
});
}
// Check blocklist
{
let bl = ctx.blocklist.lock().unwrap();
let blocked = bl.is_blocked(&domain_lower);
steps.push(DiagnoseStep {
source: "blocklist".to_string(),
matched: blocked,
detail: if blocked {
Some("domain is in blocklist".to_string())
} else {
None
},
});
}
// Check local zones
let zone_match = ctx
.zone_map
.get(domain_lower.as_str())
.and_then(|m| m.get(&qtype));
steps.push(DiagnoseStep {
source: "local_zone".to_string(),
matched: zone_match.is_some(),
detail: zone_match.map(|records| format!("{} records", records.len())),
});
// Check cache
{
let mut cache = ctx.cache.lock().unwrap();
let cached = cache.lookup(&domain_lower, qtype);
steps.push(DiagnoseStep {
source: "cache".to_string(),
matched: cached.is_some(),
detail: cached.map(|p| format!("{} answers", p.answers.len())),
});
}
// Check upstream (async, no locks held)
let (upstream_matched, upstream_detail) =
forward_query_for_diagnose(&domain_lower, ctx.upstream, ctx.timeout).await;
steps.push(DiagnoseStep {
source: "upstream".to_string(),
matched: upstream_matched,
detail: Some(upstream_detail),
});
Json(DiagnoseResponse {
domain: domain_lower,
query_type: qtype.as_str().to_string(),
steps,
})
}
async fn forward_query_for_diagnose(
domain: &str,
upstream: std::net::SocketAddr,
timeout: std::time::Duration,
) -> (bool, String) {
use crate::packet::DnsPacket;
use crate::question::DnsQuestion;
let mut query = DnsPacket::new();
query.header.id = 0xBEEF;
query.header.recursion_desired = true;
query
.questions
.push(DnsQuestion::new(domain.to_string(), QueryType::A));
match forward_query(&query, upstream, timeout).await {
Ok(resp) => (
true,
format!(
"{} ({} answers)",
resp.header.rescode.as_str(),
resp.answers.len()
),
),
Err(e) => (false, format!("error: {}", e)),
}
}
async fn query_log(
State(ctx): State<Arc<ServerCtx>>,
Query(params): Query<QueryLogParams>,
) -> Json<Vec<QueryLogResponse>> {
let qtype = params.r#type.as_deref().and_then(QueryType::parse_str);
let path = params.path.as_deref().and_then(QueryPath::parse_str);
let filter = QueryLogFilter {
domain: params.domain,
query_type: qtype,
path,
since: None,
limit: params.limit,
};
let raw_entries: Vec<QueryLogResponse> = {
let log = ctx.query_log.lock().unwrap();
log.query(&filter)
.into_iter()
.map(|e| {
let epoch = e
.timestamp
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
QueryLogResponse {
timestamp_epoch: epoch,
src: e.src_addr.to_string(),
domain: e.domain.clone(),
query_type: e.query_type.as_str().to_string(),
path: e.path.as_str().to_string(),
rescode: e.rescode.as_str().to_string(),
latency_ms: e.latency_us as f64 / 1000.0,
}
})
.collect()
};
Json(raw_entries)
}
async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
let snap = ctx.stats.lock().unwrap().snapshot();
let (cache_len, cache_max) = {
let cache = ctx.cache.lock().unwrap();
(cache.len(), cache.max_entries())
};
let override_count = ctx.overrides.lock().unwrap().active_count();
let bl_stats = ctx.blocklist.lock().unwrap().stats();
Json(StatsResponse {
uptime_secs: snap.uptime_secs,
queries: QueriesStats {
total: snap.total,
forwarded: snap.forwarded,
cached: snap.cached,
local: snap.local,
overridden: snap.overridden,
blocked: snap.blocked,
errors: snap.errors,
},
cache: CacheStats {
entries: cache_len,
max_entries: cache_max,
},
overrides: OverrideStats {
active: override_count,
},
blocking: BlockingStatsResponse {
enabled: bl_stats.enabled,
paused: bl_stats.paused,
domains_loaded: bl_stats.domains_loaded,
allowlist_size: bl_stats.allowlist_size,
},
})
}
async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryResponse>> {
let cache = ctx.cache.lock().unwrap();
let entries: Vec<CacheEntryResponse> = cache
.list()
.into_iter()
.map(|info| CacheEntryResponse {
domain: info.domain,
query_type: info.query_type.as_str().to_string(),
ttl_remaining: info.ttl_remaining,
})
.collect();
Json(entries)
}
async fn flush_cache(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.cache.lock().unwrap().clear();
StatusCode::NO_CONTENT
}
async fn flush_cache_domain(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
ctx.cache.lock().unwrap().remove(&domain);
StatusCode::NO_CONTENT
}
async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "status": "ok" }))
}
// --- Blocking handlers ---
async fn blocking_stats(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
let stats = ctx.blocklist.lock().unwrap().stats();
Json(serde_json::json!({
"enabled": stats.enabled,
"paused": stats.paused,
"domains_loaded": stats.domains_loaded,
"allowlist_size": stats.allowlist_size,
"list_sources": stats.list_sources,
"last_refresh_secs_ago": stats.last_refresh_secs_ago,
}))
}
#[derive(Deserialize)]
struct BlockingToggleRequest {
enabled: bool,
}
async fn blocking_toggle(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingToggleRequest>,
) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().set_enabled(req.enabled);
Json(serde_json::json!({ "enabled": req.enabled }))
}
#[derive(Deserialize)]
struct BlockingPauseRequest {
#[serde(default = "default_pause_minutes")]
minutes: u64,
}
fn default_pause_minutes() -> u64 {
5
}
async fn blocking_pause(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingPauseRequest>,
) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().pause(req.minutes * 60);
Json(serde_json::json!({ "paused_minutes": req.minutes }))
}
async fn blocking_allowlist(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<String>> {
let list = ctx.blocklist.lock().unwrap().allowlist();
Json(list)
}
#[derive(Deserialize)]
struct AllowlistRequest {
domain: String,
}
async fn blocking_allowlist_add(
State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<AllowlistRequest>,
) -> (StatusCode, Json<serde_json::Value>) {
ctx.blocklist.lock().unwrap().add_to_allowlist(&req.domain);
(
StatusCode::CREATED,
Json(serde_json::json!({ "allowed": req.domain })),
)
}
async fn blocking_allowlist_remove(
State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>,
) -> StatusCode {
if ctx.blocklist.lock().unwrap().remove_from_allowlist(&domain) {
StatusCode::NO_CONTENT
} else {
StatusCode::NOT_FOUND
}
}

187
src/blocklist.rs Normal file
View File

@@ -0,0 +1,187 @@
use std::collections::HashSet;
use std::time::Instant;
use log::{info, warn};
pub struct BlocklistStore {
domains: HashSet<String>,
allowlist: HashSet<String>,
enabled: bool,
paused_until: Option<Instant>,
list_sources: Vec<String>,
last_refresh: Option<Instant>,
}
pub struct BlocklistStats {
pub enabled: bool,
pub paused: bool,
pub domains_loaded: usize,
pub allowlist_size: usize,
pub list_sources: Vec<String>,
pub last_refresh_secs_ago: Option<u64>,
}
impl Default for BlocklistStore {
fn default() -> Self {
Self::new()
}
}
impl BlocklistStore {
pub fn new() -> Self {
BlocklistStore {
domains: HashSet::new(),
allowlist: HashSet::new(),
enabled: true,
paused_until: None,
list_sources: Vec::new(),
last_refresh: None,
}
}
pub fn is_blocked(&self, domain: &str) -> bool {
if !self.enabled {
return false;
}
if let Some(until) = self.paused_until {
if Instant::now() < until {
return false;
}
}
if self.allowlist.contains(domain) {
return false;
}
if self.domains.contains(domain) {
return true;
}
// Walk up: ads.tracker.example.com → tracker.example.com → example.com
let mut d = domain;
while let Some(dot) = d.find('.') {
d = &d[dot + 1..];
if self.allowlist.contains(d) {
return false;
}
if self.domains.contains(d) {
return true;
}
}
false
}
/// Atomically swap in a new domain set. Build the set outside the lock,
/// then call this to swap — keeps lock hold time sub-microsecond.
pub fn swap_domains(&mut self, domains: HashSet<String>, sources: Vec<String>) {
self.domains = domains;
self.list_sources = sources;
self.last_refresh = Some(Instant::now());
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn pause(&mut self, seconds: u64) {
self.paused_until = Some(Instant::now() + std::time::Duration::from_secs(seconds));
}
pub fn is_paused(&self) -> bool {
self.paused_until
.map(|until| Instant::now() < until)
.unwrap_or(false)
}
pub fn add_to_allowlist(&mut self, domain: &str) {
self.allowlist.insert(domain.to_lowercase());
}
pub fn remove_from_allowlist(&mut self, domain: &str) -> bool {
self.allowlist.remove(&domain.to_lowercase())
}
pub fn allowlist(&self) -> Vec<String> {
self.allowlist.iter().cloned().collect()
}
pub fn stats(&self) -> BlocklistStats {
BlocklistStats {
enabled: self.is_enabled(),
paused: self.is_paused(),
domains_loaded: self.domains.len(),
allowlist_size: self.allowlist.len(),
list_sources: self.list_sources.clone(),
last_refresh_secs_ago: self.last_refresh.map(|t| t.elapsed().as_secs()),
}
}
}
/// Parse a blocklist text file into a set of domains.
pub fn parse_blocklist(text: &str) -> HashSet<String> {
let mut domains = HashSet::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') || line.starts_with('!') {
continue;
}
// Handle hosts-file format: "0.0.0.0 domain" or "127.0.0.1 domain" (space or tab)
let domain = if line.starts_with("0.0.0.0")
|| line.starts_with("127.0.0.1")
|| line.starts_with("::")
{
line.split_whitespace()
.nth(1)
.unwrap_or("")
.trim_end_matches('.')
} else if line.contains(' ') || line.contains('\t') {
continue;
} else {
// Plain domain or adblock filter syntax
let d = line.trim_start_matches("*.").trim_start_matches("||");
let d = d.split('$').next().unwrap_or(d); // strip adblock $options
d.trim_end_matches('^').trim_end_matches('.')
};
let domain = domain.to_lowercase();
if !domain.is_empty()
&& domain.contains('.')
&& domain != "localhost"
&& domain != "localhost.localdomain"
{
domains.insert(domain);
}
}
domains
}
pub async fn download_blocklists(lists: &[String]) -> Vec<(String, String)> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.unwrap_or_default();
let mut results = Vec::new();
for url in lists {
match client.get(url).send().await {
Ok(resp) => match resp.text().await {
Ok(text) => {
info!("downloaded blocklist: {} ({} bytes)", url, text.len());
results.push((url.clone(), text));
}
Err(e) => warn!("failed to read blocklist body {}: {}", url, e),
},
Err(e) => warn!("failed to download blocklist {}: {}", url, e),
}
}
results
}

View File

@@ -11,8 +11,11 @@ struct CacheEntry {
ttl: Duration,
}
/// DNS cache using a two-level map (domain -> query_type -> entry) so that
/// lookups can borrow `&str` instead of allocating a `String` key.
pub struct DnsCache {
entries: HashMap<(String, QueryType), CacheEntry>,
entries: HashMap<String, HashMap<QueryType, CacheEntry>>,
entry_count: usize,
max_entries: usize,
min_ttl: u32,
max_ttl: u32,
@@ -23,6 +26,7 @@ impl DnsCache {
pub fn new(max_entries: usize, min_ttl: u32, max_ttl: u32) -> Self {
DnsCache {
entries: HashMap::new(),
entry_count: 0,
max_entries,
min_ttl,
max_ttl,
@@ -33,17 +37,22 @@ impl DnsCache {
pub fn lookup(&mut self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
self.query_count += 1;
// Periodic eviction every 1000 queries
if self.query_count.is_multiple_of(1000) {
self.evict_expired();
}
let key = (domain.to_string(), qtype);
let entry = self.entries.get(&key)?;
let type_map = self.entries.get(domain)?;
let entry = type_map.get(&qtype)?;
let elapsed = entry.inserted_at.elapsed();
if elapsed >= entry.ttl {
self.entries.remove(&key);
// Expired: remove this entry
let type_map = self.entries.get_mut(domain).unwrap();
type_map.remove(&qtype);
self.entry_count -= 1;
if type_map.is_empty() {
self.entries.remove(domain);
}
return None;
}
@@ -59,10 +68,9 @@ impl DnsCache {
}
pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) {
if self.entries.len() >= self.max_entries {
if self.entry_count >= self.max_entries {
self.evict_expired();
// If still full after eviction, skip insertion
if self.entries.len() >= self.max_entries {
if self.entry_count >= self.max_entries {
return;
}
}
@@ -71,9 +79,18 @@ impl DnsCache {
.unwrap_or(self.min_ttl)
.clamp(self.min_ttl, self.max_ttl);
let key = (domain.to_string(), qtype);
self.entries.insert(
key,
let type_map = if let Some(existing) = self.entries.get_mut(domain) {
existing
} else {
self.entries.entry(domain.to_string()).or_default()
};
if !type_map.contains_key(&qtype) {
self.entry_count += 1;
}
type_map.insert(
qtype,
CacheEntry {
packet: packet.clone(),
inserted_at: Instant::now(),
@@ -82,10 +99,64 @@ impl DnsCache {
);
}
fn evict_expired(&mut self) {
self.entries
.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl);
pub fn len(&self) -> usize {
self.entry_count
}
pub fn is_empty(&self) -> bool {
self.entry_count == 0
}
pub fn max_entries(&self) -> usize {
self.max_entries
}
pub fn clear(&mut self) {
self.entries.clear();
self.entry_count = 0;
}
pub fn remove(&mut self, domain: &str) {
let domain_lower = domain.to_lowercase();
if let Some(type_map) = self.entries.remove(&domain_lower) {
self.entry_count -= type_map.len();
}
}
pub fn list(&self) -> Vec<CacheInfo> {
let mut result = Vec::new();
for (domain, type_map) in &self.entries {
for (qtype, entry) in type_map {
let elapsed = entry.inserted_at.elapsed();
if elapsed < entry.ttl {
let remaining = (entry.ttl - elapsed).as_secs() as u32;
result.push(CacheInfo {
domain: domain.clone(),
query_type: *qtype,
ttl_remaining: remaining,
});
}
}
}
result
}
fn evict_expired(&mut self) {
let mut count = 0;
self.entries.retain(|_, type_map| {
let before = type_map.len();
type_map.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl);
count += before - type_map.len();
!type_map.is_empty()
});
self.entry_count -= count;
}
}
pub struct CacheInfo {
pub domain: String,
pub query_type: QueryType,
pub ttl_remaining: u32,
}
fn extract_min_ttl(records: &[DnsRecord]) -> Option<u32> {

View File

@@ -18,6 +18,8 @@ pub struct Config {
#[serde(default)]
pub cache: CacheConfig,
#[serde(default)]
pub blocking: BlockingConfig,
#[serde(default)]
pub zones: Vec<ZoneRecord>,
}
@@ -25,12 +27,15 @@ pub struct Config {
pub struct ServerConfig {
#[serde(default = "default_bind_addr")]
pub bind_addr: String,
#[serde(default = "default_api_port")]
pub api_port: u16,
}
impl Default for ServerConfig {
fn default() -> Self {
ServerConfig {
bind_addr: default_bind_addr(),
api_port: default_api_port(),
}
}
}
@@ -39,6 +44,10 @@ fn default_bind_addr() -> String {
"0.0.0.0:53".to_string()
}
fn default_api_port() -> u16 {
5380
}
#[derive(Deserialize)]
pub struct UpstreamConfig {
#[serde(default = "default_upstream_addr")]
@@ -108,6 +117,41 @@ pub struct ZoneRecord {
pub ttl: u32,
}
#[derive(Deserialize)]
pub struct BlockingConfig {
#[serde(default = "default_blocking_enabled")]
pub enabled: bool,
#[serde(default = "default_blocklists")]
pub lists: Vec<String>,
#[serde(default = "default_refresh_hours")]
pub refresh_hours: u64,
#[serde(default)]
pub allowlist: Vec<String>,
}
impl Default for BlockingConfig {
fn default() -> Self {
BlockingConfig {
enabled: default_blocking_enabled(),
lists: default_blocklists(),
refresh_hours: default_refresh_hours(),
allowlist: Vec::new(),
}
}
}
fn default_blocking_enabled() -> bool {
true
}
fn default_blocklists() -> Vec<String> {
vec!["https://cdn.jsdelivr.net/gh/hagezi/dns-blocklists@latest/hosts/pro.txt".to_string()]
}
fn default_refresh_hours() -> u64 {
24
}
fn default_zone_ttl() -> u32 {
300
}
@@ -118,6 +162,7 @@ pub fn load_config(path: &str) -> Result<Config> {
server: ServerConfig::default(),
upstream: UpstreamConfig::default(),
cache: CacheConfig::default(),
blocking: BlockingConfig::default(),
zones: Vec::new(),
});
}
@@ -126,10 +171,10 @@ pub fn load_config(path: &str) -> Result<Config> {
Ok(config)
}
pub fn build_zone_map(
zones: &[ZoneRecord],
) -> Result<HashMap<(String, QueryType), Vec<DnsRecord>>> {
let mut map: HashMap<(String, QueryType), Vec<DnsRecord>> = HashMap::new();
pub type ZoneMap = HashMap<String, HashMap<QueryType, Vec<DnsRecord>>>;
pub fn build_zone_map(zones: &[ZoneRecord]) -> Result<ZoneMap> {
let mut map: ZoneMap = HashMap::new();
for zone in zones {
let domain = zone.domain.to_lowercase();
@@ -203,7 +248,11 @@ pub fn build_zone_map(
}
};
map.entry((domain, qtype)).or_default().push(record);
map.entry(domain)
.or_default()
.entry(qtype)
.or_default()
.push(record);
}
Ok(map)

155
src/ctx.rs Normal file
View File

@@ -0,0 +1,155 @@
use std::net::SocketAddr;
use std::sync::Mutex;
use std::time::{Duration, Instant, SystemTime};
use log::{debug, error, info, warn};
use tokio::net::UdpSocket;
use crate::blocklist::BlocklistStore;
use crate::buffer::BytePacketBuffer;
use crate::cache::DnsCache;
use crate::config::ZoneMap;
use crate::forward::forward_query;
use crate::header::ResultCode;
use crate::override_store::OverrideStore;
use crate::packet::DnsPacket;
use crate::query_log::{QueryLog, QueryLogEntry};
use crate::record::DnsRecord;
use crate::stats::{QueryPath, ServerStats};
use crate::system_dns::ForwardingRule;
pub struct ServerCtx {
pub socket: UdpSocket,
pub zone_map: ZoneMap,
pub cache: Mutex<DnsCache>,
pub stats: Mutex<ServerStats>,
pub overrides: Mutex<OverrideStore>,
pub blocklist: Mutex<BlocklistStore>,
pub query_log: Mutex<QueryLog>,
pub forwarding_rules: Vec<ForwardingRule>,
pub upstream: SocketAddr,
pub timeout: Duration,
}
pub async fn handle_query(
mut buffer: BytePacketBuffer,
src_addr: SocketAddr,
ctx: &ServerCtx,
) -> crate::Result<()> {
let start = Instant::now();
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(packet) => packet,
Err(e) => {
warn!("{} | PARSE ERROR | {}", src_addr, e);
return Ok(());
}
};
let (qname, qtype) = match query.questions.first() {
Some(q) => (q.name.clone(), q.qtype),
None => return Ok(()),
};
// Pipeline: overrides -> blocklist -> local zones -> cache -> upstream
// Each lock is scoped to avoid holding MutexGuard across await points.
let (response, path) = {
let override_record = ctx.overrides.lock().unwrap().lookup(&qname);
if let Some(record) = override_record {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers.push(record);
(resp, QueryPath::Overridden)
} else if ctx.blocklist.lock().unwrap().is_blocked(&qname) {
use crate::question::QueryType;
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
match qtype {
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
domain: qname.clone(),
addr: std::net::Ipv6Addr::UNSPECIFIED,
ttl: 60,
}),
_ => resp.answers.push(DnsRecord::A {
domain: qname.clone(),
addr: std::net::Ipv4Addr::UNSPECIFIED,
ttl: 60,
}),
}
(resp, QueryPath::Blocked)
} else if let Some(records) = ctx.zone_map.get(qname.as_str()).and_then(|m| m.get(&qtype)) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers = records.clone();
(resp, QueryPath::Local)
} else {
let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype);
if let Some(cached) = cached {
let mut resp = cached;
resp.header.id = query.header.id;
(resp, QueryPath::Cached)
} else {
let upstream =
crate::system_dns::match_forwarding_rule(&qname, &ctx.forwarding_rules)
.unwrap_or(ctx.upstream);
match forward_query(&query, upstream, ctx.timeout).await {
Ok(resp) => {
ctx.cache.lock().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded)
}
Err(e) => {
error!(
"{} | {:?} {} | UPSTREAM ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
)
}
}
}
}
};
let elapsed = start.elapsed();
info!(
"{} | {:?} {} | {} | {} | {}ms",
src_addr,
qtype,
qname,
path.as_str(),
response.header.rescode.as_str(),
elapsed.as_millis(),
);
debug!(
"response: {} answers, {} authorities, {} resources",
response.answers.len(),
response.authorities.len(),
response.resources.len(),
);
let mut resp_buffer = BytePacketBuffer::new();
response.write(&mut resp_buffer)?;
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
// Record stats and query log
{
let mut s = ctx.stats.lock().unwrap();
let total = s.record(path);
if total.is_multiple_of(1000) {
s.log_summary();
}
}
ctx.query_log.lock().unwrap().push(QueryLogEntry {
timestamp: SystemTime::now(),
src_addr,
domain: qname,
query_type: qtype,
path,
rescode: response.header.rescode,
latency_us: elapsed.as_micros() as u64,
});
Ok(())
}

View File

@@ -1,12 +1,18 @@
pub mod api;
pub mod blocklist;
pub mod buffer;
pub mod cache;
pub mod config;
pub mod ctx;
pub mod forward;
pub mod header;
pub mod override_store;
pub mod packet;
pub mod query_log;
pub mod question;
pub mod record;
pub mod stats;
pub mod system_dns;
pub type Error = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -1,47 +1,48 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::time::Duration;
use log::{debug, error, info, warn};
use log::{error, info};
use tokio::net::UdpSocket;
use dns_fun::buffer::BytePacketBuffer;
use dns_fun::cache::DnsCache;
use dns_fun::config::{build_zone_map, load_config};
use dns_fun::forward::forward_query;
use dns_fun::header::ResultCode;
use dns_fun::packet::DnsPacket;
use dns_fun::question::QueryType;
use dns_fun::record::DnsRecord;
use dns_fun::stats::{QueryPath, ServerStats};
struct ServerCtx {
socket: Arc<UdpSocket>,
zone_map: HashMap<(String, QueryType), Vec<DnsRecord>>,
cache: Mutex<DnsCache>,
stats: Mutex<ServerStats>,
upstream: SocketAddr,
timeout: Duration,
}
use numa::blocklist::{download_blocklists, parse_blocklist, BlocklistStore};
use numa::buffer::BytePacketBuffer;
use numa::cache::DnsCache;
use numa::config::{build_zone_map, load_config};
use numa::ctx::{handle_query, ServerCtx};
use numa::override_store::OverrideStore;
use numa::query_log::QueryLog;
use numa::stats::ServerStats;
use numa::system_dns::discover_forwarding_rules;
#[tokio::main]
async fn main() -> dns_fun::Result<()> {
async fn main() -> numa::Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
.format_timestamp_millis()
.init();
let config_path = std::env::args()
.nth(1)
.unwrap_or_else(|| "dns_fun.toml".to_string());
.unwrap_or_else(|| "numa.toml".to_string());
let config = load_config(&config_path)?;
let upstream: SocketAddr =
format!("{}:{}", config.upstream.address, config.upstream.port).parse()?;
let socket = Arc::new(UdpSocket::bind(&config.server.bind_addr).await?);
let api_port = config.server.api_port;
let mut blocklist = BlocklistStore::new();
for domain in &config.blocking.allowlist {
blocklist.add_to_allowlist(domain);
}
if !config.blocking.enabled {
blocklist.set_enabled(false);
}
// Auto-discover conditional forwarding rules from OS (Tailscale, VPN, etc.)
let forwarding_rules = discover_forwarding_rules();
let ctx = Arc::new(ServerCtx {
socket: Arc::clone(&socket),
socket: UdpSocket::bind(&config.server.bind_addr).await?,
zone_map: build_zone_map(&config.zones)?,
cache: Mutex::new(DnsCache::new(
config.cache.max_entries,
@@ -49,21 +50,72 @@ async fn main() -> dns_fun::Result<()> {
config.cache.max_ttl,
)),
stats: Mutex::new(ServerStats::new()),
overrides: Mutex::new(OverrideStore::new()),
blocklist: Mutex::new(blocklist),
query_log: Mutex::new(QueryLog::new(1000)),
forwarding_rules,
upstream,
timeout: Duration::from_millis(config.upstream.timeout_ms),
});
let zone_count: usize = ctx.zone_map.values().map(|m| m.len()).sum();
eprintln!("\n\x1b[38;2;192;98;58m ╔══════════════════════════════════════════╗\x1b[0m");
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[1;38;2;192;98;58mNUMA\x1b[0m \x1b[3;38;2;163;152;136mDNS that governs itself\x1b[0m \x1b[38;2;192;98;58m║\x1b[0m");
eprintln!("\x1b[38;2;192;98;58m ╠══════════════════════════════════════════╣\x1b[0m");
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mDNS\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", config.server.bind_addr);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mAPI\x1b[0m http://localhost:{:<16}\x1b[38;2;192;98;58m║\x1b[0m", api_port);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mDashboard\x1b[0m http://localhost:{:<16}\x1b[38;2;192;98;58m║\x1b[0m", api_port);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mUpstream\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", upstream);
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mZones\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", format!("{} records", zone_count));
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mCache\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m", format!("max {} entries", config.cache.max_entries));
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mBlocking\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m",
if config.blocking.enabled { format!("{} lists", config.blocking.lists.len()) } else { "disabled".to_string() });
if !ctx.forwarding_rules.is_empty() {
eprintln!("\x1b[38;2;192;98;58m ║\x1b[0m \x1b[38;2;107;124;78mRouting\x1b[0m {:<30}\x1b[38;2;192;98;58m║\x1b[0m",
format!("{} conditional rules", ctx.forwarding_rules.len()));
}
eprintln!("\x1b[38;2;192;98;58m ╚══════════════════════════════════════════╝\x1b[0m\n");
info!(
"dns_fun starting on {}, upstream {}, {} zone records, cache max {}",
config.server.bind_addr,
upstream,
ctx.zone_map.len(),
config.cache.max_entries,
"numa listening on {}, upstream {}, {} zone records, cache max {}, API on port {}",
config.server.bind_addr, upstream, zone_count, config.cache.max_entries, api_port,
);
// Download blocklists on startup
let blocklist_lists = config.blocking.lists.clone();
let refresh_hours = config.blocking.refresh_hours;
if config.blocking.enabled && !blocklist_lists.is_empty() {
let bl_ctx = Arc::clone(&ctx);
let bl_lists = blocklist_lists.clone();
tokio::spawn(async move {
load_blocklists(&bl_ctx, &bl_lists).await;
// Periodic refresh
let mut interval = tokio::time::interval(Duration::from_secs(refresh_hours * 3600));
interval.tick().await; // skip immediate tick
loop {
interval.tick().await;
info!("refreshing blocklists...");
load_blocklists(&bl_ctx, &bl_lists).await;
}
});
}
// Spawn HTTP API server
let api_ctx = Arc::clone(&ctx);
let api_addr: SocketAddr = format!("0.0.0.0:{}", api_port).parse()?;
tokio::spawn(async move {
let app = numa::api::router(api_ctx);
let listener = tokio::net::TcpListener::bind(api_addr).await.unwrap();
info!("HTTP API listening on {}", api_addr);
axum::serve(listener, app).await.unwrap();
});
// UDP DNS listener
#[allow(clippy::infinite_loop)]
loop {
let mut buffer = BytePacketBuffer::new();
let (_, src_addr) = socket.recv_from(&mut buffer.buf).await?;
let (_, src_addr) = ctx.socket.recv_from(&mut buffer.buf).await?;
let ctx = Arc::clone(&ctx);
tokio::spawn(async move {
@@ -74,87 +126,28 @@ async fn main() -> dns_fun::Result<()> {
}
}
async fn handle_query(
mut buffer: BytePacketBuffer,
src_addr: SocketAddr,
ctx: &ServerCtx,
) -> dns_fun::Result<()> {
let start = Instant::now();
async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
let downloaded = download_blocklists(lists).await;
let query = match DnsPacket::from_buffer(&mut buffer) {
Ok(packet) => packet,
Err(e) => {
warn!("{} | PARSE ERROR | {}", src_addr, e);
return Ok(());
}
};
let (qname, qtype) = match query.questions.first() {
Some(q) => (q.name.clone(), q.qtype),
None => return Ok(()),
};
// Pipeline: local zones -> cache -> upstream
// Each lock is scoped to avoid holding MutexGuard across await points.
let (response, path) = if let Some(records) = ctx.zone_map.get(&(qname.to_lowercase(), qtype)) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers = records.clone();
(resp, QueryPath::Local)
} else {
let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype);
if let Some(cached) = cached {
let mut resp = cached;
resp.header.id = query.header.id;
(resp, QueryPath::Cached)
} else {
match forward_query(&query, ctx.upstream, ctx.timeout).await {
Ok(resp) => {
ctx.cache.lock().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded)
}
Err(e) => {
error!(
"{} | {:?} {} | UPSTREAM ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
)
}
}
}
};
let elapsed = start.elapsed();
info!(
"{} | {:?} {} | {} | {} | {}ms",
src_addr,
qtype,
qname,
path.as_str(),
response.header.rescode.as_str(),
elapsed.as_millis(),
);
debug!(
"response: {} answers, {} authorities, {} resources",
response.answers.len(),
response.authorities.len(),
response.resources.len(),
);
let mut resp_buffer = BytePacketBuffer::new();
response.write(&mut resp_buffer)?;
ctx.socket.send_to(resp_buffer.filled(), src_addr).await?;
// Record stats and log summary every 1000 queries (single lock acquisition)
let mut s = ctx.stats.lock().unwrap();
let total = s.record(path);
if total.is_multiple_of(1000) {
s.log_summary();
// Parse outside the lock to avoid blocking DNS queries during parse (~100ms)
let mut all_domains = std::collections::HashSet::new();
let mut sources = Vec::new();
for (source, text) in &downloaded {
let domains = parse_blocklist(text);
info!("blocklist: {} domains from {}", domains.len(), source);
all_domains.extend(domains);
sources.push(source.clone());
}
let total = all_domains.len();
Ok(())
// Swap under lock — sub-microsecond
ctx.blocklist
.lock()
.unwrap()
.swap_domains(all_domains, sources);
info!(
"blocking enabled: {} unique domains from {} lists",
total,
downloaded.len()
);
}

153
src/override_store.rs Normal file
View File

@@ -0,0 +1,153 @@
use std::collections::HashMap;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::time::Instant;
use crate::question::QueryType;
use crate::record::DnsRecord;
use crate::Result;
pub struct OverrideEntry {
pub domain: String,
pub target: String,
pub record: DnsRecord,
pub query_type: QueryType,
pub ttl: u32,
pub created_at: Instant,
pub duration_secs: Option<u64>,
}
impl OverrideEntry {
pub fn expires_at(&self) -> Option<Instant> {
self.duration_secs
.map(|d| self.created_at + std::time::Duration::from_secs(d))
}
pub fn is_expired(&self) -> bool {
self.expires_at()
.map(|exp| Instant::now() >= exp)
.unwrap_or(false)
}
pub fn remaining_secs(&self) -> Option<u64> {
self.expires_at().map(|exp| {
let now = Instant::now();
if now >= exp {
0
} else {
(exp - now).as_secs()
}
})
}
}
pub struct OverrideStore {
entries: HashMap<String, OverrideEntry>,
}
impl Default for OverrideStore {
fn default() -> Self {
Self::new()
}
}
impl OverrideStore {
pub fn new() -> Self {
OverrideStore {
entries: HashMap::new(),
}
}
pub fn insert(
&mut self,
domain: &str,
target: &str,
ttl: u32,
duration_secs: Option<u64>,
) -> Result<QueryType> {
let domain_lower = domain.to_lowercase();
let (qtype, record) = parse_target(&domain_lower, target, ttl)?;
self.entries.insert(
domain_lower.clone(),
OverrideEntry {
domain: domain_lower,
target: target.to_string(),
record,
query_type: qtype,
ttl,
created_at: Instant::now(),
duration_secs,
},
);
Ok(qtype)
}
/// Hot path: assumes `domain` is already lowercased (the parser does this).
pub fn lookup(&mut self, domain: &str) -> Option<DnsRecord> {
let entry = self.entries.get(domain)?;
if entry.is_expired() {
self.entries.remove(domain);
return None;
}
Some(entry.record.clone())
}
pub fn get(&self, domain: &str) -> Option<&OverrideEntry> {
let key = domain.to_lowercase();
let entry = self.entries.get(&key)?;
if entry.is_expired() {
return None;
}
Some(entry)
}
pub fn remove(&mut self, domain: &str) -> bool {
self.entries.remove(&domain.to_lowercase()).is_some()
}
pub fn list(&self) -> Vec<&OverrideEntry> {
self.entries.values().filter(|e| !e.is_expired()).collect()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn active_count(&self) -> usize {
self.entries.values().filter(|e| !e.is_expired()).count()
}
}
fn parse_target(domain: &str, target: &str, ttl: u32) -> Result<(QueryType, DnsRecord)> {
if let Ok(addr) = target.parse::<Ipv4Addr>() {
return Ok((
QueryType::A,
DnsRecord::A {
domain: domain.to_string(),
addr,
ttl,
},
));
}
if let Ok(addr) = target.parse::<Ipv6Addr>() {
return Ok((
QueryType::AAAA,
DnsRecord::AAAA {
domain: domain.to_string(),
addr,
ttl,
},
));
}
Ok((
QueryType::CNAME,
DnsRecord::CNAME {
domain: domain.to_string(),
host: target.to_string(),
ttl,
},
))
}

77
src/query_log.rs Normal file
View File

@@ -0,0 +1,77 @@
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::time::SystemTime;
use crate::header::ResultCode;
use crate::question::QueryType;
use crate::stats::QueryPath;
pub struct QueryLogEntry {
pub timestamp: SystemTime,
pub src_addr: SocketAddr,
pub domain: String,
pub query_type: QueryType,
pub path: QueryPath,
pub rescode: ResultCode,
pub latency_us: u64,
}
pub struct QueryLog {
entries: VecDeque<QueryLogEntry>,
capacity: usize,
}
impl QueryLog {
pub fn new(capacity: usize) -> Self {
QueryLog {
entries: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, entry: QueryLogEntry) {
if self.entries.len() >= self.capacity {
self.entries.pop_front();
}
self.entries.push_back(entry);
}
pub fn query(&self, filter: &QueryLogFilter) -> Vec<&QueryLogEntry> {
self.entries
.iter()
.rev()
.filter(|e| {
if let Some(ref domain) = filter.domain {
if !e.domain.contains(domain.as_str()) {
return false;
}
}
if let Some(qtype) = filter.query_type {
if e.query_type != qtype {
return false;
}
}
if let Some(path) = filter.path {
if e.path != path {
return false;
}
}
if let Some(since) = filter.since {
if e.timestamp < since {
return false;
}
}
true
})
.take(filter.limit.unwrap_or(50))
.collect()
}
}
pub struct QueryLogFilter {
pub domain: Option<String>,
pub query_type: Option<QueryType>,
pub path: Option<QueryPath>,
pub since: Option<SystemTime>,
pub limit: Option<usize>,
}

View File

@@ -33,6 +33,33 @@ impl QueryType {
_ => QueryType::UNKNOWN(num),
}
}
pub fn as_str(&self) -> &'static str {
match self {
QueryType::A => "A",
QueryType::NS => "NS",
QueryType::CNAME => "CNAME",
QueryType::MX => "MX",
QueryType::AAAA => "AAAA",
QueryType::UNKNOWN(_) => "UNKNOWN",
}
}
pub fn parse_str(s: &str) -> Option<QueryType> {
if s.eq_ignore_ascii_case("A") {
Some(QueryType::A)
} else if s.eq_ignore_ascii_case("NS") {
Some(QueryType::NS)
} else if s.eq_ignore_ascii_case("CNAME") {
Some(QueryType::CNAME)
} else if s.eq_ignore_ascii_case("MX") {
Some(QueryType::MX)
} else if s.eq_ignore_ascii_case("AAAA") {
Some(QueryType::AAAA)
} else {
None
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]

View File

@@ -240,7 +240,7 @@ impl DnsRecord {
}
}
DnsRecord::UNKNOWN { .. } => {
println!("Skipping record: {:?}", self);
log::debug!("Skipping record: {:?}", self);
}
}

View File

@@ -6,15 +6,18 @@ pub struct ServerStats {
queries_cached: u64,
queries_blocked: u64,
queries_local: u64,
queries_overridden: u64,
upstream_errors: u64,
started_at: Instant,
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum QueryPath {
Local,
Cached,
Forwarded,
Blocked,
Overridden,
UpstreamError,
}
@@ -25,9 +28,28 @@ impl QueryPath {
QueryPath::Cached => "CACHED",
QueryPath::Forwarded => "FORWARD",
QueryPath::Blocked => "BLOCKED",
QueryPath::Overridden => "OVERRIDE",
QueryPath::UpstreamError => "SERVFAIL",
}
}
pub fn parse_str(s: &str) -> Option<QueryPath> {
if s.eq_ignore_ascii_case("LOCAL") {
Some(QueryPath::Local)
} else if s.eq_ignore_ascii_case("CACHED") {
Some(QueryPath::Cached)
} else if s.eq_ignore_ascii_case("FORWARD") {
Some(QueryPath::Forwarded)
} else if s.eq_ignore_ascii_case("BLOCKED") {
Some(QueryPath::Blocked)
} else if s.eq_ignore_ascii_case("OVERRIDE") {
Some(QueryPath::Overridden)
} else if s.eq_ignore_ascii_case("SERVFAIL") {
Some(QueryPath::UpstreamError)
} else {
None
}
}
}
impl Default for ServerStats {
@@ -44,6 +66,7 @@ impl ServerStats {
queries_cached: 0,
queries_blocked: 0,
queries_local: 0,
queries_overridden: 0,
upstream_errors: 0,
started_at: Instant::now(),
}
@@ -56,6 +79,7 @@ impl ServerStats {
QueryPath::Cached => self.queries_cached += 1,
QueryPath::Forwarded => self.queries_forwarded += 1,
QueryPath::Blocked => self.queries_blocked += 1,
QueryPath::Overridden => self.queries_overridden += 1,
QueryPath::UpstreamError => self.upstream_errors += 1,
}
self.queries_total
@@ -65,6 +89,23 @@ impl ServerStats {
self.queries_total
}
pub fn uptime_secs(&self) -> u64 {
self.started_at.elapsed().as_secs()
}
pub fn snapshot(&self) -> StatsSnapshot {
StatsSnapshot {
uptime_secs: self.uptime_secs(),
total: self.queries_total,
forwarded: self.queries_forwarded,
cached: self.queries_cached,
local: self.queries_local,
overridden: self.queries_overridden,
blocked: self.queries_blocked,
errors: self.upstream_errors,
}
}
pub fn log_summary(&self) {
let uptime = self.started_at.elapsed();
let hours = uptime.as_secs() / 3600;
@@ -72,14 +113,26 @@ impl ServerStats {
let secs = uptime.as_secs() % 60;
log::info!(
"STATS | uptime {}h{}m{}s | total {} | fwd {} | cached {} | local {} | blocked {} | errors {}",
"STATS | uptime {}h{}m{}s | total {} | fwd {} | cached {} | local {} | override {} | blocked {} | errors {}",
hours, mins, secs,
self.queries_total,
self.queries_forwarded,
self.queries_cached,
self.queries_local,
self.queries_overridden,
self.queries_blocked,
self.upstream_errors,
);
}
}
pub struct StatsSnapshot {
pub uptime_secs: u64,
pub total: u64,
pub forwarded: u64,
pub cached: u64,
pub local: u64,
pub overridden: u64,
pub blocked: u64,
pub errors: u64,
}

144
src/system_dns.rs Normal file
View File

@@ -0,0 +1,144 @@
use std::net::SocketAddr;
use log::{debug, info, warn};
/// A conditional forwarding rule: domains matching `suffix` are forwarded to `upstream`.
#[derive(Debug, Clone)]
pub struct ForwardingRule {
pub suffix: String,
dot_suffix: String, // pre-computed ".suffix" for zero-alloc matching
pub upstream: SocketAddr,
}
/// Discover system DNS forwarding rules from the OS.
/// On macOS, parses `scutil --dns`. Returns rules sorted longest-suffix-first
/// so more specific matches take priority.
pub fn discover_forwarding_rules() -> Vec<ForwardingRule> {
#[cfg(target_os = "macos")]
{
discover_macos()
}
#[cfg(not(target_os = "macos"))]
{
info!("system DNS auto-discovery not implemented for this OS");
Vec::new()
}
}
#[cfg(target_os = "macos")]
fn discover_macos() -> Vec<ForwardingRule> {
let output = match std::process::Command::new("scutil").arg("--dns").output() {
Ok(o) => o,
Err(e) => {
warn!("failed to run scutil --dns: {}", e);
return Vec::new();
}
};
let text = String::from_utf8_lossy(&output.stdout);
let mut rules = Vec::new();
// Parse resolver blocks: look for blocks with both `domain` and `nameserver[0]`
// that have the `Supplemental` flag (conditional forwarding, not default)
let mut current_domain: Option<String> = None;
let mut current_nameserver: Option<String> = None;
let mut is_supplemental = false;
for line in text.lines() {
let line = line.trim();
if line.starts_with("resolver #") {
// Emit previous block if valid
if let (Some(domain), Some(ns), true) = (
current_domain.take(),
current_nameserver.take(),
is_supplemental,
) {
if let Some(rule) = make_rule(&domain, &ns) {
rules.push(rule);
}
}
current_domain = None;
current_nameserver = None;
is_supplemental = false;
} else if line.starts_with("domain") && line.contains(':') {
// "domain : tailcee7cc.ts.net."
if let Some(val) = line.split(':').nth(1) {
let domain = val.trim().trim_end_matches('.').to_lowercase();
if !domain.is_empty()
&& domain != "local"
&& !domain.ends_with("in-addr.arpa")
&& !domain.ends_with("ip6.arpa")
{
current_domain = Some(domain);
}
}
} else if line.starts_with("nameserver[0]") && line.contains(':') {
if let Some(val) = line.split(':').nth(1) {
let ns = val.trim().to_string();
// Only use IPv4 nameservers for now
if ns.parse::<std::net::Ipv4Addr>().is_ok() {
current_nameserver = Some(ns);
}
}
} else if line.starts_with("flags") && line.contains("Supplemental") {
is_supplemental = true;
} else if line.starts_with("DNS configuration (for scoped") {
// Stop at scoped section — those are interface-specific, not conditional
if let (Some(domain), Some(ns), true) = (
current_domain.take(),
current_nameserver.take(),
is_supplemental,
) {
if let Some(rule) = make_rule(&domain, &ns) {
rules.push(rule);
}
}
break;
}
}
// Emit last block
if let (Some(domain), Some(ns), true) = (current_domain, current_nameserver, is_supplemental) {
if let Some(rule) = make_rule(&domain, &ns) {
rules.push(rule);
}
}
// Sort longest suffix first for most-specific matching
rules.sort_by(|a, b| b.suffix.len().cmp(&a.suffix.len()));
for rule in &rules {
info!(
"auto-discovered forwarding: *.{} -> {}",
rule.suffix, rule.upstream
);
}
if rules.is_empty() {
debug!("no conditional forwarding rules discovered from scutil --dns");
}
rules
}
fn make_rule(domain: &str, nameserver: &str) -> Option<ForwardingRule> {
let addr: SocketAddr = format!("{}:53", nameserver).parse().ok()?;
Some(ForwardingRule {
dot_suffix: format!(".{}", domain),
suffix: domain.to_string(),
upstream: addr,
})
}
/// Find the upstream for a domain by checking forwarding rules.
/// Returns None if no rule matches (use default upstream).
/// Zero-allocation on the hot path — dot_suffix is pre-computed.
pub fn match_forwarding_rule(domain: &str, rules: &[ForwardingRule]) -> Option<SocketAddr> {
for rule in rules {
if domain == rule.suffix || domain.ends_with(&rule.dot_suffix) {
return Some(rule.upstream);
}
}
None
}