perf: optimize DNS query hot path #15

Merged
razvandimescu merged 2 commits from perf/hot-path-optimizations into main 2026-03-27 08:01:08 +08:00
13 changed files with 729 additions and 77 deletions
Showing only changes of commit aed0e095e1 - Show all commits

267
Cargo.lock generated
View File

@@ -17,6 +17,12 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.21" version = "0.6.21"
@@ -237,6 +243,12 @@ version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.57" version = "1.2.57"
@@ -261,6 +273,58 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "clap"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351"
dependencies = [
"clap_builder",
]
[[package]]
name = "clap_builder"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f"
dependencies = [
"anstyle",
"clap_lex",
]
[[package]]
name = "clap_lex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
[[package]] [[package]]
name = "cmake" name = "cmake"
version = "0.1.57" version = "0.1.57"
@@ -302,6 +366,73 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crunchy"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
[[package]] [[package]]
name = "data-encoding" name = "data-encoding"
version = "2.10.0" version = "2.10.0"
@@ -348,6 +479,12 @@ version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]] [[package]]
name = "env_filter" name = "env_filter"
version = "1.0.0" version = "1.0.0"
@@ -548,12 +685,29 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "half"
version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b"
dependencies = [
"cfg-if",
"crunchy",
"zerocopy",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.16.1" version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]] [[package]]
name = "http" name = "http"
version = "1.4.0" version = "1.4.0"
@@ -790,12 +944,32 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "is-terminal"
version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.2" version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itoa" name = "itoa"
version = "1.0.17" version = "1.0.17"
@@ -971,6 +1145,7 @@ version = "0.5.0"
dependencies = [ dependencies = [
"arc-swap", "arc-swap",
"axum", "axum",
"criterion",
"env_logger", "env_logger",
"futures", "futures",
"http-body-util", "http-body-util",
@@ -1010,6 +1185,12 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]] [[package]]
name = "pem" name = "pem"
version = "3.0.6" version = "3.0.6"
@@ -1038,6 +1219,34 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.13.1" version = "1.13.1"
@@ -1185,6 +1394,26 @@ dependencies = [
"getrandom 0.3.4", "getrandom 0.3.4",
] ]
[[package]]
name = "rayon"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]] [[package]]
name = "rcgen" name = "rcgen"
version = "0.13.2" version = "0.13.2"
@@ -1346,6 +1575,15 @@ version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.228" version = "1.0.228"
@@ -1589,6 +1827,16 @@ dependencies = [
"zerovec", "zerovec",
] ]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.11.0" version = "1.11.0"
@@ -1807,6 +2055,16 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]] [[package]]
name = "want" name = "want"
version = "0.3.1" version = "0.3.1"
@@ -1919,6 +2177,15 @@ dependencies = [
"rustls-pki-types", "rustls-pki-types",
] ]
[[package]]
name = "winapi-util"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "windows-link" name = "windows-link"
version = "0.2.1" version = "0.2.1"

View File

