Async tokio runtime with modular architecture #1

Merged
razvandimescu merged 4 commits from feat/async-tokio into main 2026-03-10 11:35:26 +08:00
9 changed files with 143 additions and 39 deletions
Showing only changes of commit 2c6133344a - Show all commits

20
Makefile Normal file
View File

@@ -0,0 +1,20 @@
.PHONY: all build lint fmt check test clean
all: lint build
build:
cargo build
lint: fmt check
fmt:
cargo fmt --check
check:
cargo clippy -- -D warnings
test:
cargo test
clean:
cargo clean

View File

@@ -1,10 +1,16 @@
use crate::{Result};
use crate::Result;
pub struct BytePacketBuffer {
pub buf: [u8; 512],
pub pos: usize,
}
impl Default for BytePacketBuffer {
fn default() -> Self {
Self::new()
}
}
impl BytePacketBuffer {
pub fn new() -> BytePacketBuffer {
BytePacketBuffer {
@@ -63,7 +69,7 @@ impl BytePacketBuffer {
let res = ((self.read()? as u32) << 24)
| ((self.read()? as u32) << 16)
| ((self.read()? as u32) << 8)
| ((self.read()? as u32) << 0);
| (self.read()? as u32);
Ok(res)
}
@@ -144,7 +150,7 @@ impl BytePacketBuffer {
self.write(((val >> 24) & 0xFF) as u8)?;
self.write(((val >> 16) & 0xFF) as u8)?;
self.write(((val >> 8) & 0xFF) as u8)?;
self.write(((val >> 0) & 0xFF) as u8)?;
self.write((val & 0xFF) as u8)?;
Ok(())
}

View File

@@ -34,7 +34,7 @@ impl DnsCache {
self.query_count += 1;
// Periodic eviction every 1000 queries
if self.query_count % 1000 == 0 {
if self.query_count.is_multiple_of(1000) {
self.evict_expired();
}
@@ -72,15 +72,19 @@ impl DnsCache {
.clamp(self.min_ttl, self.max_ttl);
let key = (domain.to_string(), qtype);
self.entries.insert(key, CacheEntry {
packet: packet.clone(),
inserted_at: Instant::now(),
ttl: Duration::from_secs(min_ttl as u64),
});
self.entries.insert(
key,
CacheEntry {
packet: packet.clone(),
inserted_at: Instant::now(),
ttl: Duration::from_secs(min_ttl as u64),
},
);
}
fn evict_expired(&mut self) {
self.entries.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl);
self.entries
.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl);
}
}

View File

@@ -126,36 +126,77 @@ pub fn load_config(path: &str) -> Result<Config> {
Ok(config)
}
pub fn build_zone_map(zones: &[ZoneRecord]) -> Result<HashMap<(String, QueryType), Vec<DnsRecord>>> {
pub fn build_zone_map(
zones: &[ZoneRecord],
) -> Result<HashMap<(String, QueryType), Vec<DnsRecord>>> {
let mut map: HashMap<(String, QueryType), Vec<DnsRecord>> = HashMap::new();
for zone in zones {
let domain = zone.domain.to_lowercase();
let (qtype, record) = match zone.record_type.to_uppercase().as_str() {
"A" => {
let addr: Ipv4Addr = zone.value.parse()
let addr: Ipv4Addr = zone
.value
.parse()
.map_err(|e| format!("invalid A record value '{}': {}", zone.value, e))?;
(QueryType::A, DnsRecord::A { domain: domain.clone(), addr, ttl: zone.ttl })
(
QueryType::A,
DnsRecord::A {
domain: domain.clone(),
addr,
ttl: zone.ttl,
},
)
}
"AAAA" => {
let addr: Ipv6Addr = zone.value.parse()
let addr: Ipv6Addr = zone
.value
.parse()
.map_err(|e| format!("invalid AAAA record value '{}': {}", zone.value, e))?;
(QueryType::AAAA, DnsRecord::AAAA { domain: domain.clone(), addr, ttl: zone.ttl })
}
"CNAME" => {
(QueryType::CNAME, DnsRecord::CNAME { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl })
}
"NS" => {
(QueryType::NS, DnsRecord::NS { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl })
(
QueryType::AAAA,
DnsRecord::AAAA {
domain: domain.clone(),
addr,
ttl: zone.ttl,
},
)
}
"CNAME" => (
QueryType::CNAME,
DnsRecord::CNAME {
domain: domain.clone(),
host: zone.value.clone(),
ttl: zone.ttl,
},
),
"NS" => (
QueryType::NS,
DnsRecord::NS {
domain: domain.clone(),
host: zone.value.clone(),
ttl: zone.ttl,
},
),
"MX" => {
let parts: Vec<&str> = zone.value.splitn(2, ' ').collect();
if parts.len() != 2 {
return Err(format!("MX value must be 'priority host', got '{}'", zone.value).into());
return Err(
format!("MX value must be 'priority host', got '{}'", zone.value).into(),
);
}
let priority: u16 = parts[0].parse()
let priority: u16 = parts[0]
.parse()
.map_err(|e| format!("invalid MX priority '{}': {}", parts[0], e))?;
(QueryType::MX, DnsRecord::MX { domain: domain.clone(), priority, host: parts[1].to_string(), ttl: zone.ttl })
(
QueryType::MX,
DnsRecord::MX {
domain: domain.clone(),
priority,
host: parts[1].to_string(),
ttl: zone.ttl,
},
)
}
other => {
return Err(format!("unsupported record type '{}'", other).into());

View File

@@ -19,7 +19,7 @@ impl ResultCode {
3 => ResultCode::NXDOMAIN,
4 => ResultCode::NOTIMP,
5 => ResultCode::REFUSED,
0 | _ => ResultCode::NOERROR,
_ => ResultCode::NOERROR,
}
}
@@ -57,6 +57,12 @@ pub struct DnsHeader {
pub resource_entries: u16,
}
impl Default for DnsHeader {
fn default() -> Self {
Self::new()
}
}
impl DnsHeader {
pub fn new() -> DnsHeader {
DnsHeader {
@@ -112,7 +118,7 @@ impl DnsHeader {
| ((self.truncated_message as u8) << 1)
| ((self.authoritative_answer as u8) << 2)
| (self.opcode << 3)
| ((self.response as u8) << 7) as u8,
| ((self.response as u8) << 7),
)?;
buffer.write_u8(

View File

@@ -31,10 +31,13 @@ async fn main() -> dns_fun::Result<()> {
.format_timestamp_millis()
.init();
let config_path = std::env::args().nth(1).unwrap_or_else(|| "dns_fun.toml".to_string());
let config_path = std::env::args()
.nth(1)
.unwrap_or_else(|| "dns_fun.toml".to_string());
let config = load_config(&config_path)?;
let upstream: SocketAddr = format!("{}:{}", config.upstream.address, config.upstream.port).parse()?;
let upstream: SocketAddr =
format!("{}:{}", config.upstream.address, config.upstream.port).parse()?;
let socket = Arc::new(UdpSocket::bind(&config.server.bind_addr).await?);
let ctx = Arc::new(ServerCtx {
@@ -110,8 +113,14 @@ async fn handle_query(
(resp, QueryPath::Forwarded)
}
Err(e) => {
error!("{} | {:?} {} | UPSTREAM ERROR | {}", src_addr, qtype, qname, e);
(DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError)
error!(
"{} | {:?} {} | UPSTREAM ERROR | {}",
src_addr, qtype, qname, e
);
(
DnsPacket::response_from(&query, ResultCode::SERVFAIL),
QueryPath::UpstreamError,
)
}
}
}
@@ -121,13 +130,19 @@ async fn handle_query(
info!(
"{} | {:?} {} | {} | {} | {}ms",
src_addr, qtype, qname, path.as_str(),
response.header.rescode.as_str(), elapsed.as_millis(),
src_addr,
qtype,
qname,
path.as_str(),
response.header.rescode.as_str(),
elapsed.as_millis(),
);
debug!(
"response: {} answers, {} authorities, {} resources",
response.answers.len(), response.authorities.len(), response.resources.len(),
response.answers.len(),
response.authorities.len(),
response.resources.len(),
);
let mut resp_buffer = BytePacketBuffer::new();
@@ -137,7 +152,7 @@ async fn handle_query(
// Record stats and log summary every 1000 queries (single lock acquisition)
let mut s = ctx.stats.lock().unwrap();
let total = s.record(path);
if total % 1000 == 0 {
if total.is_multiple_of(1000) {
s.log_summary();
}

View File

@@ -13,6 +13,12 @@ pub struct DnsPacket {
pub resources: Vec<DnsRecord>,
}
impl Default for DnsPacket {
fn default() -> Self {
Self::new()
}
}
impl DnsPacket {
pub fn new() -> DnsPacket {
DnsPacket {

View File

@@ -82,7 +82,7 @@ impl DnsRecord {
((raw_addr >> 24) & 0xFF) as u8,
((raw_addr >> 16) & 0xFF) as u8,
((raw_addr >> 8) & 0xFF) as u8,
((raw_addr >> 0) & 0xFF) as u8,
(raw_addr & 0xFF) as u8,
);
Ok(DnsRecord::A { domain, addr, ttl })
@@ -94,13 +94,13 @@ impl DnsRecord {
let raw_addr4 = buffer.read_u32()?;
let addr = Ipv6Addr::new(
((raw_addr1 >> 16) & 0xFFFF) as u16,
((raw_addr1 >> 0) & 0xFFFF) as u16,
(raw_addr1 & 0xFFFF) as u16,
((raw_addr2 >> 16) & 0xFFFF) as u16,
((raw_addr2 >> 0) & 0xFFFF) as u16,
(raw_addr2 & 0xFFFF) as u16,
((raw_addr3 >> 16) & 0xFFFF) as u16,
((raw_addr3 >> 0) & 0xFFFF) as u16,
(raw_addr3 & 0xFFFF) as u16,
((raw_addr4 >> 16) & 0xFFFF) as u16,
((raw_addr4 >> 0) & 0xFFFF) as u16,
(raw_addr4 & 0xFFFF) as u16,
);
Ok(DnsRecord::AAAA { domain, addr, ttl })

View File

@@ -30,6 +30,12 @@ impl QueryPath {
}
}
impl Default for ServerStats {
fn default() -> Self {
Self::new()
}
}
impl ServerStats {
pub fn new() -> Self {
ServerStats {