Files
wifi-densepose/crates/ruvector-fpga-transformer/src/artifact/pack.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

305 lines
10 KiB
Rust

//! Artifact packing and unpacking
use std::io::{Read, Write};
use std::path::Path;
use crate::artifact::{ModelArtifact, TestVector};
use crate::error::{Error, Result};
/// Magic bytes for artifact file format
const ARTIFACT_MAGIC: &[u8; 4] = b"RVAT"; // RuVector ArTifact
const ARTIFACT_VERSION: u16 = 1;
// Security: Maximum size limits to prevent DoS via unbounded allocations
/// Maximum manifest size (1 MB)
const MAX_MANIFEST_SIZE: usize = 1024 * 1024;
/// Maximum weights size (1 GB)
const MAX_WEIGHTS_SIZE: usize = 1024 * 1024 * 1024;
/// Maximum bitstream/calibration size (100 MB)
const MAX_BLOB_SIZE: usize = 100 * 1024 * 1024;
/// Maximum number of test vectors
const MAX_TEST_VECTORS: usize = 10_000;
/// Maximum tokens per test vector
const MAX_TOKENS_PER_VECTOR: usize = 65_536;
/// Maximum expected values per test vector
const MAX_EXPECTED_PER_VECTOR: usize = 1_000_000;
/// Pack an artifact to bytes
pub fn pack_artifact(artifact: &ModelArtifact) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
// Write magic and version
buffer.extend_from_slice(ARTIFACT_MAGIC);
buffer.extend_from_slice(&ARTIFACT_VERSION.to_le_bytes());
// Write manifest as JSON with length prefix
let manifest_json = serde_json::to_string(&artifact.manifest)?;
let manifest_bytes = manifest_json.as_bytes();
buffer.extend_from_slice(&(manifest_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(manifest_bytes);
// Write weights with length prefix
buffer.extend_from_slice(&(artifact.weights.len() as u64).to_le_bytes());
buffer.extend_from_slice(&artifact.weights);
// Write optional bitstream
if let Some(ref bitstream) = artifact.bitstream {
buffer.push(1); // Present flag
buffer.extend_from_slice(&(bitstream.len() as u64).to_le_bytes());
buffer.extend_from_slice(bitstream);
} else {
buffer.push(0); // Not present
}
// Write optional calibration
if let Some(ref calibration) = artifact.calibration {
buffer.push(1);
buffer.extend_from_slice(&(calibration.len() as u64).to_le_bytes());
buffer.extend_from_slice(calibration);
} else {
buffer.push(0);
}
// Write test vectors
buffer.extend_from_slice(&(artifact.test_vectors.len() as u32).to_le_bytes());
for vector in &artifact.test_vectors {
// Write tokens
buffer.extend_from_slice(&(vector.tokens.len() as u16).to_le_bytes());
for &token in &vector.tokens {
buffer.extend_from_slice(&token.to_le_bytes());
}
// Write expected
buffer.extend_from_slice(&(vector.expected.len() as u32).to_le_bytes());
for &exp in &vector.expected {
buffer.extend_from_slice(&exp.to_le_bytes());
}
// Write max_abs_err
buffer.extend_from_slice(&vector.max_abs_err.to_le_bytes());
}
// Write signature and pubkey
buffer.extend_from_slice(&artifact.signature);
buffer.extend_from_slice(&artifact.pubkey);
Ok(buffer)
}
/// Unpack an artifact from bytes
pub fn unpack_artifact(data: &[u8]) -> Result<ModelArtifact> {
let mut cursor = std::io::Cursor::new(data);
let mut read_buf = [0u8; 8];
// Read and verify magic
cursor.read_exact(&mut read_buf[..4])?;
if &read_buf[..4] != ARTIFACT_MAGIC {
return Err(Error::InvalidArtifact("Invalid magic bytes".into()));
}
// Read version
cursor.read_exact(&mut read_buf[..2])?;
let version = u16::from_le_bytes([read_buf[0], read_buf[1]]);
if version != ARTIFACT_VERSION {
return Err(Error::InvalidArtifact(format!(
"Unsupported version: {}",
version
)));
}
// Read manifest
cursor.read_exact(&mut read_buf[..4])?;
let manifest_len = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if manifest_len > MAX_MANIFEST_SIZE {
return Err(Error::InvalidArtifact(format!(
"Manifest size {} exceeds maximum {}",
manifest_len, MAX_MANIFEST_SIZE
)));
}
let mut manifest_bytes = vec![0u8; manifest_len];
cursor.read_exact(&mut manifest_bytes)?;
let manifest = serde_json::from_slice(&manifest_bytes)?;
// Read weights
cursor.read_exact(&mut read_buf)?;
let weights_len = u64::from_le_bytes(read_buf) as usize;
if weights_len > MAX_WEIGHTS_SIZE {
return Err(Error::InvalidArtifact(format!(
"Weights size {} exceeds maximum {}",
weights_len, MAX_WEIGHTS_SIZE
)));
}
let mut weights = vec![0u8; weights_len];
cursor.read_exact(&mut weights)?;
// Read optional bitstream
cursor.read_exact(&mut read_buf[..1])?;
let bitstream = if read_buf[0] == 1 {
cursor.read_exact(&mut read_buf)?;
let len = u64::from_le_bytes(read_buf) as usize;
if len > MAX_BLOB_SIZE {
return Err(Error::InvalidArtifact(format!(
"Bitstream size {} exceeds maximum {}",
len, MAX_BLOB_SIZE
)));
}
let mut data = vec![0u8; len];
cursor.read_exact(&mut data)?;
Some(data)
} else {
None
};
// Read optional calibration
cursor.read_exact(&mut read_buf[..1])?;
let calibration = if read_buf[0] == 1 {
cursor.read_exact(&mut read_buf)?;
let len = u64::from_le_bytes(read_buf) as usize;
if len > MAX_BLOB_SIZE {
return Err(Error::InvalidArtifact(format!(
"Calibration size {} exceeds maximum {}",
len, MAX_BLOB_SIZE
)));
}
let mut data = vec![0u8; len];
cursor.read_exact(&mut data)?;
Some(data)
} else {
None
};
// Read test vectors
cursor.read_exact(&mut read_buf[..4])?;
let num_vectors = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if num_vectors > MAX_TEST_VECTORS {
return Err(Error::InvalidArtifact(format!(
"Test vector count {} exceeds maximum {}",
num_vectors, MAX_TEST_VECTORS
)));
}
let mut test_vectors = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
// Read tokens
cursor.read_exact(&mut read_buf[..2])?;
let num_tokens = u16::from_le_bytes([read_buf[0], read_buf[1]]) as usize;
if num_tokens > MAX_TOKENS_PER_VECTOR {
return Err(Error::InvalidArtifact(format!(
"Token count {} exceeds maximum {}",
num_tokens, MAX_TOKENS_PER_VECTOR
)));
}
let mut tokens = Vec::with_capacity(num_tokens);
for _ in 0..num_tokens {
cursor.read_exact(&mut read_buf[..2])?;
tokens.push(u16::from_le_bytes([read_buf[0], read_buf[1]]));
}
// Read expected
cursor.read_exact(&mut read_buf[..4])?;
let num_expected = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if num_expected > MAX_EXPECTED_PER_VECTOR {
return Err(Error::InvalidArtifact(format!(
"Expected values count {} exceeds maximum {}",
num_expected, MAX_EXPECTED_PER_VECTOR
)));
}
let mut expected = Vec::with_capacity(num_expected);
for _ in 0..num_expected {
cursor.read_exact(&mut read_buf[..2])?;
expected.push(i16::from_le_bytes([read_buf[0], read_buf[1]]));
}
// Read max_abs_err
cursor.read_exact(&mut read_buf[..4])?;
let max_abs_err = i32::from_le_bytes(read_buf[..4].try_into().unwrap());
test_vectors.push(TestVector {
tokens,
expected,
max_abs_err,
});
}
// Read signature and pubkey
let mut signature = [0u8; 64];
cursor.read_exact(&mut signature)?;
let mut pubkey = [0u8; 32];
cursor.read_exact(&mut pubkey)?;
Ok(ModelArtifact {
manifest,
weights,
bitstream,
calibration,
test_vectors,
signature,
pubkey,
})
}
/// Save artifact to file
pub fn save_artifact(artifact: &ModelArtifact, path: impl AsRef<Path>) -> Result<()> {
let data = pack_artifact(artifact)?;
std::fs::write(path, data)?;
Ok(())
}
/// Load artifact from file
pub fn load_artifact(path: impl AsRef<Path>) -> Result<ModelArtifact> {
let data = std::fs::read(path)?;
unpack_artifact(&data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::types::{FixedShape, QuantSpec};
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "test_pack".into(),
model_hash: "abc123".into(),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact {
manifest,
weights: (0..5000).map(|i| (i % 256) as u8).collect(),
bitstream: Some(vec![0xFF; 100]),
calibration: None,
test_vectors: vec![TestVector {
tokens: vec![1, 2, 3],
expected: vec![100, 200, 300],
max_abs_err: 5,
}],
signature: [0x42u8; 64],
pubkey: [0x24u8; 32],
}
}
#[test]
fn test_pack_unpack_roundtrip() {
let original = create_test_artifact();
let packed = pack_artifact(&original).unwrap();
let unpacked = unpack_artifact(&packed).unwrap();
assert_eq!(original.manifest.name, unpacked.manifest.name);
assert_eq!(original.weights, unpacked.weights);
assert_eq!(original.bitstream, unpacked.bitstream);
assert_eq!(original.calibration, unpacked.calibration);
assert_eq!(original.test_vectors.len(), unpacked.test_vectors.len());
assert_eq!(original.signature, unpacked.signature);
assert_eq!(original.pubkey, unpacked.pubkey);
}
#[test]
fn test_invalid_magic() {
let data = b"XXXX0000";
assert!(unpack_artifact(data).is_err());
}
}