@@ -28,3 +28,14 @@ time = "0.3"
rustls = "0.23" rustls = "0.23"
tokio-rustls = "0.26" tokio-rustls = "0.26"
arc-swap = "1" arc-swap = "1"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
[[bench]]
name = "hot_path"
harness = false
[[bench]]
name = "throughput"
harness = false

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt check audit test clean deploy .PHONY: all build lint fmt check audit test bench clean deploy blog
all: lint build all: lint build
@@ -19,6 +19,17 @@ audit:
test: test:
cargo test cargo test
bench:
cargo bench
blog:
@mkdir -p site/blog
@for f in blog/*.md; do \
name=$$(basename "$$f" .md); \
pandoc "$$f" --template=site/blog-template.html -o "site/blog/$$name.html"; \
echo " $$f → site/blog/$$name.html"; \
done
clean: clean:
cargo clean cargo clean

87
bench/README.md Normal file
View File

@@ -0,0 +1,87 @@
# Benchmarks
Numa has two benchmark suites measuring different layers of performance.
## Micro-benchmarks (`benches/`, criterion)
Nanosecond-precision measurement of individual operations on the hot path.
No running server required — these are pure Rust unit-level benchmarks.
```sh
cargo bench # run all
cargo bench --bench hot_path # parse, serialize, cache, clone
cargo bench --bench throughput # pipeline QPS, buffer alloc
```
### What's measured
**hot_path** — individual operations:
| Benchmark | What it measures |
|-----------|-----------------|
| `buffer_parse` | Wire bytes → DnsPacket (typical response with 4 records) |
| `buffer_serialize` | DnsPacket → wire bytes |
| `packet_clone` | Full DnsPacket clone (what cache hit costs) |
| `cache_lookup_hit` | Cache lookup on a single-entry cache |
| `cache_lookup_hit_populated` | Cache lookup with 1000 entries |
| `cache_lookup_miss` | HashMap miss (baseline) |
| `cache_insert` | Insert into cache with packet clone |
| `round_trip_cached` | Full cached path: parse query → cache hit → serialize response |
**throughput** — pipeline capacity:
| Benchmark | What it measures |
|-----------|-----------------|
| `pipeline_throughput/N` | N cached queries end-to-end (parse → lookup → serialize) |
| `buffer_alloc` | BytePacketBuffer 4KB zero-init cost |
### Reading results
Criterion auto-compares against the previous run:
```
round_trip_cached time: [710.5 ns 715.2 ns 720.1 ns]
change: [-2.48% -1.85% -1.21%] (p = 0.00 < 0.05)
Performance has improved.
```
- The three values are [lower bound, estimate, upper bound] of the mean
- `change` shows the delta vs the last saved baseline
- HTML reports with charts: `target/criterion/report/index.html`
To save a named baseline for comparison:
```sh
cargo bench -- --save-baseline before
# ... make changes ...
cargo bench -- --baseline before
```
## End-to-end benchmark (`bench/dns-bench.sh`)
Real-world latency comparison using `dig` against a running Numa instance
and public resolvers. Measures millisecond-level latency including network I/O.
```sh
# Start Numa first (default port 15353 for testing)
python3 bench/dns-bench.sh [port] [rounds]
python3 bench/dns-bench.sh 15353 20 # default
```
### What's measured
- **Numa (cold)**: cache flushed before each query — measures upstream forwarding
- **Numa (cached)**: queries hit cache — measures local processing
- **System / Google / Cloudflare / Quad9**: public resolver comparison
Results saved to `bench/results.json`.
### When to use which
| Question | Use |
|----------|-----|
| Did my code change make parsing faster? | `cargo bench --bench hot_path` |
| Is the cached path still sub-microsecond? | `cargo bench --bench hot_path` (round_trip_cached) |
| How many queries/sec can we handle? | `cargo bench --bench throughput` |
| Is Numa still competitive with system resolver? | `bench/dns-bench.sh` |
| Did upstream forwarding regress? | `bench/dns-bench.sh` |

186
benches/hot_path.rs Normal file
View File

@@ -0,0 +1,186 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use std::net::Ipv4Addr;
use numa::buffer::BytePacketBuffer;
use numa::cache::DnsCache;
use numa::header::{DnsHeader, ResultCode};
use numa::packet::DnsPacket;
use numa::question::{DnsQuestion, QueryType};
use numa::record::DnsRecord;
fn make_response(domain: &str) -> DnsPacket {
let mut pkt = DnsPacket::new();
pkt.header = DnsHeader::new();
pkt.header.id = 0x1234;
pkt.header.response = true;
pkt.header.recursion_desired = true;
pkt.header.recursion_available = true;
pkt.header.rescode = ResultCode::NOERROR;
pkt.questions
.push(DnsQuestion::new(domain.to_string(), QueryType::A));
pkt.answers.push(DnsRecord::A {
domain: domain.to_string(),
addr: Ipv4Addr::new(93, 184, 216, 34),
ttl: 300,
});
// Typical response includes authority + additional records
pkt.authorities.push(DnsRecord::NS {
domain: domain.to_string(),
host: format!("ns1.{domain}"),
ttl: 172800,
});
pkt.authorities.push(DnsRecord::NS {
domain: domain.to_string(),
host: format!("ns2.{domain}"),
ttl: 172800,
});
pkt.resources.push(DnsRecord::A {
domain: format!("ns1.{domain}"),
addr: Ipv4Addr::new(198, 51, 100, 1),
ttl: 172800,
});
pkt
}
fn to_wire(pkt: &DnsPacket) -> Vec<u8> {
let mut buf = BytePacketBuffer::new();
pkt.write(&mut buf).unwrap();
buf.filled().to_vec()
}
fn bench_buffer_parse(c: &mut Criterion) {
let pkt = make_response("example.com");
let wire = to_wire(&pkt);
c.bench_function("buffer_parse", |b| {
b.iter(|| {
let mut buf = BytePacketBuffer::from_bytes(black_box(&wire));
DnsPacket::from_buffer(&mut buf).unwrap()
})
});
}
fn bench_buffer_serialize(c: &mut Criterion) {
let pkt = make_response("example.com");
c.bench_function("buffer_serialize", |b| {
b.iter(|| {
let mut buf = BytePacketBuffer::new();
black_box(&pkt).write(&mut buf).unwrap();
black_box(buf.pos());
})
});
}
fn bench_packet_clone(c: &mut Criterion) {
let pkt = make_response("example.com");
c.bench_function("packet_clone", |b| b.iter(|| black_box(&pkt).clone()));
}
fn bench_cache_lookup_hit(c: &mut Criterion) {
let mut cache = DnsCache::new(10_000, 60, 86400);
let pkt = make_response("example.com");
cache.insert("example.com", QueryType::A, &pkt);
c.bench_function("cache_lookup_hit", |b| {
b.iter(|| {
cache
.lookup(black_box("example.com"), QueryType::A)
.unwrap()
})
});
}
fn bench_cache_lookup_miss(c: &mut Criterion) {
let mut cache = DnsCache::new(10_000, 60, 86400);
c.bench_function("cache_lookup_miss", |b| {
b.iter(|| cache.lookup(black_box("nonexistent.com"), QueryType::A))
});
}
fn bench_cache_insert(c: &mut Criterion) {
let pkt = make_response("example.com");
c.bench_function("cache_insert", |b| {
let mut cache = DnsCache::new(10_000, 60, 86400);
let mut i = 0u64;
b.iter(|| {
let domain = format!("bench-{i}.example.com");
cache.insert(&domain, QueryType::A, black_box(&pkt));
i += 1;
// Reset cache periodically to avoid filling up
if i % 5000 == 0 {
cache.clear();
}
})
});
}
fn bench_round_trip(c: &mut Criterion) {
// Simulates the cached hot path: parse query → cache hit → serialize response
let query_pkt = {
let mut q = DnsPacket::new();
q.header.id = 0xABCD;
q.header.recursion_desired = true;
q.questions
.push(DnsQuestion::new("example.com".to_string(), QueryType::A));
q
};
let query_wire = to_wire(&query_pkt);
let response = make_response("example.com");
let mut cache = DnsCache::new(10_000, 60, 86400);
cache.insert("example.com", QueryType::A, &response);
c.bench_function("round_trip_cached", |b| {
b.iter(|| {
// 1. Parse incoming query
let mut buf = BytePacketBuffer::from_bytes(black_box(&query_wire));
let query = DnsPacket::from_buffer(&mut buf).unwrap();
let qname = &query.questions[0].name;
let qtype = query.questions[0].qtype;
// 2. Cache lookup
let mut resp = cache.lookup(qname, qtype).unwrap();
resp.header.id = query.header.id;
// 3. Serialize response
let mut resp_buf = BytePacketBuffer::new();
resp.write(&mut resp_buf).unwrap();
black_box(resp_buf.pos());
})
});
}
fn bench_cache_populated_lookup(c: &mut Criterion) {
// Benchmark with a realistically populated cache (1000 entries)
let mut cache = DnsCache::new(10_000, 60, 86400);
for i in 0..1000 {
let domain = format!("domain-{i}.example.com");
let pkt = make_response(&domain);
cache.insert(&domain, QueryType::A, &pkt);
}
c.bench_function("cache_lookup_hit_populated", |b| {
b.iter(|| {
cache
.lookup(black_box("domain-500.example.com"), QueryType::A)
.unwrap()
})
});
}
criterion_group!(
benches,
bench_buffer_parse,
bench_buffer_serialize,
bench_packet_clone,
bench_cache_lookup_hit,
bench_cache_lookup_miss,
bench_cache_insert,
bench_round_trip,
bench_cache_populated_lookup,
);
criterion_main!(benches);

94
benches/throughput.rs Normal file
View File

@@ -0,0 +1,94 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use std::net::Ipv4Addr;
use numa::buffer::BytePacketBuffer;
use numa::header::ResultCode;
use numa::packet::DnsPacket;
use numa::question::{DnsQuestion, QueryType};
use numa::record::DnsRecord;
fn make_query_wire(domain: &str) -> Vec<u8> {
let mut q = DnsPacket::new();
q.header.id = 0xABCD;
q.header.recursion_desired = true;
q.questions
.push(DnsQuestion::new(domain.to_string(), QueryType::A));
let mut buf = BytePacketBuffer::new();
q.write(&mut buf).unwrap();
buf.filled().to_vec()
}
fn make_response(domain: &str) -> DnsPacket {
let mut pkt = DnsPacket::new();
pkt.header.id = 0xABCD;
pkt.header.response = true;
pkt.header.recursion_desired = true;
pkt.header.recursion_available = true;
pkt.header.rescode = ResultCode::NOERROR;
pkt.questions
.push(DnsQuestion::new(domain.to_string(), QueryType::A));
pkt.answers.push(DnsRecord::A {
domain: domain.to_string(),
addr: Ipv4Addr::new(93, 184, 216, 34),
ttl: 300,
});
pkt
}
/// Simulates the complete cached query pipeline (sans network I/O):
/// parse → cache lookup → TTL adjust → serialize response
fn simulate_cached_pipeline(query_wire: &[u8], cache: &mut numa::cache::DnsCache) -> usize {
let mut buf = BytePacketBuffer::from_bytes(query_wire);
let query = DnsPacket::from_buffer(&mut buf).unwrap();
let q = &query.questions[0];
let mut resp = cache.lookup(&q.name, q.qtype).unwrap();
resp.header.id = query.header.id;
let mut resp_buf = BytePacketBuffer::new();
resp.write(&mut resp_buf).unwrap();
resp_buf.pos()
}
fn bench_pipeline_throughput(c: &mut Criterion) {
let domains: Vec<String> = (0..100)
.map(|i| format!("domain-{i}.example.com"))
.collect();
let mut cache = numa::cache::DnsCache::new(10_000, 60, 86400);
for d in &domains {
cache.insert(d, QueryType::A, &make_response(d));
}
let query_wires: Vec<Vec<u8>> = domains.iter().map(|d| make_query_wire(d)).collect();
let mut group = c.benchmark_group("pipeline_throughput");
for count in [1, 10, 100] {
group.throughput(Throughput::Elements(count));
group.bench_with_input(BenchmarkId::from_parameter(count), &count, |b, &count| {
let mut idx = 0usize;
b.iter(|| {
for _ in 0..count {
let wire = &query_wires[idx % query_wires.len()];
simulate_cached_pipeline(wire, &mut cache);
idx += 1;
}
});
});
}
group.finish();
}
/// Measures the overhead of BytePacketBuffer allocation + zero-init
fn bench_buffer_alloc(c: &mut Criterion) {
c.bench_function("buffer_alloc", |b| {
b.iter(|| {
let buf = BytePacketBuffer::new();
criterion::black_box(buf.pos());
})
});
}
criterion_group!(benches, bench_pipeline_throughput, bench_buffer_alloc,);
criterion_main!(benches);

View File

@@ -220,7 +220,7 @@ async fn create_overrides(
}) })
.collect::<Result<Vec<_>, (StatusCode, String)>>()?; .collect::<Result<Vec<_>, (StatusCode, String)>>()?;
let mut store = ctx.overrides.lock().unwrap(); let mut store = ctx.overrides.write().unwrap();
let mut responses = Vec::with_capacity(parsed.len()); let mut responses = Vec::with_capacity(parsed.len());
for (domain, target, ttl, duration_secs) in parsed { for (domain, target, ttl, duration_secs) in parsed {
@@ -241,7 +241,7 @@ async fn create_overrides(
} }
async fn list_overrides(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<OverrideResponse>> { async fn list_overrides(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<OverrideResponse>> {
let store = ctx.overrides.lock().unwrap(); let store = ctx.overrides.read().unwrap();
let entries: Vec<OverrideResponse> = store let entries: Vec<OverrideResponse> = store
.list() .list()
.into_iter() .into_iter()
@@ -254,7 +254,7 @@ async fn get_override(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>, Path(domain): Path<String>,
) -> Result<Json<OverrideResponse>, StatusCode> { ) -> Result<Json<OverrideResponse>, StatusCode> {
let store = ctx.overrides.lock().unwrap(); let store = ctx.overrides.read().unwrap();
let entry = store.get(&domain).ok_or(StatusCode::NOT_FOUND)?; let entry = store.get(&domain).ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(OverrideResponse::from(entry))) Ok(Json(OverrideResponse::from(entry)))
} }
@@ -263,7 +263,7 @@ async fn remove_override(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>, Path(domain): Path<String>,
) -> StatusCode { ) -> StatusCode {
let mut store = ctx.overrides.lock().unwrap(); let mut store = ctx.overrides.write().unwrap();
if store.remove(&domain) { if store.remove(&domain) {
StatusCode::NO_CONTENT StatusCode::NO_CONTENT
} else { } else {
@@ -272,7 +272,7 @@ async fn remove_override(
} }
async fn clear_overrides(State(ctx): State<Arc<ServerCtx>>) -> StatusCode { async fn clear_overrides(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.overrides.lock().unwrap().clear(); ctx.overrides.write().unwrap().clear();
StatusCode::NO_CONTENT StatusCode::NO_CONTENT
} }
@@ -280,7 +280,7 @@ async fn load_environment(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<EnvironmentRequest>, Json(req): Json<EnvironmentRequest>,
) -> Result<(StatusCode, Json<EnvironmentResponse>), (StatusCode, String)> { ) -> Result<(StatusCode, Json<EnvironmentResponse>), (StatusCode, String)> {
let mut store = ctx.overrides.lock().unwrap(); let mut store = ctx.overrides.write().unwrap();
for entry in &req.overrides { for entry in &req.overrides {
let duration = entry.duration_secs.or(req.duration_secs); let duration = entry.duration_secs.or(req.duration_secs);
@@ -307,7 +307,7 @@ async fn diagnose(
// Check overrides // Check overrides
{ {
let store = ctx.overrides.lock().unwrap(); let store = ctx.overrides.read().unwrap();
let entry = store.get(&domain_lower); let entry = store.get(&domain_lower);
steps.push(DiagnoseStep { steps.push(DiagnoseStep {
source: "override".to_string(), source: "override".to_string(),
@@ -319,7 +319,7 @@ async fn diagnose(
// Check blocklist // Check blocklist
{ {
let bl = ctx.blocklist.lock().unwrap(); let bl = ctx.blocklist.read().unwrap();
let blocked = bl.is_blocked(&domain_lower); let blocked = bl.is_blocked(&domain_lower);
steps.push(DiagnoseStep { steps.push(DiagnoseStep {
source: "blocklist".to_string(), source: "blocklist".to_string(),
@@ -345,7 +345,7 @@ async fn diagnose(
// Check cache // Check cache
{ {
let mut cache = ctx.cache.lock().unwrap(); let cache = ctx.cache.read().unwrap();
let cached = cache.lookup(&domain_lower, qtype); let cached = cache.lookup(&domain_lower, qtype);
steps.push(DiagnoseStep { steps.push(DiagnoseStep {
source: "cache".to_string(), source: "cache".to_string(),
@@ -443,11 +443,11 @@ async fn query_log(
async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> { async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
let snap = ctx.stats.lock().unwrap().snapshot(); let snap = ctx.stats.lock().unwrap().snapshot();
let (cache_len, cache_max) = { let (cache_len, cache_max) = {
let cache = ctx.cache.lock().unwrap(); let cache = ctx.cache.read().unwrap();
(cache.len(), cache.max_entries()) (cache.len(), cache.max_entries())
}; };
let override_count = ctx.overrides.lock().unwrap().active_count(); let override_count = ctx.overrides.read().unwrap().active_count();
let bl_stats = ctx.blocklist.lock().unwrap().stats(); let bl_stats = ctx.blocklist.read().unwrap().stats();
let upstream = ctx.upstream.lock().unwrap().to_string(); let upstream = ctx.upstream.lock().unwrap().to_string();
@@ -486,7 +486,7 @@ async fn stats(State(ctx): State<Arc<ServerCtx>>) -> Json<StatsResponse> {
} }
async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryResponse>> { async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryResponse>> {
let cache = ctx.cache.lock().unwrap(); let cache = ctx.cache.read().unwrap();
let entries: Vec<CacheEntryResponse> = cache let entries: Vec<CacheEntryResponse> = cache
.list() .list()
.into_iter() .into_iter()
@@ -500,7 +500,7 @@ async fn list_cache(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<CacheEntryRes
} }
async fn flush_cache(State(ctx): State<Arc<ServerCtx>>) -> StatusCode { async fn flush_cache(State(ctx): State<Arc<ServerCtx>>) -> StatusCode {
ctx.cache.lock().unwrap().clear(); ctx.cache.write().unwrap().clear();
StatusCode::NO_CONTENT StatusCode::NO_CONTENT
} }
@@ -508,7 +508,7 @@ async fn flush_cache_domain(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>, Path(domain): Path<String>,
) -> StatusCode { ) -> StatusCode {
ctx.cache.lock().unwrap().remove(&domain); ctx.cache.write().unwrap().remove(&domain);
StatusCode::NO_CONTENT StatusCode::NO_CONTENT
} }
@@ -519,7 +519,7 @@ async fn health() -> Json<serde_json::Value> {
// --- Blocking handlers --- // --- Blocking handlers ---
async fn blocking_stats(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> { async fn blocking_stats(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
let stats = ctx.blocklist.lock().unwrap().stats(); let stats = ctx.blocklist.read().unwrap().stats();
Json(serde_json::json!({ Json(serde_json::json!({
"enabled": stats.enabled, "enabled": stats.enabled,
"paused": stats.paused, "paused": stats.paused,
@@ -539,7 +539,7 @@ async fn blocking_toggle(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingToggleRequest>, Json(req): Json<BlockingToggleRequest>,
) -> Json<serde_json::Value> { ) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().set_enabled(req.enabled); ctx.blocklist.write().unwrap().set_enabled(req.enabled);
Json(serde_json::json!({ "enabled": req.enabled })) Json(serde_json::json!({ "enabled": req.enabled }))
} }
@@ -557,12 +557,12 @@ async fn blocking_pause(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<BlockingPauseRequest>, Json(req): Json<BlockingPauseRequest>,
) -> Json<serde_json::Value> { ) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().pause(req.minutes * 60); ctx.blocklist.write().unwrap().pause(req.minutes * 60);
Json(serde_json::json!({ "paused_minutes": req.minutes })) Json(serde_json::json!({ "paused_minutes": req.minutes }))
} }
async fn blocking_unpause(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> { async fn blocking_unpause(State(ctx): State<Arc<ServerCtx>>) -> Json<serde_json::Value> {
ctx.blocklist.lock().unwrap().unpause(); ctx.blocklist.write().unwrap().unpause();
Json(serde_json::json!({ "paused": false })) Json(serde_json::json!({ "paused": false }))
} }
@@ -570,12 +570,12 @@ async fn blocking_check(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>, Path(domain): Path<String>,
) -> Json<crate::blocklist::BlockCheckResult> { ) -> Json<crate::blocklist::BlockCheckResult> {
let result = ctx.blocklist.lock().unwrap().check(&domain); let result = ctx.blocklist.read().unwrap().check(&domain);
Json(result) Json(result)
} }
async fn blocking_allowlist(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<String>> { async fn blocking_allowlist(State(ctx): State<Arc<ServerCtx>>) -> Json<Vec<String>> {
let list = ctx.blocklist.lock().unwrap().allowlist(); let list = ctx.blocklist.read().unwrap().allowlist();
Json(list) Json(list)
} }
@@ -588,7 +588,7 @@ async fn blocking_allowlist_add(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Json(req): Json<AllowlistRequest>, Json(req): Json<AllowlistRequest>,
) -> (StatusCode, Json<serde_json::Value>) { ) -> (StatusCode, Json<serde_json::Value>) {
ctx.blocklist.lock().unwrap().add_to_allowlist(&req.domain); ctx.blocklist.write().unwrap().add_to_allowlist(&req.domain);
( (
StatusCode::CREATED, StatusCode::CREATED,
Json(serde_json::json!({ "allowed": req.domain })), Json(serde_json::json!({ "allowed": req.domain })),
@@ -599,7 +599,12 @@ async fn blocking_allowlist_remove(
State(ctx): State<Arc<ServerCtx>>, State(ctx): State<Arc<ServerCtx>>,
Path(domain): Path<String>, Path(domain): Path<String>,
) -> StatusCode { ) -> StatusCode {
if ctx.blocklist.lock().unwrap().remove_from_allowlist(&domain) { if ctx
.blocklist
.write()
.unwrap()
.remove_from_allowlist(&domain)
{
StatusCode::NO_CONTENT StatusCode::NO_CONTENT
} else { } else {
StatusCode::NOT_FOUND StatusCode::NOT_FOUND

View File

@@ -19,7 +19,6 @@ pub struct DnsCache {
max_entries: usize, max_entries: usize,
min_ttl: u32, min_ttl: u32,
max_ttl: u32, max_ttl: u32,
query_count: u64,
} }
impl DnsCache { impl DnsCache {
@@ -30,29 +29,16 @@ impl DnsCache {
max_entries, max_entries,
min_ttl, min_ttl,
max_ttl, max_ttl,
query_count: 0,
} }
} }
pub fn lookup(&mut self, domain: &str, qtype: QueryType) -> Option<DnsPacket> { /// Read-only lookup — expired entries are left in place (cleaned up on insert).
self.query_count += 1; pub fn lookup(&self, domain: &str, qtype: QueryType) -> Option<DnsPacket> {
if self.query_count.is_multiple_of(1000) {
self.evict_expired();
}
let type_map = self.entries.get(domain)?; let type_map = self.entries.get(domain)?;
let entry = type_map.get(&qtype)?; let entry = type_map.get(&qtype)?;
let elapsed = entry.inserted_at.elapsed(); let elapsed = entry.inserted_at.elapsed();
if elapsed >= entry.ttl { if elapsed >= entry.ttl {
// Expired: remove this entry
let type_map = self.entries.get_mut(domain).unwrap();
type_map.remove(&qtype);
self.entry_count -= 1;
if type_map.is_empty() {
self.entries.remove(domain);
}
return None; return None;
} }

View File

@@ -1,6 +1,6 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Mutex; use std::sync::{Mutex, RwLock};
use std::time::{Duration, Instant, SystemTime}; use std::time::{Duration, Instant, SystemTime};
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
@@ -27,10 +27,10 @@ use crate::system_dns::ForwardingRule;
pub struct ServerCtx { pub struct ServerCtx {
pub socket: UdpSocket, pub socket: UdpSocket,
pub zone_map: ZoneMap, pub zone_map: ZoneMap,
pub cache: Mutex<DnsCache>, pub cache: RwLock<DnsCache>,
pub stats: Mutex<ServerStats>, pub stats: Mutex<ServerStats>,
pub overrides: Mutex<OverrideStore>, pub overrides: RwLock<OverrideStore>,
pub blocklist: Mutex<BlocklistStore>, pub blocklist: RwLock<BlocklistStore>,
pub query_log: Mutex<QueryLog>, pub query_log: Mutex<QueryLog>,
pub services: Mutex<ServiceStore>, pub services: Mutex<ServiceStore>,
pub lan_peers: Mutex<PeerStore>, pub lan_peers: Mutex<PeerStore>,
@@ -73,7 +73,7 @@ pub async fn handle_query(
// Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream // Pipeline: overrides -> .tld interception -> blocklist -> local zones -> cache -> upstream
// Each lock is scoped to avoid holding MutexGuard across await points. // Each lock is scoped to avoid holding MutexGuard across await points.
let (response, path) = { let (response, path) = {
let override_record = ctx.overrides.lock().unwrap().lookup(&qname); let override_record = ctx.overrides.read().unwrap().lookup(&qname);
if let Some(record) = override_record { if let Some(record) = override_record {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
resp.answers.push(record); resp.answers.push(record);
@@ -116,7 +116,7 @@ pub async fn handle_query(
}), }),
} }
(resp, QueryPath::Local) (resp, QueryPath::Local)
} else if ctx.blocklist.lock().unwrap().is_blocked(&qname) { } else if ctx.blocklist.read().unwrap().is_blocked(&qname) {
let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR);
match qtype { match qtype {
QueryType::AAAA => resp.answers.push(DnsRecord::AAAA { QueryType::AAAA => resp.answers.push(DnsRecord::AAAA {
@@ -136,7 +136,7 @@ pub async fn handle_query(
resp.answers = records.clone(); resp.answers = records.clone();
(resp, QueryPath::Local) (resp, QueryPath::Local)
} else { } else {
let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype); let cached = ctx.cache.read().unwrap().lookup(&qname, qtype);
if let Some(cached) = cached { if let Some(cached) = cached {
let mut resp = cached; let mut resp = cached;
resp.header.id = query.header.id; resp.header.id = query.header.id;
@@ -149,7 +149,7 @@ pub async fn handle_query(
}; };
match forward_query(&query, &upstream, ctx.timeout).await { match forward_query(&query, &upstream, ctx.timeout).await {
Ok(resp) => { Ok(resp) => {
ctx.cache.lock().unwrap().insert(&qname, qtype, &resp); ctx.cache.write().unwrap().insert(&qname, qtype, &resp);
(resp, QueryPath::Forwarded) (resp, QueryPath::Forwarded)
} }
Err(e) => { Err(e) => {

View File

@@ -1,5 +1,5 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration; use std::time::Duration;
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
@@ -170,14 +170,14 @@ async fn main() -> numa::Result<()> {
let ctx = Arc::new(ServerCtx { let ctx = Arc::new(ServerCtx {
socket: UdpSocket::bind(&config.server.bind_addr).await?, socket: UdpSocket::bind(&config.server.bind_addr).await?,
zone_map: build_zone_map(&config.zones)?, zone_map: build_zone_map(&config.zones)?,
cache: Mutex::new(DnsCache::new( cache: RwLock::new(DnsCache::new(
config.cache.max_entries, config.cache.max_entries,
config.cache.min_ttl, config.cache.min_ttl,
config.cache.max_ttl, config.cache.max_ttl,
)), )),
stats: Mutex::new(ServerStats::new()), stats: Mutex::new(ServerStats::new()),
overrides: Mutex::new(OverrideStore::new()), overrides: RwLock::new(OverrideStore::new()),
blocklist: Mutex::new(blocklist), blocklist: RwLock::new(blocklist),
query_log: Mutex::new(QueryLog::new(1000)), query_log: Mutex::new(QueryLog::new(1000)),
services: Mutex::new(service_store), services: Mutex::new(service_store),
lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)), lan_peers: Mutex::new(numa::lan::PeerStore::new(config.lan.peer_timeout_secs)),
@@ -541,7 +541,7 @@ async fn load_blocklists(ctx: &ServerCtx, lists: &[String]) {
// Swap under lock — sub-microsecond // Swap under lock — sub-microsecond
ctx.blocklist ctx.blocklist
.lock() .write()
.unwrap() .unwrap()
.swap_domains(all_domains, sources); .swap_domains(all_domains, sources);
info!( info!(

View File

@@ -64,6 +64,9 @@ impl OverrideStore {
ttl: u32, ttl: u32,
duration_secs: Option<u64>, duration_secs: Option<u64>,
) -> Result<QueryType> { ) -> Result<QueryType> {
// Clean up expired entries on write
self.entries.retain(|_, e| !e.is_expired());
let domain_lower = domain.to_lowercase(); let domain_lower = domain.to_lowercase();
let (qtype, record) = parse_target(&domain_lower, target, ttl)?; let (qtype, record) = parse_target(&domain_lower, target, ttl)?;
@@ -84,10 +87,10 @@ impl OverrideStore {
} }
/// Hot path: assumes `domain` is already lowercased (the parser does this). /// Hot path: assumes `domain` is already lowercased (the parser does this).
pub fn lookup(&mut self, domain: &str) -> Option<DnsRecord> { /// Read-only — expired entries are left in place (cleaned up on write operations).
pub fn lookup(&self, domain: &str) -> Option<DnsRecord> {
let entry = self.entries.get(domain)?; let entry = self.entries.get(domain)?;
if entry.is_expired() { if entry.is_expired() {
self.entries.remove(domain);
return None; return None;
} }
Some(entry.record.clone()) Some(entry.record.clone())

View File

@@ -46,7 +46,7 @@ impl DnsPacket {
result.header.read(buffer)?; result.header.read(buffer)?;
for _ in 0..result.header.questions { for _ in 0..result.header.questions {
let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); let mut question = DnsQuestion::new(String::with_capacity(64), QueryType::UNKNOWN(0));
question.read(buffer)?; question.read(buffer)?;
result.questions.push(question); result.questions.push(question);
} }
@@ -68,34 +68,36 @@ impl DnsPacket {
} }
pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> {
// Filter out UNKNOWN records (e.g. EDNS OPT) that we can't re-serialize // Count known records without allocating filter Vecs
let answers: Vec<_> = self.answers.iter().filter(|r| !r.is_unknown()).collect(); let answer_count = self.answers.iter().filter(|r| !r.is_unknown()).count() as u16;
let authorities: Vec<_> = self let auth_count = self.authorities.iter().filter(|r| !r.is_unknown()).count() as u16;
.authorities let res_count = self.resources.iter().filter(|r| !r.is_unknown()).count() as u16;
.iter()
.filter(|r| !r.is_unknown())
.collect();
let resources: Vec<_> = self.resources.iter().filter(|r| !r.is_unknown()).collect();
let mut header = self.header.clone(); let mut header = self.header.clone();
header.questions = self.questions.len() as u16; header.questions = self.questions.len() as u16;
header.answers = answers.len() as u16; header.answers = answer_count;
header.authoritative_entries = authorities.len() as u16; header.authoritative_entries = auth_count;
header.resource_entries = resources.len() as u16; header.resource_entries = res_count;
header.write(buffer)?; header.write(buffer)?;
for question in &self.questions { for question in &self.questions {
question.write(buffer)?; question.write(buffer)?;
} }
for rec in answers { for rec in &self.answers {
rec.write(buffer)?; if !rec.is_unknown() {
rec.write(buffer)?;
}
} }
for rec in authorities { for rec in &self.authorities {
rec.write(buffer)?; if !rec.is_unknown() {
rec.write(buffer)?;
}
} }
for rec in resources { for rec in &self.resources {
rec.write(buffer)?; if !rec.is_unknown() {
rec.write(buffer)?;
}
} }
Ok(()) Ok(())

View File

@@ -70,7 +70,7 @@ impl DnsRecord {
} }
pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> { pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> {
let mut domain = String::new(); let mut domain = String::with_capacity(64);
buffer.read_qname(&mut domain)?; buffer.read_qname(&mut domain)?;
let qtype_num = buffer.read_u16()?; let qtype_num = buffer.read_u16()?;
@@ -110,7 +110,7 @@ impl DnsRecord {
Ok(DnsRecord::AAAA { domain, addr, ttl }) Ok(DnsRecord::AAAA { domain, addr, ttl })
} }
QueryType::NS => { QueryType::NS => {
let mut ns = String::new(); let mut ns = String::with_capacity(64);
buffer.read_qname(&mut ns)?; buffer.read_qname(&mut ns)?;
Ok(DnsRecord::NS { Ok(DnsRecord::NS {
@@ -120,7 +120,7 @@ impl DnsRecord {
}) })
} }
QueryType::CNAME => { QueryType::CNAME => {
let mut cname = String::new(); let mut cname = String::with_capacity(64);
buffer.read_qname(&mut cname)?; buffer.read_qname(&mut cname)?;
Ok(DnsRecord::CNAME { Ok(DnsRecord::CNAME {
@@ -131,7 +131,7 @@ impl DnsRecord {
} }
QueryType::MX => { QueryType::MX => {
let priority = buffer.read_u16()?; let priority = buffer.read_u16()?;
let mut mx = String::new(); let mut mx = String::with_capacity(64);
buffer.read_qname(&mut mx)?; buffer.read_qname(&mut mx)?;
Ok(DnsRecord::MX { Ok(DnsRecord::MX {