diff --git a/src/buffer.rs b/src/buffer.rs index c9378d0..0c358e7 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,7 +1,9 @@ use crate::Result; +const BUF_SIZE: usize = 4096; + pub struct BytePacketBuffer { - pub buf: [u8; 512], + pub buf: [u8; BUF_SIZE], pub pos: usize, } @@ -14,7 +16,7 @@ impl Default for BytePacketBuffer { impl BytePacketBuffer { pub fn new() -> BytePacketBuffer { BytePacketBuffer { - buf: [0; 512], + buf: [0; BUF_SIZE], pos: 0, } } @@ -38,7 +40,7 @@ impl BytePacketBuffer { } pub fn read(&mut self) -> Result { - if self.pos >= 512 { + if self.pos >= BUF_SIZE { return Err("End of buffer".into()); } let res = self.buf[self.pos]; @@ -47,14 +49,14 @@ impl BytePacketBuffer { } pub fn get(&self, pos: usize) -> Result { - if pos >= 512 { + if pos >= BUF_SIZE { return Err("End of buffer".into()); } Ok(self.buf[pos]) } pub fn get_range(&self, start: usize, len: usize) -> Result<&[u8]> { - if start + len > 512 { + if start + len > BUF_SIZE { return Err("End of buffer".into()); } Ok(&self.buf[start..start + len]) @@ -128,7 +130,7 @@ impl BytePacketBuffer { } pub fn write(&mut self, val: u8) -> Result<()> { - if self.pos >= 512 { + if self.pos >= BUF_SIZE { return Err("End of buffer".into()); } self.buf[self.pos] = val; @@ -172,7 +174,7 @@ impl BytePacketBuffer { } pub fn set(&mut self, pos: usize, val: u8) -> Result<()> { - if pos >= 512 { + if pos >= BUF_SIZE { return Err("End of buffer".into()); } self.buf[pos] = val; diff --git a/src/packet.rs b/src/packet.rs index f6845aa..9158a77 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -68,24 +68,29 @@ impl DnsPacket { } pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { + // Filter out UNKNOWN records (e.g. EDNS OPT) that we can't re-serialize + let answers: Vec<_> = self.answers.iter().filter(|r| !r.is_unknown()).collect(); + let authorities: Vec<_> = self.authorities.iter().filter(|r| !r.is_unknown()).collect(); + let resources: Vec<_> = self.resources.iter().filter(|r| !r.is_unknown()).collect(); + let mut header = self.header.clone(); header.questions = self.questions.len() as u16; - header.answers = self.answers.len() as u16; - header.authoritative_entries = self.authorities.len() as u16; - header.resource_entries = self.resources.len() as u16; + header.answers = answers.len() as u16; + header.authoritative_entries = authorities.len() as u16; + header.resource_entries = resources.len() as u16; header.write(buffer)?; for question in &self.questions { question.write(buffer)?; } - for rec in &self.answers { + for rec in answers { rec.write(buffer)?; } - for rec in &self.authorities { + for rec in authorities { rec.write(buffer)?; } - for rec in &self.resources { + for rec in resources { rec.write(buffer)?; } diff --git a/src/record.rs b/src/record.rs index 8d65879..ba8e3d0 100644 --- a/src/record.rs +++ b/src/record.rs @@ -43,6 +43,10 @@ pub enum DnsRecord { } impl DnsRecord { + pub fn is_unknown(&self) -> bool { + matches!(self, DnsRecord::UNKNOWN { .. }) + } + pub fn ttl(&self) -> u32 { match self { DnsRecord::A { ttl, .. }