From 7327a96e82952592deedd7074474fd834143712e Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Tue, 10 Mar 2026 04:50:16 +0200 Subject: [PATCH 1/4] refactor to async tokio with modular architecture - Replace synchronous std::net::UdpSocket with tokio async runtime - Spawn concurrent task per incoming DNS query via tokio::spawn - Extract monolithic main.rs into modules: buffer, header, question, record, packet, config, cache, forward, stats - Share state across tasks via Arc with scoped Mutex locks - Add TOML config loading, TTL-aware cache, structured logging, stats - Add CLAUDE.md, README, dns_fun.toml config, and design docs Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 59 ++++ Cargo.lock | 425 ++++++++++++++++++++++ Cargo.toml | 5 + README.md | 121 +++++++ dns_fun.toml | 25 ++ src/buffer.rs | 181 ++++++++++ src/cache.rs | 95 +++++ src/config.rs | 169 +++++++++ src/forward.rs | 27 ++ src/header.rs | 133 +++++++ src/lib.rs | 12 + src/main.rs | 917 +++++++----------------------------------------- src/packet.rs | 105 ++++++ src/question.rs | 64 ++++ src/record.rs | 249 +++++++++++++ src/stats.rs | 79 +++++ 16 files changed, 1879 insertions(+), 787 deletions(-) create mode 100644 CLAUDE.md create mode 100644 Cargo.lock create mode 100644 dns_fun.toml create mode 100644 src/buffer.rs create mode 100644 src/cache.rs create mode 100644 src/config.rs create mode 100644 src/forward.rs create mode 100644 src/header.rs create mode 100644 src/lib.rs create mode 100644 src/packet.rs create mode 100644 src/question.rs create mode 100644 src/record.rs create mode 100644 src/stats.rs diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..6079f31 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,59 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +A DNS forwarding/caching proxy in Rust. Serves local zone records from TOML config, caches upstream responses with TTL-based expiration, forwards unknown queries to an upstream resolver, and logs all queries with structured output. + +## Build & Run + +```bash +cargo build # compile +sudo cargo run # run with default config (dns_fun.toml) +sudo cargo run -- path/to/config # run with custom config path +RUST_LOG=debug sudo cargo run # verbose logging +``` + +Test with: `dig @127.0.0.1 google.com` + +No tests or linter configured. + +## Architecture + +``` +src/ + lib.rs # module declarations, Error/Result type aliases + main.rs # startup, config load, UDP listen loop, request pipeline + buffer.rs # BytePacketBuffer — 512-byte DNS wire format read/write + header.rs # DnsHeader, ResultCode — 12-byte header bitfield parsing + question.rs # DnsQuestion, QueryType — query section (A, NS, CNAME, MX, AAAA) + record.rs # DnsRecord — resource record variants with read/write + packet.rs # DnsPacket — top-level: header + questions + answers + authorities + resources + config.rs # Config loading from TOML, zone map builder + cache.rs # DnsCache — TTL-aware cache with lazy eviction + forward.rs # forward_query() — sends query to upstream, build_servfail() — error response + stats.rs # ServerStats — query counters and periodic summary +``` + +## Request Pipeline + +``` +Query → Parse → Log → Local Zones → Cache → Upstream Forward (+ cache result) → Log → Respond +``` + +## Config + +`dns_fun.toml` at project root. Sections: `[server]`, `[upstream]`, `[cache]`, `[[zones]]`. Falls back to sensible defaults if file is missing. + +## Logging + +Controlled via `RUST_LOG` env var. Default level: `info` (one structured line per query). `debug` adds response details. Stats summary every 1000 queries. + +## Key Details + +- Rust 2018 edition, deps: `serde`, `toml`, `log`, `env_logger` +- DNS packet size limited to 512 bytes (standard UDP DNS) +- `BytePacketBuffer::read_qname` handles label compression (pointer jumps) +- `type Error = Box` / `type Result` aliased in `lib.rs` +- Cache: TTL clamped between `min_ttl` and `max_ttl`, lazy eviction every 1000 queries diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..208c6e5 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,425 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "dns_fun" +version = "0.1.0" +dependencies = [ + "env_logger", + "log", + "serde", + "tokio", + "toml", +] + +[[package]] +name = "env_filter" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "jiff" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] diff --git a/Cargo.toml b/Cargo.toml index f02037f..5e7a2d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,8 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] } +toml = "0.8" +serde = { version = "1", features = ["derive"] } +log = "0.4" +env_logger = "0.11" diff --git a/README.md b/README.md index 590d19d..26234fe 100644 --- a/README.md +++ b/README.md @@ -1 +1,122 @@ # dns_fun + +A DNS forwarding/caching proxy written from scratch in Rust. Parses and serializes DNS wire protocol (RFC 1035), serves local zone records from TOML config, caches upstream responses with TTL-aware expiration, and logs every query with structured output. + +No async runtime, no DNS libraries — just `std::net::UdpSocket` and manual packet parsing. + +## Record Types + +A, NS, CNAME, MX, AAAA + +## Usage + +```bash +# Run with default config (dns_fun.toml) +sudo cargo run + +# Run with custom config path +sudo cargo run -- path/to/config.toml + +# Test +dig @127.0.0.1 google.com +dig @127.0.0.1 mysite.local +``` + +Requires root/sudo for binding to port 53. + +## Configuration + +Edit `dns_fun.toml`: + +```toml +[server] +bind_addr = "0.0.0.0:53" + +[upstream] +address = "8.8.8.8" +port = 53 +timeout_ms = 3000 + +[cache] +max_entries = 10000 +min_ttl = 60 # floor: cache at least 60s +max_ttl = 86400 # ceiling: never cache longer than 24h + +[[zones]] +domain = "mysite.local" +record_type = "A" +value = "127.0.0.1" +ttl = 60 + +[[zones]] +domain = "other.local" +record_type = "AAAA" +value = "::1" +ttl = 120 +``` + +All sections are optional — sensible defaults are used if the config file is missing. + +## Request Pipeline + +``` +Query -> Parse -> Local Zones -> Cache -> Upstream Forward -> Respond +``` + +1. **Local zones** — match against records defined in `[[zones]]`, respond immediately +2. **Cache** — return TTL-adjusted cached response if available +3. **Forward** — send query to upstream resolver, cache the response +4. **SERVFAIL** — returned to client on upstream failure + +## Caching + +- TTL derived from minimum TTL across answer records +- Clamped to configured `min_ttl`/`max_ttl` bounds +- TTLs in cached responses decrease over time (adjusted on serve) +- Lazy eviction on capacity overflow + periodic sweep every 1000 queries + +## Logging + +Controlled via `RUST_LOG` environment variable: + +```bash +RUST_LOG=info sudo cargo run # default — one line per query +RUST_LOG=debug sudo cargo run # includes response details +RUST_LOG=warn sudo cargo run # errors only +``` + +Log output: + +``` +2026-03-10T14:23:01.123Z INFO 192.168.1.5:41234 | A google.com | FORWARD | NOERROR | 12ms +2026-03-10T14:23:01.456Z INFO 192.168.1.5:41235 | A mysite.local | LOCAL | NOERROR | 0ms +2026-03-10T14:23:02.789Z INFO 192.168.1.5:41236 | A google.com | CACHED | NOERROR | 0ms +``` + +Stats summary (total, forwarded, cached, local, blocked, errors) logged every 1000 queries. + +## Project Structure + +``` +src/ + main.rs # startup, config load, UDP listen loop, request pipeline + lib.rs # module declarations, Error/Result type aliases + buffer.rs # BytePacketBuffer — 512-byte DNS wire format read/write + header.rs # DnsHeader, ResultCode + question.rs # DnsQuestion, QueryType + record.rs # DnsRecord (A, NS, CNAME, MX, AAAA, UNKNOWN) + packet.rs # DnsPacket — full DNS message parse/serialize + config.rs # TOML config loading, zone map builder + cache.rs # TTL-aware DNS response cache with lazy eviction + forward.rs # upstream forwarding, SERVFAIL builder + stats.rs # query counters and periodic summary +``` + +## Dependencies + +```toml +toml = "0.8" +serde = { version = "1", features = ["derive"] } +log = "0.4" +env_logger = "0.11" +``` diff --git a/dns_fun.toml b/dns_fun.toml new file mode 100644 index 0000000..68355e0 --- /dev/null +++ b/dns_fun.toml @@ -0,0 +1,25 @@ +[server] +bind_addr = "0.0.0.0:53" + +[upstream] +address = "8.8.8.8" +port = 53 +timeout_ms = 3000 + +[cache] +max_entries = 10000 +min_ttl = 60 +max_ttl = 86400 + +# Example zone records: +# [[zones]] +# domain = "dimescu.ro" +# record_type = "A" +# value = "3.120.139.105" +# ttl = 30 + +# [[zones]] +# domain = "test.local" +# record_type = "A" +# value = "127.0.0.1" +# ttl = 60 diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..5c82f0e --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,181 @@ +use crate::{Result}; + +pub struct BytePacketBuffer { + pub buf: [u8; 512], + pub pos: usize, +} + +impl BytePacketBuffer { + pub fn new() -> BytePacketBuffer { + BytePacketBuffer { + buf: [0; 512], + pos: 0, + } + } + + pub fn pos(&self) -> usize { + self.pos + } + + pub fn filled(&self) -> &[u8] { + &self.buf[..self.pos] + } + + pub fn step(&mut self, steps: usize) -> Result<()> { + self.pos += steps; + Ok(()) + } + + pub fn seek(&mut self, pos: usize) -> Result<()> { + self.pos = pos; + Ok(()) + } + + pub fn read(&mut self) -> Result { + if self.pos >= 512 { + return Err("End of buffer".into()); + } + let res = self.buf[self.pos]; + self.pos += 1; + Ok(res) + } + + pub fn get(&self, pos: usize) -> Result { + if pos >= 512 { + return Err("End of buffer".into()); + } + Ok(self.buf[pos]) + } + + pub fn get_range(&self, start: usize, len: usize) -> Result<&[u8]> { + if start + len > 512 { + return Err("End of buffer".into()); + } + Ok(&self.buf[start..start + len]) + } + + pub fn read_u16(&mut self) -> Result { + let res = ((self.read()? as u16) << 8) | (self.read()? as u16); + Ok(res) + } + + pub fn read_u32(&mut self) -> Result { + let res = ((self.read()? as u32) << 24) + | ((self.read()? as u32) << 16) + | ((self.read()? as u32) << 8) + | ((self.read()? as u32) << 0); + Ok(res) + } + + /// Read a qname, handling label compression (pointer jumps). + /// Converts wire format like [3]www[6]google[3]com[0] into "www.google.com". + pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> { + let mut pos = self.pos(); + let mut jumped = false; + let max_jumps = 5; + let mut jumps_performed = 0; + let mut delim = ""; + + loop { + if jumps_performed > max_jumps { + return Err(format!("Limit of {} jumps exceeded", max_jumps).into()); + } + + let len = self.get(pos)?; + + if (len & 0xC0) == 0xC0 { + if !jumped { + self.seek(pos + 2)?; + } + + let b2 = self.get(pos + 1)? as u16; + let offset = (((len as u16) ^ 0xC0) << 8) | b2; + pos = offset as usize; + + jumped = true; + jumps_performed += 1; + continue; + } else { + pos += 1; + + if len == 0 { + break; + } + + outstr.push_str(delim); + + let str_buffer = self.get_range(pos, len as usize)?; + for &b in str_buffer { + outstr.push(b.to_ascii_lowercase() as char); + } + + delim = "."; + pos += len as usize; + } + } + + if !jumped { + self.seek(pos)?; + } + + Ok(()) + } + + pub fn write(&mut self, val: u8) -> Result<()> { + if self.pos >= 512 { + return Err("End of buffer".into()); + } + self.buf[self.pos] = val; + self.pos += 1; + Ok(()) + } + + pub fn write_u8(&mut self, val: u8) -> Result<()> { + self.write(val) + } + + pub fn write_u16(&mut self, val: u16) -> Result<()> { + self.write((val >> 8) as u8)?; + self.write((val & 0xFF) as u8)?; + Ok(()) + } + + pub fn write_u32(&mut self, val: u32) -> Result<()> { + self.write(((val >> 24) & 0xFF) as u8)?; + self.write(((val >> 16) & 0xFF) as u8)?; + self.write(((val >> 8) & 0xFF) as u8)?; + self.write(((val >> 0) & 0xFF) as u8)?; + Ok(()) + } + + pub fn write_qname(&mut self, qname: &str) -> Result<()> { + for label in qname.split('.') { + let len = label.len(); + if len > 0x3f { + return Err("Single label exceeds 63 characters of length".into()); + } + + self.write_u8(len as u8)?; + for b in label.as_bytes() { + self.write_u8(*b)?; + } + } + + self.write_u8(0)?; + Ok(()) + } + + pub fn set(&mut self, pos: usize, val: u8) -> Result<()> { + if pos >= 512 { + return Err("End of buffer".into()); + } + self.buf[pos] = val; + Ok(()) + } + + pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> { + self.set(pos, (val >> 8) as u8)?; + self.set(pos + 1, (val & 0xFF) as u8)?; + Ok(()) + } +} diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..6dd2e45 --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use crate::packet::DnsPacket; +use crate::question::QueryType; +use crate::record::DnsRecord; + +struct CacheEntry { + packet: DnsPacket, + inserted_at: Instant, + ttl: Duration, +} + +pub struct DnsCache { + entries: HashMap<(String, QueryType), CacheEntry>, + max_entries: usize, + min_ttl: u32, + max_ttl: u32, + query_count: u64, +} + +impl DnsCache { + pub fn new(max_entries: usize, min_ttl: u32, max_ttl: u32) -> Self { + DnsCache { + entries: HashMap::new(), + max_entries, + min_ttl, + max_ttl, + query_count: 0, + } + } + + pub fn lookup(&mut self, domain: &str, qtype: QueryType) -> Option { + self.query_count += 1; + + // Periodic eviction every 1000 queries + if self.query_count % 1000 == 0 { + self.evict_expired(); + } + + let key = (domain.to_string(), qtype); + let entry = self.entries.get(&key)?; + + let elapsed = entry.inserted_at.elapsed(); + if elapsed >= entry.ttl { + self.entries.remove(&key); + return None; + } + + let remaining_secs = (entry.ttl - elapsed).as_secs() as u32; + let remaining = remaining_secs.max(1); + + let mut packet = entry.packet.clone(); + adjust_ttls(&mut packet.answers, remaining); + adjust_ttls(&mut packet.authorities, remaining); + adjust_ttls(&mut packet.resources, remaining); + + Some(packet) + } + + pub fn insert(&mut self, domain: &str, qtype: QueryType, packet: &DnsPacket) { + if self.entries.len() >= self.max_entries { + self.evict_expired(); + // If still full after eviction, skip insertion + if self.entries.len() >= self.max_entries { + return; + } + } + + let min_ttl = extract_min_ttl(&packet.answers) + .unwrap_or(self.min_ttl) + .clamp(self.min_ttl, self.max_ttl); + + let key = (domain.to_string(), qtype); + self.entries.insert(key, CacheEntry { + packet: packet.clone(), + inserted_at: Instant::now(), + ttl: Duration::from_secs(min_ttl as u64), + }); + } + + fn evict_expired(&mut self) { + self.entries.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl); + } +} + +fn extract_min_ttl(records: &[DnsRecord]) -> Option { + records.iter().map(|r| r.ttl()).min() +} + +fn adjust_ttls(records: &mut [DnsRecord], new_ttl: u32) { + for record in records.iter_mut() { + record.set_ttl(new_ttl); + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..3d13f23 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,169 @@ +use std::collections::HashMap; +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; +use std::path::Path; + +use serde::Deserialize; + +use crate::question::QueryType; +use crate::record::DnsRecord; +use crate::Result; + +#[derive(Deserialize)] +pub struct Config { + #[serde(default)] + pub server: ServerConfig, + #[serde(default)] + pub upstream: UpstreamConfig, + #[serde(default)] + pub cache: CacheConfig, + #[serde(default)] + pub zones: Vec, +} + +#[derive(Deserialize)] +pub struct ServerConfig { + #[serde(default = "default_bind_addr")] + pub bind_addr: String, +} + +impl Default for ServerConfig { + fn default() -> Self { + ServerConfig { + bind_addr: default_bind_addr(), + } + } +} + +fn default_bind_addr() -> String { + "0.0.0.0:53".to_string() +} + +#[derive(Deserialize)] +pub struct UpstreamConfig { + #[serde(default = "default_upstream_addr")] + pub address: String, + #[serde(default = "default_upstream_port")] + pub port: u16, + #[serde(default = "default_timeout_ms")] + pub timeout_ms: u64, +} + +impl Default for UpstreamConfig { + fn default() -> Self { + UpstreamConfig { + address: default_upstream_addr(), + port: default_upstream_port(), + timeout_ms: default_timeout_ms(), + } + } +} + +fn default_upstream_addr() -> String { + "8.8.8.8".to_string() +} +fn default_upstream_port() -> u16 { + 53 +} +fn default_timeout_ms() -> u64 { + 3000 +} + +#[derive(Deserialize)] +pub struct CacheConfig { + #[serde(default = "default_max_entries")] + pub max_entries: usize, + #[serde(default = "default_min_ttl")] + pub min_ttl: u32, + #[serde(default = "default_max_ttl")] + pub max_ttl: u32, +} + +impl Default for CacheConfig { + fn default() -> Self { + CacheConfig { + max_entries: default_max_entries(), + min_ttl: default_min_ttl(), + max_ttl: default_max_ttl(), + } + } +} + +fn default_max_entries() -> usize { + 10000 +} +fn default_min_ttl() -> u32 { + 60 +} +fn default_max_ttl() -> u32 { + 86400 +} + +#[derive(Deserialize)] +pub struct ZoneRecord { + pub domain: String, + pub record_type: String, + pub value: String, + #[serde(default = "default_zone_ttl")] + pub ttl: u32, +} + +fn default_zone_ttl() -> u32 { + 300 +} + +pub fn load_config(path: &str) -> Result { + if !Path::new(path).exists() { + return Ok(Config { + server: ServerConfig::default(), + upstream: UpstreamConfig::default(), + cache: CacheConfig::default(), + zones: Vec::new(), + }); + } + let contents = std::fs::read_to_string(path)?; + let config: Config = toml::from_str(&contents)?; + Ok(config) +} + +pub fn build_zone_map(zones: &[ZoneRecord]) -> Result>> { + let mut map: HashMap<(String, QueryType), Vec> = HashMap::new(); + + for zone in zones { + let domain = zone.domain.to_lowercase(); + let (qtype, record) = match zone.record_type.to_uppercase().as_str() { + "A" => { + let addr: Ipv4Addr = zone.value.parse() + .map_err(|e| format!("invalid A record value '{}': {}", zone.value, e))?; + (QueryType::A, DnsRecord::A { domain: domain.clone(), addr, ttl: zone.ttl }) + } + "AAAA" => { + let addr: Ipv6Addr = zone.value.parse() + .map_err(|e| format!("invalid AAAA record value '{}': {}", zone.value, e))?; + (QueryType::AAAA, DnsRecord::AAAA { domain: domain.clone(), addr, ttl: zone.ttl }) + } + "CNAME" => { + (QueryType::CNAME, DnsRecord::CNAME { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl }) + } + "NS" => { + (QueryType::NS, DnsRecord::NS { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl }) + } + "MX" => { + let parts: Vec<&str> = zone.value.splitn(2, ' ').collect(); + if parts.len() != 2 { + return Err(format!("MX value must be 'priority host', got '{}'", zone.value).into()); + } + let priority: u16 = parts[0].parse() + .map_err(|e| format!("invalid MX priority '{}': {}", parts[0], e))?; + (QueryType::MX, DnsRecord::MX { domain: domain.clone(), priority, host: parts[1].to_string(), ttl: zone.ttl }) + } + other => { + return Err(format!("unsupported record type '{}'", other).into()); + } + }; + + map.entry((domain, qtype)).or_default().push(record); + } + + Ok(map) +} diff --git a/src/forward.rs b/src/forward.rs new file mode 100644 index 0000000..63de394 --- /dev/null +++ b/src/forward.rs @@ -0,0 +1,27 @@ +use std::net::SocketAddr; +use std::time::Duration; + +use tokio::net::UdpSocket; +use tokio::time::timeout; + +use crate::buffer::BytePacketBuffer; +use crate::packet::DnsPacket; +use crate::Result; + +pub async fn forward_query( + query: &DnsPacket, + upstream: SocketAddr, + timeout_duration: Duration, +) -> Result { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + + let mut send_buffer = BytePacketBuffer::new(); + query.write(&mut send_buffer)?; + + socket.send_to(send_buffer.filled(), upstream).await?; + + let mut recv_buffer = BytePacketBuffer::new(); + timeout(timeout_duration, socket.recv_from(&mut recv_buffer.buf)).await??; + + DnsPacket::from_buffer(&mut recv_buffer) +} diff --git a/src/header.rs b/src/header.rs new file mode 100644 index 0000000..2ce42a1 --- /dev/null +++ b/src/header.rs @@ -0,0 +1,133 @@ +use crate::buffer::BytePacketBuffer; +use crate::Result; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ResultCode { + NOERROR = 0, + FORMERR = 1, + SERVFAIL = 2, + NXDOMAIN = 3, + NOTIMP = 4, + REFUSED = 5, +} + +impl ResultCode { + pub fn from_num(num: u8) -> ResultCode { + match num { + 1 => ResultCode::FORMERR, + 2 => ResultCode::SERVFAIL, + 3 => ResultCode::NXDOMAIN, + 4 => ResultCode::NOTIMP, + 5 => ResultCode::REFUSED, + 0 | _ => ResultCode::NOERROR, + } + } + + pub fn as_str(&self) -> &'static str { + match self { + ResultCode::NOERROR => "NOERROR", + ResultCode::FORMERR => "FORMERR", + ResultCode::SERVFAIL => "SERVFAIL", + ResultCode::NXDOMAIN => "NXDOMAIN", + ResultCode::NOTIMP => "NOTIMP", + ResultCode::REFUSED => "REFUSED", + } + } +} + +#[derive(Clone, Debug)] +pub struct DnsHeader { + pub id: u16, + + pub recursion_desired: bool, + pub truncated_message: bool, + pub authoritative_answer: bool, + pub opcode: u8, + pub response: bool, + + pub rescode: ResultCode, + pub checking_disabled: bool, + pub authed_data: bool, + pub z: bool, + pub recursion_available: bool, + + pub questions: u16, + pub answers: u16, + pub authoritative_entries: u16, + pub resource_entries: u16, +} + +impl DnsHeader { + pub fn new() -> DnsHeader { + DnsHeader { + id: 0, + recursion_desired: false, + truncated_message: false, + authoritative_answer: false, + opcode: 0, + response: false, + rescode: ResultCode::NOERROR, + checking_disabled: false, + authed_data: false, + z: false, + recursion_available: false, + questions: 0, + answers: 0, + authoritative_entries: 0, + resource_entries: 0, + } + } + + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { + self.id = buffer.read_u16()?; + + let flags = buffer.read_u16()?; + let a = (flags >> 8) as u8; + let b = (flags & 0xFF) as u8; + self.recursion_desired = (a & (1 << 0)) > 0; + self.truncated_message = (a & (1 << 1)) > 0; + self.authoritative_answer = (a & (1 << 2)) > 0; + self.opcode = (a >> 3) & 0x0F; + self.response = (a & (1 << 7)) > 0; + + self.rescode = ResultCode::from_num(b & 0x0F); + self.checking_disabled = (b & (1 << 4)) > 0; + self.authed_data = (b & (1 << 5)) > 0; + self.z = (b & (1 << 6)) > 0; + self.recursion_available = (b & (1 << 7)) > 0; + + self.questions = buffer.read_u16()?; + self.answers = buffer.read_u16()?; + self.authoritative_entries = buffer.read_u16()?; + self.resource_entries = buffer.read_u16()?; + + Ok(()) + } + + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { + buffer.write_u16(self.id)?; + + buffer.write_u8( + (self.recursion_desired as u8) + | ((self.truncated_message as u8) << 1) + | ((self.authoritative_answer as u8) << 2) + | (self.opcode << 3) + | ((self.response as u8) << 7) as u8, + )?; + + buffer.write_u8( + (self.rescode as u8) + | ((self.checking_disabled as u8) << 4) + | ((self.authed_data as u8) << 5) + | ((self.z as u8) << 6) + | ((self.recursion_available as u8) << 7), + )?; + + buffer.write_u16(self.questions)?; + buffer.write_u16(self.answers)?; + buffer.write_u16(self.authoritative_entries)?; + buffer.write_u16(self.resource_entries)?; + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..e43c565 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,12 @@ +pub mod buffer; +pub mod cache; +pub mod config; +pub mod forward; +pub mod header; +pub mod packet; +pub mod question; +pub mod record; +pub mod stats; + +pub type Error = Box; +pub type Result = std::result::Result; diff --git a/src/main.rs b/src/main.rs index 28340a8..e2624c5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,801 +1,144 @@ -use std::fs::File; -use std::io::Read; -use std::net::Ipv4Addr; -use std::net::Ipv6Addr; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; -type Error = Box; -type Result = std::result::Result; +use log::{debug, error, info, warn}; +use tokio::net::UdpSocket; -pub struct BytePacketBuffer { - pub buf: [u8; 512], - pub pos: usize, +use dns_fun::buffer::BytePacketBuffer; +use dns_fun::cache::DnsCache; +use dns_fun::config::{build_zone_map, load_config}; +use dns_fun::forward::forward_query; +use dns_fun::header::ResultCode; +use dns_fun::packet::DnsPacket; +use dns_fun::question::QueryType; +use dns_fun::record::DnsRecord; +use dns_fun::stats::{QueryPath, ServerStats}; + +struct ServerCtx { + socket: Arc, + zone_map: HashMap<(String, QueryType), Vec>, + cache: Mutex, + stats: Mutex, + upstream: SocketAddr, + timeout: Duration, } -impl BytePacketBuffer { +#[tokio::main] +async fn main() -> dns_fun::Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) + .format_timestamp_millis() + .init(); + + let config_path = std::env::args().nth(1).unwrap_or_else(|| "dns_fun.toml".to_string()); + let config = load_config(&config_path)?; + + let upstream: SocketAddr = format!("{}:{}", config.upstream.address, config.upstream.port).parse()?; + let socket = Arc::new(UdpSocket::bind(&config.server.bind_addr).await?); + + let ctx = Arc::new(ServerCtx { + socket: Arc::clone(&socket), + zone_map: build_zone_map(&config.zones)?, + cache: Mutex::new(DnsCache::new( + config.cache.max_entries, + config.cache.min_ttl, + config.cache.max_ttl, + )), + stats: Mutex::new(ServerStats::new()), + upstream, + timeout: Duration::from_millis(config.upstream.timeout_ms), + }); + + info!( + "dns_fun starting on {}, upstream {}, {} zone records, cache max {}", + config.server.bind_addr, + upstream, + ctx.zone_map.len(), + config.cache.max_entries, + ); - /// This gives us a fresh buffer for holding the packet contents, and a - /// field for keeping track of where we are. - pub fn new() -> BytePacketBuffer { - BytePacketBuffer { - buf: [0; 512], - pos: 0, - } - } - - /// Current position within buffer - fn pos(&self) -> usize { - self.pos - } - - /// Step the buffer position forward a specific number of steps - fn step(&mut self, steps: usize) -> Result<()> { - self.pos += steps; - - Ok(()) - } - - /// Change the buffer position - fn seek(&mut self, pos: usize) -> Result<()> { - self.pos = pos; - - Ok(()) - } - - /// Read a single byte and move the position one step forward - fn read(&mut self) -> Result { - if self.pos >= 512 { - return Err("End of buffer".into()); - } - let res = self.buf[self.pos]; - self.pos += 1; - - Ok(res) - } - - /// Get a single byte, without changing the buffer position - fn get(&mut self, pos: usize) -> Result { - if pos >= 512 { - return Err("End of buffer".into()); - } - Ok(self.buf[pos]) - } - - /// Get a range of bytes - fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { - if start + len >= 512 { - return Err("End of buffer".into()); - } - Ok(&self.buf[start..start + len as usize]) - } - - /// Read two bytes, stepping two steps forward - fn read_u16(&mut self) -> Result { - let res = ((self.read()? as u16) << 8) | (self.read()? as u16); - - Ok(res) - } - - /// Read four bytes, stepping four steps forward - fn read_u32(&mut self) -> Result { - let res = ((self.read()? as u32) << 24) - | ((self.read()? as u32) << 16) - | ((self.read()? as u32) << 8) - | ((self.read()? as u32) << 0); - - Ok(res) - } - - - /// Read a qname - /// - /// The tricky part: Reading domain names, taking labels into consideration. - /// Will take something like [3]www[6]google[3]com[0] and append - /// www.google.com to outstr. - fn read_qname(&mut self, outstr: &mut String) -> Result<()> { - // Since we might encounter jumps, we'll keep track of our position - // locally as opposed to using the position within the struct. This - // allows us to move the shared position to a point past our current - // qname, while keeping track of our progress on the current qname - // using this variable. - let mut pos = self.pos(); - - // track whether or not we've jumped - let mut jumped = false; - let max_jumps = 5; - let mut jumps_performed = 0; - - // Our delimiter which we append for each label. Since we don't want a - // dot at the beginning of the domain name we'll leave it empty for now - // and set it to "." at the end of the first iteration. - let mut delim = ""; - loop { - // Dns Packets are untrusted data, so we need to be paranoid. Someone - // can craft a packet with a cycle in the jump instructions. This guards - // against such packets. - if jumps_performed > max_jumps { - return Err(format!("Limit of {} jumps exceeded", max_jumps).into()); - } - - // At this point, we're always at the beginning of a label. Recall - // that labels start with a length byte. - let len = self.get(pos)?; - - // If len has the two most significant bit are set, it represents a - // jump to some other offset in the packet: - if (len & 0xC0) == 0xC0 { - // Update the buffer position to a point past the current - // label. We don't need to touch it any further. - if !jumped { - self.seek(pos + 2)?; - } - - // Read another byte, calculate offset and perform the jump by - // updating our local position variable - let b2 = self.get(pos + 1)? as u16; - let offset = (((len as u16) ^ 0xC0) << 8) | b2; - pos = offset as usize; - - // Indicate that a jump was performed. - jumped = true; - jumps_performed += 1; - - continue; - } - // The base scenario, where we're reading a single label and - // appending it to the output: - else { - // Move a single byte forward to move past the length byte. - pos += 1; - - // Domain names are terminated by an empty label of length 0, - // so if the length is zero we're done. - if len == 0 { - break; - } - - // Append the delimiter to our output buffer first. - outstr.push_str(delim); - - // Extract the actual ASCII bytes for this label and append them - // to the output buffer. - let str_buffer = self.get_range(pos, len as usize)?; - outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); - - delim = "."; - - // Move forward the full length of the label. - pos += len as usize; - } - } - - if !jumped { - self.seek(pos)?; - } - - Ok(()) - } - fn write(&mut self, val: u8) -> Result<()> { - if self.pos >= 512 { - return Err("End of buffer".into()); - } - self.buf[self.pos] = val; - self.pos += 1; - Ok(()) - } - - fn write_u8(&mut self, val: u8) -> Result<()> { - self.write(val)?; - - Ok(()) - } - - fn write_u16(&mut self, val: u16) -> Result<()> { - self.write((val >> 8) as u8)?; - self.write((val & 0xFF) as u8)?; - - Ok(()) - } - - fn write_u32(&mut self, val: u32) -> Result<()> { - self.write(((val >> 24) & 0xFF) as u8)?; - self.write(((val >> 16) & 0xFF) as u8)?; - self.write(((val >> 8) & 0xFF) as u8)?; - self.write(((val >> 0) & 0xFF) as u8)?; - - Ok(()) - } - - fn write_qname(&mut self, qname: &str) -> Result<()> { - for label in qname.split('.') { - let len = label.len(); - if len > 0x3f { - return Err("Single label exceeds 63 characters of length".into()); - } - - self.write_u8(len as u8)?; - for b in label.as_bytes() { - self.write_u8(*b)?; - } - } - - self.write_u8(0)?; - - Ok(()) - } - - fn set(&mut self, pos: usize, val: u8) -> Result<()> { - self.buf[pos] = val; - - Ok(()) - } - - fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> { - self.set(pos, (val >> 8) as u8)?; - self.set(pos + 1, (val & 0xFF) as u8)?; - - Ok(()) - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ResultCode { - NOERROR = 0, - FORMERR = 1, - SERVFAIL = 2, - NXDOMAIN = 3, - NOTIMP = 4, - REFUSED = 5, -} - -impl ResultCode { - pub fn from_num(num: u8) -> ResultCode { - match num { - 1 => ResultCode::FORMERR, - 2 => ResultCode::SERVFAIL, - 3 => ResultCode::NXDOMAIN, - 4 => ResultCode::NOTIMP, - 5 => ResultCode::REFUSED, - 0 | _ => ResultCode::NOERROR, - } - } -} - -#[derive(Clone, Debug)] -pub struct DnsHeader { - pub id: u16, // 16 bits - - pub recursion_desired: bool, // 1 bit - pub truncated_message: bool, // 1 bit - pub authoritative_answer: bool, // 1 bit - pub opcode: u8, // 4 bits - pub response: bool, // 1 bit - - pub rescode: ResultCode, // 4 bits - pub checking_disabled: bool, // 1 bit - pub authed_data: bool, // 1 bit - pub z: bool, // 1 bit - pub recursion_available: bool, // 1 bit - - pub questions: u16, // 16 bits - pub answers: u16, // 16 bits - pub authoritative_entries: u16, // 16 bits - pub resource_entries: u16, // 16 bits -} - -impl DnsHeader { - pub fn new() -> DnsHeader { - DnsHeader { - id: 0, - - recursion_desired: false, - truncated_message: false, - authoritative_answer: false, - opcode: 0, - response: false, - - rescode: ResultCode::NOERROR, - checking_disabled: false, - authed_data: false, - z: false, - recursion_available: false, - - questions: 0, - answers: 0, - authoritative_entries: 0, - resource_entries: 0, - } - } - - pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { - self.id = buffer.read_u16()?; - - let flags = buffer.read_u16()?; - let a = (flags >> 8) as u8; - let b = (flags & 0xFF) as u8; - self.recursion_desired = (a & (1 << 0)) > 0; - self.truncated_message = (a & (1 << 1)) > 0; - self.authoritative_answer = (a & (1 << 2)) > 0; - self.opcode = (a >> 3) & 0x0F; - self.response = (a & (1 << 7)) > 0; - - self.rescode = ResultCode::from_num(b & 0x0F); - self.checking_disabled = (b & (1 << 4)) > 0; - self.authed_data = (b & (1 << 5)) > 0; - self.z = (b & (1 << 6)) > 0; - self.recursion_available = (b & (1 << 7)) > 0; - - self.questions = buffer.read_u16()?; - self.answers = buffer.read_u16()?; - self.authoritative_entries = buffer.read_u16()?; - self.resource_entries = buffer.read_u16()?; - - // Return the constant header size - Ok(()) - } - - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { - buffer.write_u16(self.id)?; - - buffer.write_u8( - (self.recursion_desired as u8) - | ((self.truncated_message as u8) << 1) - | ((self.authoritative_answer as u8) << 2) - | (self.opcode << 3) - | ((self.response as u8) << 7) as u8, - )?; - - buffer.write_u8( - (self.rescode as u8) - | ((self.checking_disabled as u8) << 4) - | ((self.authed_data as u8) << 5) - | ((self.z as u8) << 6) - | ((self.recursion_available as u8) << 7), - )?; - - buffer.write_u16(self.questions)?; - buffer.write_u16(self.answers)?; - buffer.write_u16(self.authoritative_entries)?; - buffer.write_u16(self.resource_entries)?; - - Ok(()) - } -} - -#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] -pub enum QueryType { - UNKNOWN(u16), - A, // 1 - NS, // 2 - CNAME, // 5 - MX, // 15 - AAAA, // 28 -} - -impl QueryType { - pub fn to_num(&self) -> u16 { - match *self { - QueryType::UNKNOWN(x) => x, - QueryType::A => 1, - QueryType::NS => 2, - QueryType::CNAME => 5, - QueryType::MX => 15, - QueryType::AAAA => 28, - } - } - - pub fn from_num(num: u16) -> QueryType { - match num { - 1 => QueryType::A, - 2 => QueryType::NS, - 5 => QueryType::CNAME, - 15 => QueryType::MX, - 28 => QueryType::AAAA, - _ => QueryType::UNKNOWN(num), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DnsQuestion { - pub name: String, - pub qtype: QueryType, -} - -impl DnsQuestion { - pub fn new(name: String, qtype: QueryType) -> DnsQuestion { - DnsQuestion { - name: name, - qtype: qtype, - } - } - - pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { - buffer.read_qname(&mut self.name)?; - self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype - let _ = buffer.read_u16()?; // class - - Ok(()) - } - - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { - buffer.write_qname(&self.name)?; - - let typenum = self.qtype.to_num(); - buffer.write_u16(typenum)?; - buffer.write_u16(1)?; - - Ok(()) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[allow(dead_code)] -pub enum DnsRecord { - UNKNOWN { - domain: String, - qtype: u16, - data_len: u16, - ttl: u32, - }, // 0 - A { - domain: String, - addr: Ipv4Addr, - ttl: u32, - }, // 1 - NS { - domain: String, - host: String, - ttl: u32, - }, // 2 - CNAME { - domain: String, - host: String, - ttl: u32, - }, // 5 - MX { - domain: String, - priority: u16, - host: String, - ttl: u32, - }, // 15 - AAAA { - domain: String, - addr: Ipv6Addr, - ttl: u32, - }, // 28 -} - -impl DnsRecord { - pub fn read(buffer: &mut BytePacketBuffer) -> Result { - let mut domain = String::new(); - buffer.read_qname(&mut domain)?; - - let qtype_num = buffer.read_u16()?; - let qtype = QueryType::from_num(qtype_num); - let _ = buffer.read_u16()?; - let ttl = buffer.read_u32()?; - let data_len = buffer.read_u16()?; - - match qtype { - QueryType::A => { - let raw_addr = buffer.read_u32()?; - let addr = Ipv4Addr::new( - ((raw_addr >> 24) & 0xFF) as u8, - ((raw_addr >> 16) & 0xFF) as u8, - ((raw_addr >> 8) & 0xFF) as u8, - ((raw_addr >> 0) & 0xFF) as u8, - ); - - Ok(DnsRecord::A { - domain: domain, - addr: addr, - ttl: ttl, - }) - } - QueryType::AAAA => { - let raw_addr1 = buffer.read_u32()?; - let raw_addr2 = buffer.read_u32()?; - let raw_addr3 = buffer.read_u32()?; - let raw_addr4 = buffer.read_u32()?; - let addr = Ipv6Addr::new( - ((raw_addr1 >> 16) & 0xFFFF) as u16, - ((raw_addr1 >> 0) & 0xFFFF) as u16, - ((raw_addr2 >> 16) & 0xFFFF) as u16, - ((raw_addr2 >> 0) & 0xFFFF) as u16, - ((raw_addr3 >> 16) & 0xFFFF) as u16, - ((raw_addr3 >> 0) & 0xFFFF) as u16, - ((raw_addr4 >> 16) & 0xFFFF) as u16, - ((raw_addr4 >> 0) & 0xFFFF) as u16, - ); - - Ok(DnsRecord::AAAA { - domain: domain, - addr: addr, - ttl: ttl, - }) - } - QueryType::NS => { - let mut ns = String::new(); - buffer.read_qname(&mut ns)?; - - Ok(DnsRecord::NS { - domain: domain, - host: ns, - ttl: ttl, - }) - } - QueryType::CNAME => { - let mut cname = String::new(); - buffer.read_qname(&mut cname)?; - - Ok(DnsRecord::CNAME { - domain: domain, - host: cname, - ttl: ttl, - }) - } - QueryType::MX => { - let priority = buffer.read_u16()?; - let mut mx = String::new(); - buffer.read_qname(&mut mx)?; - - Ok(DnsRecord::MX { - domain: domain, - priority: priority, - host: mx, - ttl: ttl, - }) - } - QueryType::UNKNOWN(_) => { - buffer.step(data_len as usize)?; - - Ok(DnsRecord::UNKNOWN { - domain: domain, - qtype: qtype_num, - data_len: data_len, - ttl: ttl, - }) - } - } - } - - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { - let start_pos = buffer.pos(); - - match *self { - DnsRecord::A { - ref domain, - ref addr, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::A.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(4)?; - - let octets = addr.octets(); - buffer.write_u8(octets[0])?; - buffer.write_u8(octets[1])?; - buffer.write_u8(octets[2])?; - buffer.write_u8(octets[3])?; - } - DnsRecord::NS { - ref domain, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::NS.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - DnsRecord::CNAME { - ref domain, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::CNAME.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - DnsRecord::MX { - ref domain, - priority, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::MX.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_u16(priority)?; - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - DnsRecord::AAAA { - ref domain, - ref addr, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::AAAA.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(16)?; - - for octet in &addr.segments() { - buffer.write_u16(*octet)?; - } - } - DnsRecord::UNKNOWN { .. } => { - println!("Skipping record: {:?}", self); - } - } - - Ok(buffer.pos() - start_pos) - } -} - -#[derive(Clone, Debug)] -pub struct DnsPacket { - pub header: DnsHeader, - pub questions: Vec, - pub answers: Vec, - pub authorities: Vec, - pub resources: Vec, -} - -impl DnsPacket { - pub fn new() -> DnsPacket { - DnsPacket { - header: DnsHeader::new(), - questions: Vec::new(), - answers: Vec::new(), - authorities: Vec::new(), - resources: Vec::new(), - } - } - - pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result { - let mut result = DnsPacket::new(); - result.header.read(buffer)?; - - for _ in 0..result.header.questions { - let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); - question.read(buffer)?; - result.questions.push(question); - } - - for _ in 0..result.header.answers { - let rec = DnsRecord::read(buffer)?; - result.answers.push(rec); - } - for _ in 0..result.header.authoritative_entries { - let rec = DnsRecord::read(buffer)?; - result.authorities.push(rec); - } - for _ in 0..result.header.resource_entries { - let rec = DnsRecord::read(buffer)?; - result.resources.push(rec); - } - - Ok(result) - } - - pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { - self.header.questions = self.questions.len() as u16; - self.header.answers = self.answers.len() as u16; - self.header.authoritative_entries = self.authorities.len() as u16; - self.header.resource_entries = self.resources.len() as u16; - - self.header.write(buffer)?; - - for question in &self.questions { - question.write(buffer)?; - } - for rec in &self.answers { - rec.write(buffer)?; - } - for rec in &self.authorities { - rec.write(buffer)?; - } - for rec in &self.resources { - rec.write(buffer)?; - } - - Ok(()) - } - - pub fn display(&self) { - println!("{:#?}", self.header); - - for q in &self.questions { - println!("{:#?}", q); - } - for rec in &self.answers { - println!("{:#?}", rec); - } - for rec in &self.authorities { - println!("{:#?}", rec); - } - for rec in &self.resources { - println!("{:#?}", rec); - } - } -} - -/*fn main() -> Result<()> { - let mut f = File::open("response_packet.txt")?; - let mut buffer = BytePacketBuffer::new(); - f.read(&mut buffer.buf)?; - - let packet = DnsPacket::from_buffer(&mut buffer)?; - println!("{:#?}", packet.header); - - for q in packet.questions { - println!("{:#?}", q); - } - for rec in packet.answers { - println!("{:#?}", rec); - } - for rec in packet.authorities { - println!("{:#?}", rec); - } - for rec in packet.resources { - println!("{:#?}", rec); - } - - Ok(()) -}*/ - -use std::net::UdpSocket; - -fn main() -> std::io::Result<()> { - let socket = UdpSocket::bind("0.0.0.0:53")?; loop { let mut buffer = BytePacketBuffer::new(); - let (number_of_bytes, src_addr) = socket.recv_from(&mut buffer.buf)?; - print!("received: {} from {} \n", number_of_bytes, src_addr); + let (_, src_addr) = socket.recv_from(&mut buffer.buf).await?; - let packet = DnsPacket::from_buffer(&mut buffer).unwrap(); - packet.display(); + let ctx = Arc::clone(&ctx); + tokio::spawn(async move { + if let Err(e) = handle_query(buffer, src_addr, &ctx).await { + error!("{} | HANDLER ERROR | {}", src_addr, e); + } + }); + } +} - let mut resp = DnsPacket::new(); - resp.header.id = packet.header.id; - resp.header.authoritative_answer = true; - resp.header.response = true; - resp.questions = packet.questions; - resp.answers.push(DnsRecord::A{domain: "dimescu.ro".to_string(), addr: Ipv4Addr::new(3, 120, 139, 105), ttl: 30}); - resp.display(); +async fn handle_query( + mut buffer: BytePacketBuffer, + src_addr: SocketAddr, + ctx: &ServerCtx, +) -> dns_fun::Result<()> { + let start = Instant::now(); - let mut req_buffer = BytePacketBuffer::new(); - resp.write(&mut req_buffer).unwrap(); + let query = match DnsPacket::from_buffer(&mut buffer) { + Ok(packet) => packet, + Err(e) => { + warn!("{} | PARSE ERROR | {}", src_addr, e); + return Ok(()); + } + }; - socket.send_to(&req_buffer.buf[0..req_buffer.pos], src_addr).unwrap(); + let (qname, qtype) = match query.questions.first() { + Some(q) => (q.name.clone(), q.qtype), + None => return Ok(()), + }; + + // Pipeline: local zones -> cache -> upstream + // Each lock is scoped to avoid holding MutexGuard across await points. + let (response, path) = if let Some(records) = ctx.zone_map.get(&(qname.to_lowercase(), qtype)) { + let mut resp = DnsPacket::response_from(&query, ResultCode::NOERROR); + resp.answers = records.clone(); + (resp, QueryPath::Local) + } else { + let cached = ctx.cache.lock().unwrap().lookup(&qname, qtype); + if let Some(cached) = cached { + let mut resp = cached; + resp.header.id = query.header.id; + (resp, QueryPath::Cached) + } else { + match forward_query(&query, ctx.upstream, ctx.timeout).await { + Ok(resp) => { + ctx.cache.lock().unwrap().insert(&qname, qtype, &resp); + (resp, QueryPath::Forwarded) + } + Err(e) => { + error!("{} | {:?} {} | UPSTREAM ERROR | {}", src_addr, qtype, qname, e); + (DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError) + } + } + } + }; + + let elapsed = start.elapsed(); + + info!( + "{} | {:?} {} | {} | {} | {}ms", + src_addr, qtype, qname, path.as_str(), + response.header.rescode.as_str(), elapsed.as_millis(), + ); + + debug!( + "response: {} answers, {} authorities, {} resources", + response.answers.len(), response.authorities.len(), response.resources.len(), + ); + + let mut resp_buffer = BytePacketBuffer::new(); + response.write(&mut resp_buffer)?; + ctx.socket.send_to(resp_buffer.filled(), src_addr).await?; + + // Record stats and log summary every 1000 queries (single lock acquisition) + let mut s = ctx.stats.lock().unwrap(); + let total = s.record(path); + if total % 1000 == 0 { + s.log_summary(); } Ok(()) diff --git a/src/packet.rs b/src/packet.rs new file mode 100644 index 0000000..2098257 --- /dev/null +++ b/src/packet.rs @@ -0,0 +1,105 @@ +use crate::buffer::BytePacketBuffer; +use crate::header::DnsHeader; +use crate::question::{DnsQuestion, QueryType}; +use crate::record::DnsRecord; +use crate::Result; + +#[derive(Clone, Debug)] +pub struct DnsPacket { + pub header: DnsHeader, + pub questions: Vec, + pub answers: Vec, + pub authorities: Vec, + pub resources: Vec, +} + +impl DnsPacket { + pub fn new() -> DnsPacket { + DnsPacket { + header: DnsHeader::new(), + questions: Vec::new(), + answers: Vec::new(), + authorities: Vec::new(), + resources: Vec::new(), + } + } + + pub fn response_from(query: &DnsPacket, rescode: crate::header::ResultCode) -> DnsPacket { + let mut resp = DnsPacket::new(); + resp.header.id = query.header.id; + resp.header.response = true; + resp.header.recursion_desired = query.header.recursion_desired; + resp.header.recursion_available = true; + resp.header.rescode = rescode; + resp.questions = query.questions.clone(); + resp + } + + pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result { + let mut result = DnsPacket::new(); + result.header.read(buffer)?; + + for _ in 0..result.header.questions { + let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); + question.read(buffer)?; + result.questions.push(question); + } + + for _ in 0..result.header.answers { + let rec = DnsRecord::read(buffer)?; + result.answers.push(rec); + } + for _ in 0..result.header.authoritative_entries { + let rec = DnsRecord::read(buffer)?; + result.authorities.push(rec); + } + for _ in 0..result.header.resource_entries { + let rec = DnsRecord::read(buffer)?; + result.resources.push(rec); + } + + Ok(result) + } + + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { + let mut header = self.header.clone(); + header.questions = self.questions.len() as u16; + header.answers = self.answers.len() as u16; + header.authoritative_entries = self.authorities.len() as u16; + header.resource_entries = self.resources.len() as u16; + + header.write(buffer)?; + + for question in &self.questions { + question.write(buffer)?; + } + for rec in &self.answers { + rec.write(buffer)?; + } + for rec in &self.authorities { + rec.write(buffer)?; + } + for rec in &self.resources { + rec.write(buffer)?; + } + + Ok(()) + } + + pub fn display(&self) { + println!("{:#?}", self.header); + + for q in &self.questions { + println!("{:#?}", q); + } + for rec in &self.answers { + println!("{:#?}", rec); + } + for rec in &self.authorities { + println!("{:#?}", rec); + } + for rec in &self.resources { + println!("{:#?}", rec); + } + } +} diff --git a/src/question.rs b/src/question.rs new file mode 100644 index 0000000..b142153 --- /dev/null +++ b/src/question.rs @@ -0,0 +1,64 @@ +use crate::buffer::BytePacketBuffer; +use crate::Result; + +#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] +pub enum QueryType { + UNKNOWN(u16), + A, // 1 + NS, // 2 + CNAME, // 5 + MX, // 15 + AAAA, // 28 +} + +impl QueryType { + pub fn to_num(&self) -> u16 { + match *self { + QueryType::UNKNOWN(x) => x, + QueryType::A => 1, + QueryType::NS => 2, + QueryType::CNAME => 5, + QueryType::MX => 15, + QueryType::AAAA => 28, + } + } + + pub fn from_num(num: u16) -> QueryType { + match num { + 1 => QueryType::A, + 2 => QueryType::NS, + 5 => QueryType::CNAME, + 15 => QueryType::MX, + 28 => QueryType::AAAA, + _ => QueryType::UNKNOWN(num), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DnsQuestion { + pub name: String, + pub qtype: QueryType, +} + +impl DnsQuestion { + pub fn new(name: String, qtype: QueryType) -> DnsQuestion { + DnsQuestion { name, qtype } + } + + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { + buffer.read_qname(&mut self.name)?; + self.qtype = QueryType::from_num(buffer.read_u16()?); + let _ = buffer.read_u16()?; // class + + Ok(()) + } + + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { + buffer.write_qname(&self.name)?; + buffer.write_u16(self.qtype.to_num())?; + buffer.write_u16(1)?; + + Ok(()) + } +} diff --git a/src/record.rs b/src/record.rs new file mode 100644 index 0000000..ffc6ef3 --- /dev/null +++ b/src/record.rs @@ -0,0 +1,249 @@ +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; + +use crate::buffer::BytePacketBuffer; +use crate::question::QueryType; +use crate::Result; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[allow(dead_code)] +pub enum DnsRecord { + UNKNOWN { + domain: String, + qtype: u16, + data_len: u16, + ttl: u32, + }, + A { + domain: String, + addr: Ipv4Addr, + ttl: u32, + }, + NS { + domain: String, + host: String, + ttl: u32, + }, + CNAME { + domain: String, + host: String, + ttl: u32, + }, + MX { + domain: String, + priority: u16, + host: String, + ttl: u32, + }, + AAAA { + domain: String, + addr: Ipv6Addr, + ttl: u32, + }, +} + +impl DnsRecord { + pub fn ttl(&self) -> u32 { + match self { + DnsRecord::A { ttl, .. } + | DnsRecord::NS { ttl, .. } + | DnsRecord::CNAME { ttl, .. } + | DnsRecord::MX { ttl, .. } + | DnsRecord::AAAA { ttl, .. } + | DnsRecord::UNKNOWN { ttl, .. } => *ttl, + } + } + + pub fn set_ttl(&mut self, new_ttl: u32) { + match self { + DnsRecord::A { ttl, .. } + | DnsRecord::NS { ttl, .. } + | DnsRecord::CNAME { ttl, .. } + | DnsRecord::MX { ttl, .. } + | DnsRecord::AAAA { ttl, .. } + | DnsRecord::UNKNOWN { ttl, .. } => *ttl = new_ttl, + } + } + + pub fn read(buffer: &mut BytePacketBuffer) -> Result { + let mut domain = String::new(); + buffer.read_qname(&mut domain)?; + + let qtype_num = buffer.read_u16()?; + let qtype = QueryType::from_num(qtype_num); + let _ = buffer.read_u16()?; + let ttl = buffer.read_u32()?; + let data_len = buffer.read_u16()?; + + match qtype { + QueryType::A => { + let raw_addr = buffer.read_u32()?; + let addr = Ipv4Addr::new( + ((raw_addr >> 24) & 0xFF) as u8, + ((raw_addr >> 16) & 0xFF) as u8, + ((raw_addr >> 8) & 0xFF) as u8, + ((raw_addr >> 0) & 0xFF) as u8, + ); + + Ok(DnsRecord::A { domain, addr, ttl }) + } + QueryType::AAAA => { + let raw_addr1 = buffer.read_u32()?; + let raw_addr2 = buffer.read_u32()?; + let raw_addr3 = buffer.read_u32()?; + let raw_addr4 = buffer.read_u32()?; + let addr = Ipv6Addr::new( + ((raw_addr1 >> 16) & 0xFFFF) as u16, + ((raw_addr1 >> 0) & 0xFFFF) as u16, + ((raw_addr2 >> 16) & 0xFFFF) as u16, + ((raw_addr2 >> 0) & 0xFFFF) as u16, + ((raw_addr3 >> 16) & 0xFFFF) as u16, + ((raw_addr3 >> 0) & 0xFFFF) as u16, + ((raw_addr4 >> 16) & 0xFFFF) as u16, + ((raw_addr4 >> 0) & 0xFFFF) as u16, + ); + + Ok(DnsRecord::AAAA { domain, addr, ttl }) + } + QueryType::NS => { + let mut ns = String::new(); + buffer.read_qname(&mut ns)?; + + Ok(DnsRecord::NS { + domain, + host: ns, + ttl, + }) + } + QueryType::CNAME => { + let mut cname = String::new(); + buffer.read_qname(&mut cname)?; + + Ok(DnsRecord::CNAME { + domain, + host: cname, + ttl, + }) + } + QueryType::MX => { + let priority = buffer.read_u16()?; + let mut mx = String::new(); + buffer.read_qname(&mut mx)?; + + Ok(DnsRecord::MX { + domain, + priority, + host: mx, + ttl, + }) + } + QueryType::UNKNOWN(_) => { + buffer.step(data_len as usize)?; + + Ok(DnsRecord::UNKNOWN { + domain, + qtype: qtype_num, + data_len, + ttl, + }) + } + } + } + + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { + let start_pos = buffer.pos(); + + match *self { + DnsRecord::A { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; + + let octets = addr.octets(); + buffer.write_u8(octets[0])?; + buffer.write_u8(octets[1])?; + buffer.write_u8(octets[2])?; + buffer.write_u8(octets[3])?; + } + DnsRecord::NS { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::NS.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::CNAME { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::CNAME.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::MX { + ref domain, + priority, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::MX.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + buffer.write_u16(priority)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::AAAA { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::AAAA.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(16)?; + + for octet in &addr.segments() { + buffer.write_u16(*octet)?; + } + } + DnsRecord::UNKNOWN { .. } => { + println!("Skipping record: {:?}", self); + } + } + + Ok(buffer.pos() - start_pos) + } +} diff --git a/src/stats.rs b/src/stats.rs new file mode 100644 index 0000000..d9db0e3 --- /dev/null +++ b/src/stats.rs @@ -0,0 +1,79 @@ +use std::time::Instant; + +pub struct ServerStats { + queries_total: u64, + queries_forwarded: u64, + queries_cached: u64, + queries_blocked: u64, + queries_local: u64, + upstream_errors: u64, + started_at: Instant, +} + +pub enum QueryPath { + Local, + Cached, + Forwarded, + Blocked, + UpstreamError, +} + +impl QueryPath { + pub fn as_str(&self) -> &'static str { + match self { + QueryPath::Local => "LOCAL", + QueryPath::Cached => "CACHED", + QueryPath::Forwarded => "FORWARD", + QueryPath::Blocked => "BLOCKED", + QueryPath::UpstreamError => "SERVFAIL", + } + } +} + +impl ServerStats { + pub fn new() -> Self { + ServerStats { + queries_total: 0, + queries_forwarded: 0, + queries_cached: 0, + queries_blocked: 0, + queries_local: 0, + upstream_errors: 0, + started_at: Instant::now(), + } + } + + pub fn record(&mut self, path: QueryPath) -> u64 { + self.queries_total += 1; + match path { + QueryPath::Local => self.queries_local += 1, + QueryPath::Cached => self.queries_cached += 1, + QueryPath::Forwarded => self.queries_forwarded += 1, + QueryPath::Blocked => self.queries_blocked += 1, + QueryPath::UpstreamError => self.upstream_errors += 1, + } + self.queries_total + } + + pub fn total(&self) -> u64 { + self.queries_total + } + + pub fn log_summary(&self) { + let uptime = self.started_at.elapsed(); + let hours = uptime.as_secs() / 3600; + let mins = (uptime.as_secs() % 3600) / 60; + let secs = uptime.as_secs() % 60; + + log::info!( + "STATS | uptime {}h{}m{}s | total {} | fwd {} | cached {} | local {} | blocked {} | errors {}", + hours, mins, secs, + self.queries_total, + self.queries_forwarded, + self.queries_cached, + self.queries_local, + self.queries_blocked, + self.upstream_errors, + ); + } +} -- 2.34.1 From 967150b99162908a8d2814cf985e4de6012c6f7b Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Tue, 10 Mar 2026 04:58:11 +0200 Subject: [PATCH 2/4] gitignore CLAUDE.md and update README for async tokio Co-Authored-By: Claude Opus 4.6 --- .gitignore | 1 + CLAUDE.md | 59 ------------------------------------------------------ README.md | 7 ++++--- 3 files changed, 5 insertions(+), 62 deletions(-) delete mode 100644 CLAUDE.md diff --git a/.gitignore b/.gitignore index ea8c4bf..cfa6940 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 6079f31..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,59 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Overview - -A DNS forwarding/caching proxy in Rust. Serves local zone records from TOML config, caches upstream responses with TTL-based expiration, forwards unknown queries to an upstream resolver, and logs all queries with structured output. - -## Build & Run - -```bash -cargo build # compile -sudo cargo run # run with default config (dns_fun.toml) -sudo cargo run -- path/to/config # run with custom config path -RUST_LOG=debug sudo cargo run # verbose logging -``` - -Test with: `dig @127.0.0.1 google.com` - -No tests or linter configured. - -## Architecture - -``` -src/ - lib.rs # module declarations, Error/Result type aliases - main.rs # startup, config load, UDP listen loop, request pipeline - buffer.rs # BytePacketBuffer — 512-byte DNS wire format read/write - header.rs # DnsHeader, ResultCode — 12-byte header bitfield parsing - question.rs # DnsQuestion, QueryType — query section (A, NS, CNAME, MX, AAAA) - record.rs # DnsRecord — resource record variants with read/write - packet.rs # DnsPacket — top-level: header + questions + answers + authorities + resources - config.rs # Config loading from TOML, zone map builder - cache.rs # DnsCache — TTL-aware cache with lazy eviction - forward.rs # forward_query() — sends query to upstream, build_servfail() — error response - stats.rs # ServerStats — query counters and periodic summary -``` - -## Request Pipeline - -``` -Query → Parse → Log → Local Zones → Cache → Upstream Forward (+ cache result) → Log → Respond -``` - -## Config - -`dns_fun.toml` at project root. Sections: `[server]`, `[upstream]`, `[cache]`, `[[zones]]`. Falls back to sensible defaults if file is missing. - -## Logging - -Controlled via `RUST_LOG` env var. Default level: `info` (one structured line per query). `debug` adds response details. Stats summary every 1000 queries. - -## Key Details - -- Rust 2018 edition, deps: `serde`, `toml`, `log`, `env_logger` -- DNS packet size limited to 512 bytes (standard UDP DNS) -- `BytePacketBuffer::read_qname` handles label compression (pointer jumps) -- `type Error = Box` / `type Result` aliased in `lib.rs` -- Cache: TTL clamped between `min_ttl` and `max_ttl`, lazy eviction every 1000 queries diff --git a/README.md b/README.md index 26234fe..a0bc077 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A DNS forwarding/caching proxy written from scratch in Rust. Parses and serializes DNS wire protocol (RFC 1035), serves local zone records from TOML config, caches upstream responses with TTL-aware expiration, and logs every query with structured output. -No async runtime, no DNS libraries — just `std::net::UdpSocket` and manual packet parsing. +No DNS libraries — just `tokio::net::UdpSocket` and manual packet parsing. Each query is handled concurrently via `tokio::spawn`. ## Record Types @@ -99,7 +99,7 @@ Stats summary (total, forwarded, cached, local, blocked, errors) logged every 10 ``` src/ - main.rs # startup, config load, UDP listen loop, request pipeline + main.rs # async startup, tokio event loop, ServerCtx, per-query task spawn lib.rs # module declarations, Error/Result type aliases buffer.rs # BytePacketBuffer — 512-byte DNS wire format read/write header.rs # DnsHeader, ResultCode @@ -108,13 +108,14 @@ src/ packet.rs # DnsPacket — full DNS message parse/serialize config.rs # TOML config loading, zone map builder cache.rs # TTL-aware DNS response cache with lazy eviction - forward.rs # upstream forwarding, SERVFAIL builder + forward.rs # async upstream forwarding stats.rs # query counters and periodic summary ``` ## Dependencies ```toml +tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "time"] } toml = "0.8" serde = { version = "1", features = ["derive"] } log = "0.4" -- 2.34.1 From 2c6133344ae46798b10700ee9b99f1667b160620 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Tue, 10 Mar 2026 05:04:31 +0200 Subject: [PATCH 3/4] add Makefile with clippy/rustfmt linting, fix all warnings Co-Authored-By: Claude Opus 4.6 --- Makefile | 20 +++++++++++++++ src/buffer.rs | 12 ++++++--- src/cache.rs | 18 ++++++++------ src/config.rs | 69 ++++++++++++++++++++++++++++++++++++++++----------- src/header.rs | 10 ++++++-- src/main.rs | 31 +++++++++++++++++------ src/packet.rs | 6 +++++ src/record.rs | 10 ++++---- src/stats.rs | 6 +++++ 9 files changed, 143 insertions(+), 39 deletions(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..540f041 --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +.PHONY: all build lint fmt check test clean + +all: lint build + +build: + cargo build + +lint: fmt check + +fmt: + cargo fmt --check + +check: + cargo clippy -- -D warnings + +test: + cargo test + +clean: + cargo clean diff --git a/src/buffer.rs b/src/buffer.rs index 5c82f0e..c9378d0 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,10 +1,16 @@ -use crate::{Result}; +use crate::Result; pub struct BytePacketBuffer { pub buf: [u8; 512], pub pos: usize, } +impl Default for BytePacketBuffer { + fn default() -> Self { + Self::new() + } +} + impl BytePacketBuffer { pub fn new() -> BytePacketBuffer { BytePacketBuffer { @@ -63,7 +69,7 @@ impl BytePacketBuffer { let res = ((self.read()? as u32) << 24) | ((self.read()? as u32) << 16) | ((self.read()? as u32) << 8) - | ((self.read()? as u32) << 0); + | (self.read()? as u32); Ok(res) } @@ -144,7 +150,7 @@ impl BytePacketBuffer { self.write(((val >> 24) & 0xFF) as u8)?; self.write(((val >> 16) & 0xFF) as u8)?; self.write(((val >> 8) & 0xFF) as u8)?; - self.write(((val >> 0) & 0xFF) as u8)?; + self.write((val & 0xFF) as u8)?; Ok(()) } diff --git a/src/cache.rs b/src/cache.rs index 6dd2e45..65629fc 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -34,7 +34,7 @@ impl DnsCache { self.query_count += 1; // Periodic eviction every 1000 queries - if self.query_count % 1000 == 0 { + if self.query_count.is_multiple_of(1000) { self.evict_expired(); } @@ -72,15 +72,19 @@ impl DnsCache { .clamp(self.min_ttl, self.max_ttl); let key = (domain.to_string(), qtype); - self.entries.insert(key, CacheEntry { - packet: packet.clone(), - inserted_at: Instant::now(), - ttl: Duration::from_secs(min_ttl as u64), - }); + self.entries.insert( + key, + CacheEntry { + packet: packet.clone(), + inserted_at: Instant::now(), + ttl: Duration::from_secs(min_ttl as u64), + }, + ); } fn evict_expired(&mut self) { - self.entries.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl); + self.entries + .retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl); } } diff --git a/src/config.rs b/src/config.rs index 3d13f23..1cd5c61 100644 --- a/src/config.rs +++ b/src/config.rs @@ -126,36 +126,77 @@ pub fn load_config(path: &str) -> Result { Ok(config) } -pub fn build_zone_map(zones: &[ZoneRecord]) -> Result>> { +pub fn build_zone_map( + zones: &[ZoneRecord], +) -> Result>> { let mut map: HashMap<(String, QueryType), Vec> = HashMap::new(); for zone in zones { let domain = zone.domain.to_lowercase(); let (qtype, record) = match zone.record_type.to_uppercase().as_str() { "A" => { - let addr: Ipv4Addr = zone.value.parse() + let addr: Ipv4Addr = zone + .value + .parse() .map_err(|e| format!("invalid A record value '{}': {}", zone.value, e))?; - (QueryType::A, DnsRecord::A { domain: domain.clone(), addr, ttl: zone.ttl }) + ( + QueryType::A, + DnsRecord::A { + domain: domain.clone(), + addr, + ttl: zone.ttl, + }, + ) } "AAAA" => { - let addr: Ipv6Addr = zone.value.parse() + let addr: Ipv6Addr = zone + .value + .parse() .map_err(|e| format!("invalid AAAA record value '{}': {}", zone.value, e))?; - (QueryType::AAAA, DnsRecord::AAAA { domain: domain.clone(), addr, ttl: zone.ttl }) - } - "CNAME" => { - (QueryType::CNAME, DnsRecord::CNAME { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl }) - } - "NS" => { - (QueryType::NS, DnsRecord::NS { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl }) + ( + QueryType::AAAA, + DnsRecord::AAAA { + domain: domain.clone(), + addr, + ttl: zone.ttl, + }, + ) } + "CNAME" => ( + QueryType::CNAME, + DnsRecord::CNAME { + domain: domain.clone(), + host: zone.value.clone(), + ttl: zone.ttl, + }, + ), + "NS" => ( + QueryType::NS, + DnsRecord::NS { + domain: domain.clone(), + host: zone.value.clone(), + ttl: zone.ttl, + }, + ), "MX" => { let parts: Vec<&str> = zone.value.splitn(2, ' ').collect(); if parts.len() != 2 { - return Err(format!("MX value must be 'priority host', got '{}'", zone.value).into()); + return Err( + format!("MX value must be 'priority host', got '{}'", zone.value).into(), + ); } - let priority: u16 = parts[0].parse() + let priority: u16 = parts[0] + .parse() .map_err(|e| format!("invalid MX priority '{}': {}", parts[0], e))?; - (QueryType::MX, DnsRecord::MX { domain: domain.clone(), priority, host: parts[1].to_string(), ttl: zone.ttl }) + ( + QueryType::MX, + DnsRecord::MX { + domain: domain.clone(), + priority, + host: parts[1].to_string(), + ttl: zone.ttl, + }, + ) } other => { return Err(format!("unsupported record type '{}'", other).into()); diff --git a/src/header.rs b/src/header.rs index 2ce42a1..837e1ea 100644 --- a/src/header.rs +++ b/src/header.rs @@ -19,7 +19,7 @@ impl ResultCode { 3 => ResultCode::NXDOMAIN, 4 => ResultCode::NOTIMP, 5 => ResultCode::REFUSED, - 0 | _ => ResultCode::NOERROR, + _ => ResultCode::NOERROR, } } @@ -57,6 +57,12 @@ pub struct DnsHeader { pub resource_entries: u16, } +impl Default for DnsHeader { + fn default() -> Self { + Self::new() + } +} + impl DnsHeader { pub fn new() -> DnsHeader { DnsHeader { @@ -112,7 +118,7 @@ impl DnsHeader { | ((self.truncated_message as u8) << 1) | ((self.authoritative_answer as u8) << 2) | (self.opcode << 3) - | ((self.response as u8) << 7) as u8, + | ((self.response as u8) << 7), )?; buffer.write_u8( diff --git a/src/main.rs b/src/main.rs index e2624c5..39e7811 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,10 +31,13 @@ async fn main() -> dns_fun::Result<()> { .format_timestamp_millis() .init(); - let config_path = std::env::args().nth(1).unwrap_or_else(|| "dns_fun.toml".to_string()); + let config_path = std::env::args() + .nth(1) + .unwrap_or_else(|| "dns_fun.toml".to_string()); let config = load_config(&config_path)?; - let upstream: SocketAddr = format!("{}:{}", config.upstream.address, config.upstream.port).parse()?; + let upstream: SocketAddr = + format!("{}:{}", config.upstream.address, config.upstream.port).parse()?; let socket = Arc::new(UdpSocket::bind(&config.server.bind_addr).await?); let ctx = Arc::new(ServerCtx { @@ -110,8 +113,14 @@ async fn handle_query( (resp, QueryPath::Forwarded) } Err(e) => { - error!("{} | {:?} {} | UPSTREAM ERROR | {}", src_addr, qtype, qname, e); - (DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError) + error!( + "{} | {:?} {} | UPSTREAM ERROR | {}", + src_addr, qtype, qname, e + ); + ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + ) } } } @@ -121,13 +130,19 @@ async fn handle_query( info!( "{} | {:?} {} | {} | {} | {}ms", - src_addr, qtype, qname, path.as_str(), - response.header.rescode.as_str(), elapsed.as_millis(), + src_addr, + qtype, + qname, + path.as_str(), + response.header.rescode.as_str(), + elapsed.as_millis(), ); debug!( "response: {} answers, {} authorities, {} resources", - response.answers.len(), response.authorities.len(), response.resources.len(), + response.answers.len(), + response.authorities.len(), + response.resources.len(), ); let mut resp_buffer = BytePacketBuffer::new(); @@ -137,7 +152,7 @@ async fn handle_query( // Record stats and log summary every 1000 queries (single lock acquisition) let mut s = ctx.stats.lock().unwrap(); let total = s.record(path); - if total % 1000 == 0 { + if total.is_multiple_of(1000) { s.log_summary(); } diff --git a/src/packet.rs b/src/packet.rs index 2098257..f6845aa 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -13,6 +13,12 @@ pub struct DnsPacket { pub resources: Vec, } +impl Default for DnsPacket { + fn default() -> Self { + Self::new() + } +} + impl DnsPacket { pub fn new() -> DnsPacket { DnsPacket { diff --git a/src/record.rs b/src/record.rs index ffc6ef3..b138b79 100644 --- a/src/record.rs +++ b/src/record.rs @@ -82,7 +82,7 @@ impl DnsRecord { ((raw_addr >> 24) & 0xFF) as u8, ((raw_addr >> 16) & 0xFF) as u8, ((raw_addr >> 8) & 0xFF) as u8, - ((raw_addr >> 0) & 0xFF) as u8, + (raw_addr & 0xFF) as u8, ); Ok(DnsRecord::A { domain, addr, ttl }) @@ -94,13 +94,13 @@ impl DnsRecord { let raw_addr4 = buffer.read_u32()?; let addr = Ipv6Addr::new( ((raw_addr1 >> 16) & 0xFFFF) as u16, - ((raw_addr1 >> 0) & 0xFFFF) as u16, + (raw_addr1 & 0xFFFF) as u16, ((raw_addr2 >> 16) & 0xFFFF) as u16, - ((raw_addr2 >> 0) & 0xFFFF) as u16, + (raw_addr2 & 0xFFFF) as u16, ((raw_addr3 >> 16) & 0xFFFF) as u16, - ((raw_addr3 >> 0) & 0xFFFF) as u16, + (raw_addr3 & 0xFFFF) as u16, ((raw_addr4 >> 16) & 0xFFFF) as u16, - ((raw_addr4 >> 0) & 0xFFFF) as u16, + (raw_addr4 & 0xFFFF) as u16, ); Ok(DnsRecord::AAAA { domain, addr, ttl }) diff --git a/src/stats.rs b/src/stats.rs index d9db0e3..3f50e85 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -30,6 +30,12 @@ impl QueryPath { } } +impl Default for ServerStats { + fn default() -> Self { + Self::new() + } +} + impl ServerStats { pub fn new() -> Self { ServerStats { -- 2.34.1 From f627b03e8068147c8e289c2cad31a617ee15d704 Mon Sep 17 00:00:00 2001 From: Razvan Dimescu Date: Tue, 10 Mar 2026 05:18:59 +0200 Subject: [PATCH 4/4] gitignore docs/ directory Co-Authored-By: Claude Opus 4.6 --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index cfa6940..1b715be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target CLAUDE.md +docs/ -- 2.34.1