diff --git a/src/buffer.rs b/src/buffer.rs index 79473d8..2308813 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -131,7 +131,12 @@ impl BytePacketBuffer { b'.' => outstr.push_str("\\."), b'\\' => outstr.push_str("\\\\"), 0x21..=0x7E => outstr.push(c as char), - _ => outstr.push_str(&format!("\\{:03}", c)), + _ => { + outstr.push('\\'); + outstr.push((b'0' + c / 100) as char); + outstr.push((b'0' + (c / 10) % 10) as char); + outstr.push((b'0' + c % 10) as char); + } } } @@ -178,20 +183,70 @@ impl BytePacketBuffer { /// /// Dots are label separators unless escaped as `\.`. `\\` yields a literal /// backslash, and `\DDD` (three decimal digits) yields an arbitrary byte. + /// + /// Streams directly into the buffer by reserving a length byte, writing + /// the label body, then backpatching the length. Zero intermediate + /// allocations on the common path. pub fn write_qname(&mut self, qname: &str) -> Result<()> { if qname.is_empty() || qname == "." { self.write_u8(0)?; return Ok(()); } - let labels = parse_escaped_labels(qname)?; - for label in &labels { - if label.len() > 0x3f { - return Err("Single label exceeds 63 characters of length".into()); + let bytes = qname.as_bytes(); + let mut i = 0; + while i < bytes.len() { + let len_pos = self.pos; + self.write_u8(0)?; // placeholder length byte, backpatched below + let body_start = self.pos; + + while i < bytes.len() && bytes[i] != b'.' { + let b = bytes[i]; + if b == b'\\' { + i += 1; + let c1 = *bytes.get(i).ok_or("trailing backslash in qname")?; + if c1.is_ascii_digit() { + let c2 = *bytes + .get(i + 1) + .ok_or("invalid \\DDD escape: expected 3 digits")?; + let c3 = *bytes + .get(i + 2) + .ok_or("invalid \\DDD escape: expected 3 digits")?; + if !c2.is_ascii_digit() || !c3.is_ascii_digit() { + return Err("invalid \\DDD escape: expected 3 digits".into()); + } + let val = + (c1 - b'0') as u16 * 100 + (c2 - b'0') as u16 * 10 + (c3 - b'0') as u16; + if val > 255 { + return Err(format!("\\DDD escape out of range: {}", val).into()); + } + self.write_u8(val as u8)?; + i += 3; + } else { + // \. \\ and any other \X → literal next byte + self.write_u8(c1)?; + i += 1; + } + } else { + self.write_u8(b)?; + i += 1; + } + + if self.pos - body_start > 0x3f { + return Err("Single label exceeds 63 characters of length".into()); + } } - self.write_u8(label.len() as u8)?; - for b in label { - self.write_u8(*b)?; + + let label_len = self.pos - body_start; + if label_len == 0 && i < bytes.len() { + // Empty label from leading/consecutive dots — roll back the placeholder. + self.pos = len_pos; + } else { + self.set(len_pos, label_len as u8)?; + } + + if i < bytes.len() && bytes[i] == b'.' { + i += 1; } } @@ -224,50 +279,6 @@ impl BytePacketBuffer { } } -fn parse_escaped_labels(qname: &str) -> Result>> { - let mut labels: Vec> = Vec::new(); - let mut current: Vec = Vec::new(); - let mut chars = qname.chars(); - - while let Some(c) = chars.next() { - if c == '\\' { - match chars.next() { - Some(d1) if d1.is_ascii_digit() => { - let d2 = chars - .next() - .and_then(|c| c.to_digit(10)) - .ok_or("invalid \\DDD escape: expected 3 digits")?; - let d3 = chars - .next() - .and_then(|c| c.to_digit(10)) - .ok_or("invalid \\DDD escape: expected 3 digits")?; - let val = d1.to_digit(10).unwrap() * 100 + d2 * 10 + d3; - if val > 255 { - return Err(format!("\\DDD escape out of range: {}", val).into()); - } - current.push(val as u8); - } - Some(other) => { - let mut buf = [0u8; 4]; - current.extend_from_slice(other.encode_utf8(&mut buf).as_bytes()); - } - None => return Err("trailing backslash in qname".into()), - } - } else if c == '.' { - if !current.is_empty() { - labels.push(std::mem::take(&mut current)); - } - } else { - let mut buf = [0u8; 4]; - current.extend_from_slice(c.encode_utf8(&mut buf).as_bytes()); - } - } - if !current.is_empty() { - labels.push(current); - } - Ok(labels) -} - #[cfg(test)] mod tests { use super::*;