diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..540f041 --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +.PHONY: all build lint fmt check test clean + +all: lint build + +build: + cargo build + +lint: fmt check + +fmt: + cargo fmt --check + +check: + cargo clippy -- -D warnings + +test: + cargo test + +clean: + cargo clean diff --git a/src/buffer.rs b/src/buffer.rs index 5c82f0e..c9378d0 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,10 +1,16 @@ -use crate::{Result}; +use crate::Result; pub struct BytePacketBuffer { pub buf: [u8; 512], pub pos: usize, } +impl Default for BytePacketBuffer { + fn default() -> Self { + Self::new() + } +} + impl BytePacketBuffer { pub fn new() -> BytePacketBuffer { BytePacketBuffer { @@ -63,7 +69,7 @@ impl BytePacketBuffer { let res = ((self.read()? as u32) << 24) | ((self.read()? as u32) << 16) | ((self.read()? as u32) << 8) - | ((self.read()? as u32) << 0); + | (self.read()? as u32); Ok(res) } @@ -144,7 +150,7 @@ impl BytePacketBuffer { self.write(((val >> 24) & 0xFF) as u8)?; self.write(((val >> 16) & 0xFF) as u8)?; self.write(((val >> 8) & 0xFF) as u8)?; - self.write(((val >> 0) & 0xFF) as u8)?; + self.write((val & 0xFF) as u8)?; Ok(()) } diff --git a/src/cache.rs b/src/cache.rs index 6dd2e45..65629fc 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -34,7 +34,7 @@ impl DnsCache { self.query_count += 1; // Periodic eviction every 1000 queries - if self.query_count % 1000 == 0 { + if self.query_count.is_multiple_of(1000) { self.evict_expired(); } @@ -72,15 +72,19 @@ impl DnsCache { .clamp(self.min_ttl, self.max_ttl); let key = (domain.to_string(), qtype); - self.entries.insert(key, CacheEntry { - packet: packet.clone(), - inserted_at: Instant::now(), - ttl: Duration::from_secs(min_ttl as u64), - }); + self.entries.insert( + key, + CacheEntry { + packet: packet.clone(), + inserted_at: Instant::now(), + ttl: Duration::from_secs(min_ttl as u64), + }, + ); } fn evict_expired(&mut self) { - self.entries.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl); + self.entries + .retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl); } } diff --git a/src/config.rs b/src/config.rs index 3d13f23..1cd5c61 100644 --- a/src/config.rs +++ b/src/config.rs @@ -126,36 +126,77 @@ pub fn load_config(path: &str) -> Result { Ok(config) } -pub fn build_zone_map(zones: &[ZoneRecord]) -> Result>> { +pub fn build_zone_map( + zones: &[ZoneRecord], +) -> Result>> { let mut map: HashMap<(String, QueryType), Vec> = HashMap::new(); for zone in zones { let domain = zone.domain.to_lowercase(); let (qtype, record) = match zone.record_type.to_uppercase().as_str() { "A" => { - let addr: Ipv4Addr = zone.value.parse() + let addr: Ipv4Addr = zone + .value + .parse() .map_err(|e| format!("invalid A record value '{}': {}", zone.value, e))?; - (QueryType::A, DnsRecord::A { domain: domain.clone(), addr, ttl: zone.ttl }) + ( + QueryType::A, + DnsRecord::A { + domain: domain.clone(), + addr, + ttl: zone.ttl, + }, + ) } "AAAA" => { - let addr: Ipv6Addr = zone.value.parse() + let addr: Ipv6Addr = zone + .value + .parse() .map_err(|e| format!("invalid AAAA record value '{}': {}", zone.value, e))?; - (QueryType::AAAA, DnsRecord::AAAA { domain: domain.clone(), addr, ttl: zone.ttl }) - } - "CNAME" => { - (QueryType::CNAME, DnsRecord::CNAME { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl }) - } - "NS" => { - (QueryType::NS, DnsRecord::NS { domain: domain.clone(), host: zone.value.clone(), ttl: zone.ttl }) + ( + QueryType::AAAA, + DnsRecord::AAAA { + domain: domain.clone(), + addr, + ttl: zone.ttl, + }, + ) } + "CNAME" => ( + QueryType::CNAME, + DnsRecord::CNAME { + domain: domain.clone(), + host: zone.value.clone(), + ttl: zone.ttl, + }, + ), + "NS" => ( + QueryType::NS, + DnsRecord::NS { + domain: domain.clone(), + host: zone.value.clone(), + ttl: zone.ttl, + }, + ), "MX" => { let parts: Vec<&str> = zone.value.splitn(2, ' ').collect(); if parts.len() != 2 { - return Err(format!("MX value must be 'priority host', got '{}'", zone.value).into()); + return Err( + format!("MX value must be 'priority host', got '{}'", zone.value).into(), + ); } - let priority: u16 = parts[0].parse() + let priority: u16 = parts[0] + .parse() .map_err(|e| format!("invalid MX priority '{}': {}", parts[0], e))?; - (QueryType::MX, DnsRecord::MX { domain: domain.clone(), priority, host: parts[1].to_string(), ttl: zone.ttl }) + ( + QueryType::MX, + DnsRecord::MX { + domain: domain.clone(), + priority, + host: parts[1].to_string(), + ttl: zone.ttl, + }, + ) } other => { return Err(format!("unsupported record type '{}'", other).into()); diff --git a/src/header.rs b/src/header.rs index 2ce42a1..837e1ea 100644 --- a/src/header.rs +++ b/src/header.rs @@ -19,7 +19,7 @@ impl ResultCode { 3 => ResultCode::NXDOMAIN, 4 => ResultCode::NOTIMP, 5 => ResultCode::REFUSED, - 0 | _ => ResultCode::NOERROR, + _ => ResultCode::NOERROR, } } @@ -57,6 +57,12 @@ pub struct DnsHeader { pub resource_entries: u16, } +impl Default for DnsHeader { + fn default() -> Self { + Self::new() + } +} + impl DnsHeader { pub fn new() -> DnsHeader { DnsHeader { @@ -112,7 +118,7 @@ impl DnsHeader { | ((self.truncated_message as u8) << 1) | ((self.authoritative_answer as u8) << 2) | (self.opcode << 3) - | ((self.response as u8) << 7) as u8, + | ((self.response as u8) << 7), )?; buffer.write_u8( diff --git a/src/main.rs b/src/main.rs index e2624c5..39e7811 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,10 +31,13 @@ async fn main() -> dns_fun::Result<()> { .format_timestamp_millis() .init(); - let config_path = std::env::args().nth(1).unwrap_or_else(|| "dns_fun.toml".to_string()); + let config_path = std::env::args() + .nth(1) + .unwrap_or_else(|| "dns_fun.toml".to_string()); let config = load_config(&config_path)?; - let upstream: SocketAddr = format!("{}:{}", config.upstream.address, config.upstream.port).parse()?; + let upstream: SocketAddr = + format!("{}:{}", config.upstream.address, config.upstream.port).parse()?; let socket = Arc::new(UdpSocket::bind(&config.server.bind_addr).await?); let ctx = Arc::new(ServerCtx { @@ -110,8 +113,14 @@ async fn handle_query( (resp, QueryPath::Forwarded) } Err(e) => { - error!("{} | {:?} {} | UPSTREAM ERROR | {}", src_addr, qtype, qname, e); - (DnsPacket::response_from(&query, ResultCode::SERVFAIL), QueryPath::UpstreamError) + error!( + "{} | {:?} {} | UPSTREAM ERROR | {}", + src_addr, qtype, qname, e + ); + ( + DnsPacket::response_from(&query, ResultCode::SERVFAIL), + QueryPath::UpstreamError, + ) } } } @@ -121,13 +130,19 @@ async fn handle_query( info!( "{} | {:?} {} | {} | {} | {}ms", - src_addr, qtype, qname, path.as_str(), - response.header.rescode.as_str(), elapsed.as_millis(), + src_addr, + qtype, + qname, + path.as_str(), + response.header.rescode.as_str(), + elapsed.as_millis(), ); debug!( "response: {} answers, {} authorities, {} resources", - response.answers.len(), response.authorities.len(), response.resources.len(), + response.answers.len(), + response.authorities.len(), + response.resources.len(), ); let mut resp_buffer = BytePacketBuffer::new(); @@ -137,7 +152,7 @@ async fn handle_query( // Record stats and log summary every 1000 queries (single lock acquisition) let mut s = ctx.stats.lock().unwrap(); let total = s.record(path); - if total % 1000 == 0 { + if total.is_multiple_of(1000) { s.log_summary(); } diff --git a/src/packet.rs b/src/packet.rs index 2098257..f6845aa 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -13,6 +13,12 @@ pub struct DnsPacket { pub resources: Vec, } +impl Default for DnsPacket { + fn default() -> Self { + Self::new() + } +} + impl DnsPacket { pub fn new() -> DnsPacket { DnsPacket { diff --git a/src/record.rs b/src/record.rs index ffc6ef3..b138b79 100644 --- a/src/record.rs +++ b/src/record.rs @@ -82,7 +82,7 @@ impl DnsRecord { ((raw_addr >> 24) & 0xFF) as u8, ((raw_addr >> 16) & 0xFF) as u8, ((raw_addr >> 8) & 0xFF) as u8, - ((raw_addr >> 0) & 0xFF) as u8, + (raw_addr & 0xFF) as u8, ); Ok(DnsRecord::A { domain, addr, ttl }) @@ -94,13 +94,13 @@ impl DnsRecord { let raw_addr4 = buffer.read_u32()?; let addr = Ipv6Addr::new( ((raw_addr1 >> 16) & 0xFFFF) as u16, - ((raw_addr1 >> 0) & 0xFFFF) as u16, + (raw_addr1 & 0xFFFF) as u16, ((raw_addr2 >> 16) & 0xFFFF) as u16, - ((raw_addr2 >> 0) & 0xFFFF) as u16, + (raw_addr2 & 0xFFFF) as u16, ((raw_addr3 >> 16) & 0xFFFF) as u16, - ((raw_addr3 >> 0) & 0xFFFF) as u16, + (raw_addr3 & 0xFFFF) as u16, ((raw_addr4 >> 16) & 0xFFFF) as u16, - ((raw_addr4 >> 0) & 0xFFFF) as u16, + (raw_addr4 & 0xFFFF) as u16, ); Ok(DnsRecord::AAAA { domain, addr, ttl }) diff --git a/src/stats.rs b/src/stats.rs index d9db0e3..3f50e85 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -30,6 +30,12 @@ impl QueryPath { } } +impl Default for ServerStats { + fn default() -> Self { + Self::new() + } +} + impl ServerStats { pub fn new() -> Self { ServerStats {