Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,266 @@
//! Manifest schema for model artifacts
use crate::error::{Error, Result};
use crate::types::{FixedShape, Layout, QuantSpec};
use serde::{Deserialize, Serialize};
/// Model manifest containing all metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Manifest {
/// Model name
pub name: String,
/// SHA-256 hash of model (hex string)
pub model_hash: String,
/// Fixed shape specification
pub shape: FixedShape,
/// Quantization specification
pub quant: QuantSpec,
/// I/O configuration
pub io: IoSpec,
/// Backend configuration
pub backend: BackendSpec,
/// Test vector specification
pub tests: TestSpec,
}
impl Manifest {
/// Create a new manifest
pub fn new(name: impl Into<String>, shape: FixedShape, quant: QuantSpec) -> Self {
Self {
name: name.into(),
model_hash: String::new(),
shape,
quant,
io: IoSpec::default(),
backend: BackendSpec::default(),
tests: TestSpec::default(),
}
}
/// Validate manifest consistency
pub fn validate(&self) -> Result<()> {
if self.name.is_empty() {
return Err(Error::InvalidArtifact("Model name is empty".into()));
}
// Validate shape
self.shape
.validate()
.map_err(|e| Error::InvalidArtifact(e))?;
// Validate quantization bits
if !matches!(self.quant.w_bits, 1 | 2 | 4 | 8 | 16) {
return Err(Error::InvalidArtifact(format!(
"Invalid weight bits: {}",
self.quant.w_bits
)));
}
if !matches!(self.quant.a_bits, 4 | 8 | 16 | 32) {
return Err(Error::InvalidArtifact(format!(
"Invalid activation bits: {}",
self.quant.a_bits
)));
}
Ok(())
}
/// Convert to JSON string
pub fn to_json(&self) -> Result<String> {
serde_json::to_string_pretty(self).map_err(Into::into)
}
/// Parse from JSON string
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json).map_err(Into::into)
}
}
/// I/O type specification
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IoSpec {
/// Input token type (typically "u16")
pub tokens: String,
/// Output logit type (typically "i16" or "i32")
pub logits: String,
/// Top-K count (0 for full logits)
pub topk: u16,
}
impl Default for IoSpec {
fn default() -> Self {
Self {
tokens: "u16".into(),
logits: "i16".into(),
topk: 16,
}
}
}
/// Backend-specific configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendSpec {
/// Backend kind ("fpga_pcie", "fpga_daemon", "native_sim", "wasm_sim")
pub kind: String,
/// Protocol version
pub protocol: u16,
/// Backend-specific options
#[serde(default)]
pub options: BackendOptions,
}
impl Default for BackendSpec {
fn default() -> Self {
Self {
kind: "native_sim".into(),
protocol: 1,
options: BackendOptions::default(),
}
}
}
/// Backend-specific options
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BackendOptions {
/// Enable batch processing
#[serde(default)]
pub batch_enabled: bool,
/// Maximum batch size
#[serde(default)]
pub max_batch: u16,
/// Enable early exit
#[serde(default)]
pub early_exit: bool,
/// Minimum coherence threshold for early exit
#[serde(default)]
pub early_exit_threshold: i16,
/// FPGA clock frequency in MHz (for cycle estimation)
#[serde(default)]
pub clock_mhz: u16,
}
/// Test vector specification
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestSpec {
/// Number of test vectors
pub vectors: u32,
/// Maximum absolute error allowed
pub max_abs_err: i32,
/// Whether test vectors must pass before activation
#[serde(default = "default_true")]
pub require_pass: bool,
}
fn default_true() -> bool {
true
}
impl Default for TestSpec {
fn default() -> Self {
Self {
vectors: 0,
max_abs_err: 2,
require_pass: true,
}
}
}
/// Manifest builder for convenient construction
pub struct ManifestBuilder {
manifest: Manifest,
}
impl ManifestBuilder {
/// Create a new builder with name and shape
pub fn new(name: impl Into<String>, shape: FixedShape) -> Self {
Self {
manifest: Manifest::new(name, shape, QuantSpec::int8()),
}
}
/// Set quantization spec
pub fn quant(mut self, quant: QuantSpec) -> Self {
self.manifest.quant = quant;
self
}
/// Set model hash
pub fn model_hash(mut self, hash: impl Into<String>) -> Self {
self.manifest.model_hash = hash.into();
self
}
/// Set I/O spec
pub fn io(mut self, io: IoSpec) -> Self {
self.manifest.io = io;
self
}
/// Set backend spec
pub fn backend(mut self, backend: BackendSpec) -> Self {
self.manifest.backend = backend;
self
}
/// Set test spec
pub fn tests(mut self, tests: TestSpec) -> Self {
self.manifest.tests = tests;
self
}
/// Enable top-K only output
pub fn topk_only(mut self, k: u16) -> Self {
self.manifest.io.topk = k;
self
}
/// Enable early exit
pub fn early_exit(mut self, threshold: i16) -> Self {
self.manifest.backend.options.early_exit = true;
self.manifest.backend.options.early_exit_threshold = threshold;
self
}
/// Build the manifest
pub fn build(self) -> Result<Manifest> {
self.manifest.validate()?;
Ok(self.manifest)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_manifest_builder() {
let manifest = ManifestBuilder::new("test", FixedShape::micro())
.quant(QuantSpec::int4_int8())
.topk_only(16)
.early_exit(100)
.build()
.unwrap();
assert_eq!(manifest.name, "test");
assert_eq!(manifest.quant.w_bits, 4);
assert_eq!(manifest.io.topk, 16);
assert!(manifest.backend.options.early_exit);
}
#[test]
fn test_manifest_json_roundtrip() {
let manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
let json = manifest.to_json().unwrap();
let parsed = Manifest::from_json(&json).unwrap();
assert_eq!(manifest.name, parsed.name);
}
#[test]
fn test_manifest_validation() {
let mut manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
assert!(manifest.validate().is_ok());
manifest.name = String::new();
assert!(manifest.validate().is_err());
}
}

View File

@@ -0,0 +1,242 @@
//! Model artifact format and handling
//!
//! Signed bundles with metadata, weights, and test vectors.
pub mod manifest;
pub mod pack;
pub mod verify;
pub use manifest::{BackendSpec, IoSpec, Manifest, TestSpec};
pub use pack::{pack_artifact, unpack_artifact};
pub use verify::{verify_artifact, verify_signature};
use crate::error::{Error, Result};
use crate::types::{FixedShape, ModelId, QuantSpec};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
/// Complete model artifact
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelArtifact {
/// Manifest with metadata
pub manifest: Manifest,
/// Quantized weights (binary blob)
#[serde(with = "serde_bytes")]
pub weights: Vec<u8>,
/// Optional FPGA bitstream
#[serde(with = "serde_bytes_option")]
pub bitstream: Option<Vec<u8>>,
/// Optional calibration data
#[serde(with = "serde_bytes_option")]
pub calibration: Option<Vec<u8>>,
/// Test vectors for validation
pub test_vectors: Vec<TestVector>,
/// Ed25519 signature over manifest + file hashes
#[serde(with = "serde_bytes")]
pub signature: [u8; 64],
/// Ed25519 public key
#[serde(with = "serde_bytes")]
pub pubkey: [u8; 32],
}
/// Serde helper for Option<Vec<u8>>
mod serde_bytes_option {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(data: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
match data {
Some(bytes) => s.serialize_some(&serde_bytes::Bytes::new(bytes)),
None => s.serialize_none(),
}
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Vec<u8>>, D::Error> {
let opt: Option<serde_bytes::ByteBuf> = Option::deserialize(d)?;
Ok(opt.map(|b| b.into_vec()))
}
}
/// Test vector for model validation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestVector {
/// Input tokens
pub tokens: Vec<u16>,
/// Expected output logits (top-K or full)
pub expected: Vec<i16>,
/// Maximum absolute error allowed
pub max_abs_err: i32,
}
impl ModelArtifact {
/// Create a new artifact (for building/packing)
pub fn new(
manifest: Manifest,
weights: Vec<u8>,
bitstream: Option<Vec<u8>>,
calibration: Option<Vec<u8>>,
test_vectors: Vec<TestVector>,
) -> Self {
Self {
manifest,
weights,
bitstream,
calibration,
test_vectors,
signature: [0u8; 64],
pubkey: [0u8; 32],
}
}
/// Compute model ID (SHA-256 of manifest + weights hash)
pub fn model_id(&self) -> ModelId {
let mut hasher = Sha256::new();
hasher.update(self.manifest.name.as_bytes());
hasher.update(&self.model_hash());
hasher.update(&self.quant_hash());
ModelId::new(hasher.finalize().into())
}
/// Compute hash of model weights
pub fn model_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(&self.weights);
if let Some(ref bitstream) = self.bitstream {
hasher.update(bitstream);
}
hasher.finalize().into()
}
/// Compute hash of quantization parameters
pub fn quant_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
let quant_json = serde_json::to_string(&self.manifest.quant).unwrap_or_default();
hasher.update(quant_json.as_bytes());
if let Some(ref calib) = self.calibration {
hasher.update(calib);
}
hasher.finalize().into()
}
/// Validate artifact integrity
pub fn validate(&self) -> Result<()> {
// Validate manifest
self.manifest.validate()?;
// Validate shape
self.manifest
.shape
.validate()
.map_err(|e| Error::InvalidArtifact(e))?;
// Check weights size is reasonable
let min_weight_size =
self.manifest.shape.embedding_params() / self.manifest.quant.weights_per_byte();
if self.weights.len() < min_weight_size {
return Err(Error::InvalidArtifact(format!(
"Weights too small: {} bytes, expected at least {} for embeddings",
self.weights.len(),
min_weight_size
)));
}
// Validate test vectors if strict mode
#[cfg(feature = "strict_verify")]
self.run_test_vectors()?;
Ok(())
}
/// Run test vectors for validation
#[cfg(feature = "strict_verify")]
pub fn run_test_vectors(&self) -> Result<()> {
// This would require running inference, which creates a circular dependency
// In practice, this is done by the backend after loading
Ok(())
}
/// Get the fixed shape
pub fn shape(&self) -> &FixedShape {
&self.manifest.shape
}
/// Get quantization spec
pub fn quant(&self) -> &QuantSpec {
&self.manifest.quant
}
/// Check if artifact has FPGA bitstream
pub fn has_bitstream(&self) -> bool {
self.bitstream.is_some()
}
/// Estimated memory footprint in bytes
pub fn memory_footprint(&self) -> usize {
self.weights.len()
+ self.bitstream.as_ref().map(|b| b.len()).unwrap_or(0)
+ self.calibration.as_ref().map(|c| c.len()).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_manifest() -> Manifest {
Manifest {
name: "test_model".into(),
model_hash: "0".repeat(64),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: IoSpec::default(),
backend: BackendSpec::default(),
tests: TestSpec::default(),
}
}
#[test]
fn test_model_id_computation() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![]);
let id1 = artifact.model_id();
let id2 = artifact.model_id();
assert_eq!(id1, id2); // Deterministic
}
#[test]
fn test_model_hash() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(manifest, vec![42u8; 4096 * 64], None, None, vec![]);
let hash = artifact.model_hash();
assert_ne!(hash, [0u8; 32]); // Non-zero hash
}
#[test]
fn test_artifact_validation() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(
manifest,
vec![0u8; 4096 * 64], // Enough for micro embeddings
None,
None,
vec![],
);
assert!(artifact.validate().is_ok());
}
#[test]
fn test_artifact_too_small_weights() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(
manifest,
vec![0u8; 100], // Too small
None,
None,
vec![],
);
assert!(artifact.validate().is_err());
}
}

View File

@@ -0,0 +1,304 @@
//! Artifact packing and unpacking
use std::io::{Read, Write};
use std::path::Path;
use crate::artifact::{ModelArtifact, TestVector};
use crate::error::{Error, Result};
/// Magic bytes for artifact file format
const ARTIFACT_MAGIC: &[u8; 4] = b"RVAT"; // RuVector ArTifact
const ARTIFACT_VERSION: u16 = 1;
// Security: Maximum size limits to prevent DoS via unbounded allocations
/// Maximum manifest size (1 MB)
const MAX_MANIFEST_SIZE: usize = 1024 * 1024;
/// Maximum weights size (1 GB)
const MAX_WEIGHTS_SIZE: usize = 1024 * 1024 * 1024;
/// Maximum bitstream/calibration size (100 MB)
const MAX_BLOB_SIZE: usize = 100 * 1024 * 1024;
/// Maximum number of test vectors
const MAX_TEST_VECTORS: usize = 10_000;
/// Maximum tokens per test vector
const MAX_TOKENS_PER_VECTOR: usize = 65_536;
/// Maximum expected values per test vector
const MAX_EXPECTED_PER_VECTOR: usize = 1_000_000;
/// Pack an artifact to bytes
pub fn pack_artifact(artifact: &ModelArtifact) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
// Write magic and version
buffer.extend_from_slice(ARTIFACT_MAGIC);
buffer.extend_from_slice(&ARTIFACT_VERSION.to_le_bytes());
// Write manifest as JSON with length prefix
let manifest_json = serde_json::to_string(&artifact.manifest)?;
let manifest_bytes = manifest_json.as_bytes();
buffer.extend_from_slice(&(manifest_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(manifest_bytes);
// Write weights with length prefix
buffer.extend_from_slice(&(artifact.weights.len() as u64).to_le_bytes());
buffer.extend_from_slice(&artifact.weights);
// Write optional bitstream
if let Some(ref bitstream) = artifact.bitstream {
buffer.push(1); // Present flag
buffer.extend_from_slice(&(bitstream.len() as u64).to_le_bytes());
buffer.extend_from_slice(bitstream);
} else {
buffer.push(0); // Not present
}
// Write optional calibration
if let Some(ref calibration) = artifact.calibration {
buffer.push(1);
buffer.extend_from_slice(&(calibration.len() as u64).to_le_bytes());
buffer.extend_from_slice(calibration);
} else {
buffer.push(0);
}
// Write test vectors
buffer.extend_from_slice(&(artifact.test_vectors.len() as u32).to_le_bytes());
for vector in &artifact.test_vectors {
// Write tokens
buffer.extend_from_slice(&(vector.tokens.len() as u16).to_le_bytes());
for &token in &vector.tokens {
buffer.extend_from_slice(&token.to_le_bytes());
}
// Write expected
buffer.extend_from_slice(&(vector.expected.len() as u32).to_le_bytes());
for &exp in &vector.expected {
buffer.extend_from_slice(&exp.to_le_bytes());
}
// Write max_abs_err
buffer.extend_from_slice(&vector.max_abs_err.to_le_bytes());
}
// Write signature and pubkey
buffer.extend_from_slice(&artifact.signature);
buffer.extend_from_slice(&artifact.pubkey);
Ok(buffer)
}
/// Unpack an artifact from bytes
pub fn unpack_artifact(data: &[u8]) -> Result<ModelArtifact> {
let mut cursor = std::io::Cursor::new(data);
let mut read_buf = [0u8; 8];
// Read and verify magic
cursor.read_exact(&mut read_buf[..4])?;
if &read_buf[..4] != ARTIFACT_MAGIC {
return Err(Error::InvalidArtifact("Invalid magic bytes".into()));
}
// Read version
cursor.read_exact(&mut read_buf[..2])?;
let version = u16::from_le_bytes([read_buf[0], read_buf[1]]);
if version != ARTIFACT_VERSION {
return Err(Error::InvalidArtifact(format!(
"Unsupported version: {}",
version
)));
}
// Read manifest
cursor.read_exact(&mut read_buf[..4])?;
let manifest_len = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if manifest_len > MAX_MANIFEST_SIZE {
return Err(Error::InvalidArtifact(format!(
"Manifest size {} exceeds maximum {}",
manifest_len, MAX_MANIFEST_SIZE
)));
}
let mut manifest_bytes = vec![0u8; manifest_len];
cursor.read_exact(&mut manifest_bytes)?;
let manifest = serde_json::from_slice(&manifest_bytes)?;
// Read weights
cursor.read_exact(&mut read_buf)?;
let weights_len = u64::from_le_bytes(read_buf) as usize;
if weights_len > MAX_WEIGHTS_SIZE {
return Err(Error::InvalidArtifact(format!(
"Weights size {} exceeds maximum {}",
weights_len, MAX_WEIGHTS_SIZE
)));
}
let mut weights = vec![0u8; weights_len];
cursor.read_exact(&mut weights)?;
// Read optional bitstream
cursor.read_exact(&mut read_buf[..1])?;
let bitstream = if read_buf[0] == 1 {
cursor.read_exact(&mut read_buf)?;
let len = u64::from_le_bytes(read_buf) as usize;
if len > MAX_BLOB_SIZE {
return Err(Error::InvalidArtifact(format!(
"Bitstream size {} exceeds maximum {}",
len, MAX_BLOB_SIZE
)));
}
let mut data = vec![0u8; len];
cursor.read_exact(&mut data)?;
Some(data)
} else {
None
};
// Read optional calibration
cursor.read_exact(&mut read_buf[..1])?;
let calibration = if read_buf[0] == 1 {
cursor.read_exact(&mut read_buf)?;
let len = u64::from_le_bytes(read_buf) as usize;
if len > MAX_BLOB_SIZE {
return Err(Error::InvalidArtifact(format!(
"Calibration size {} exceeds maximum {}",
len, MAX_BLOB_SIZE
)));
}
let mut data = vec![0u8; len];
cursor.read_exact(&mut data)?;
Some(data)
} else {
None
};
// Read test vectors
cursor.read_exact(&mut read_buf[..4])?;
let num_vectors = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if num_vectors > MAX_TEST_VECTORS {
return Err(Error::InvalidArtifact(format!(
"Test vector count {} exceeds maximum {}",
num_vectors, MAX_TEST_VECTORS
)));
}
let mut test_vectors = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
// Read tokens
cursor.read_exact(&mut read_buf[..2])?;
let num_tokens = u16::from_le_bytes([read_buf[0], read_buf[1]]) as usize;
if num_tokens > MAX_TOKENS_PER_VECTOR {
return Err(Error::InvalidArtifact(format!(
"Token count {} exceeds maximum {}",
num_tokens, MAX_TOKENS_PER_VECTOR
)));
}
let mut tokens = Vec::with_capacity(num_tokens);
for _ in 0..num_tokens {
cursor.read_exact(&mut read_buf[..2])?;
tokens.push(u16::from_le_bytes([read_buf[0], read_buf[1]]));
}
// Read expected
cursor.read_exact(&mut read_buf[..4])?;
let num_expected = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if num_expected > MAX_EXPECTED_PER_VECTOR {
return Err(Error::InvalidArtifact(format!(
"Expected values count {} exceeds maximum {}",
num_expected, MAX_EXPECTED_PER_VECTOR
)));
}
let mut expected = Vec::with_capacity(num_expected);
for _ in 0..num_expected {
cursor.read_exact(&mut read_buf[..2])?;
expected.push(i16::from_le_bytes([read_buf[0], read_buf[1]]));
}
// Read max_abs_err
cursor.read_exact(&mut read_buf[..4])?;
let max_abs_err = i32::from_le_bytes(read_buf[..4].try_into().unwrap());
test_vectors.push(TestVector {
tokens,
expected,
max_abs_err,
});
}
// Read signature and pubkey
let mut signature = [0u8; 64];
cursor.read_exact(&mut signature)?;
let mut pubkey = [0u8; 32];
cursor.read_exact(&mut pubkey)?;
Ok(ModelArtifact {
manifest,
weights,
bitstream,
calibration,
test_vectors,
signature,
pubkey,
})
}
/// Save artifact to file
pub fn save_artifact(artifact: &ModelArtifact, path: impl AsRef<Path>) -> Result<()> {
let data = pack_artifact(artifact)?;
std::fs::write(path, data)?;
Ok(())
}
/// Load artifact from file
pub fn load_artifact(path: impl AsRef<Path>) -> Result<ModelArtifact> {
let data = std::fs::read(path)?;
unpack_artifact(&data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::types::{FixedShape, QuantSpec};
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "test_pack".into(),
model_hash: "abc123".into(),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact {
manifest,
weights: (0..5000).map(|i| (i % 256) as u8).collect(),
bitstream: Some(vec![0xFF; 100]),
calibration: None,
test_vectors: vec![TestVector {
tokens: vec![1, 2, 3],
expected: vec![100, 200, 300],
max_abs_err: 5,
}],
signature: [0x42u8; 64],
pubkey: [0x24u8; 32],
}
}
#[test]
fn test_pack_unpack_roundtrip() {
let original = create_test_artifact();
let packed = pack_artifact(&original).unwrap();
let unpacked = unpack_artifact(&packed).unwrap();
assert_eq!(original.manifest.name, unpacked.manifest.name);
assert_eq!(original.weights, unpacked.weights);
assert_eq!(original.bitstream, unpacked.bitstream);
assert_eq!(original.calibration, unpacked.calibration);
assert_eq!(original.test_vectors.len(), unpacked.test_vectors.len());
assert_eq!(original.signature, unpacked.signature);
assert_eq!(original.pubkey, unpacked.pubkey);
}
#[test]
fn test_invalid_magic() {
let data = b"XXXX0000";
assert!(unpack_artifact(data).is_err());
}
}

View File

@@ -0,0 +1,203 @@
//! Artifact verification and signature validation
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use sha2::{Digest, Sha256};
use crate::artifact::ModelArtifact;
use crate::error::{Error, Result};
/// Verify artifact signature
pub fn verify_signature(artifact: &ModelArtifact) -> Result<bool> {
// Compute the message to verify (manifest hash + file hashes)
let message = compute_signing_message(artifact);
// Load public key
let pubkey = VerifyingKey::from_bytes(&artifact.pubkey)
.map_err(|e| Error::SignatureError(format!("Invalid public key: {}", e)))?;
// Load signature
let signature = Signature::from_bytes(&artifact.signature);
// Verify
pubkey
.verify(&message, &signature)
.map(|_| true)
.map_err(|e| Error::SignatureError(format!("Verification failed: {}", e)))
}
/// Verify complete artifact integrity
pub fn verify_artifact(artifact: &ModelArtifact) -> Result<()> {
// 1. Validate manifest
artifact.manifest.validate()?;
// 2. Verify model hash matches manifest
let computed_hash = hex::encode(artifact.model_hash());
if !artifact.manifest.model_hash.is_empty() && computed_hash != artifact.manifest.model_hash {
return Err(Error::InvalidArtifact(format!(
"Model hash mismatch: expected {}, got {}",
artifact.manifest.model_hash, computed_hash
)));
}
// 3. Verify signature (if present)
if artifact.pubkey != [0u8; 32] {
verify_signature(artifact)?;
}
// 4. Verify weights size
let expected_min =
artifact.manifest.shape.embedding_params() / artifact.manifest.quant.weights_per_byte();
if artifact.weights.len() < expected_min {
return Err(Error::InvalidArtifact(format!(
"Weights too small: {} < {}",
artifact.weights.len(),
expected_min
)));
}
Ok(())
}
/// Compute the message that was signed
fn compute_signing_message(artifact: &ModelArtifact) -> Vec<u8> {
let mut hasher = Sha256::new();
// Hash manifest
let manifest_json = serde_json::to_string(&artifact.manifest).unwrap_or_default();
hasher.update(manifest_json.as_bytes());
// Hash weights
let weights_hash = artifact.model_hash();
hasher.update(&weights_hash);
// Hash quant params
let quant_hash = artifact.quant_hash();
hasher.update(&quant_hash);
// Hash bitstream if present
if let Some(ref bitstream) = artifact.bitstream {
let mut h = Sha256::new();
h.update(bitstream);
hasher.update(&h.finalize());
}
// Hash calibration if present
if let Some(ref calib) = artifact.calibration {
let mut h = Sha256::new();
h.update(calib);
hasher.update(&h.finalize());
}
hasher.finalize().to_vec()
}
/// Sign an artifact with Ed25519 private key
#[cfg(feature = "sign")]
pub fn sign_artifact(artifact: &mut ModelArtifact, secret_key: &[u8; 32]) -> Result<()> {
use ed25519_dalek::{Signer, SigningKey};
let signing_key = SigningKey::from_bytes(secret_key);
let message = compute_signing_message(artifact);
let signature = signing_key.sign(&message);
artifact.signature = signature.to_bytes();
artifact.pubkey = signing_key.verifying_key().to_bytes();
Ok(())
}
/// Verify test vectors against model output
pub fn verify_test_vectors(
artifact: &ModelArtifact,
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
) -> Result<()> {
let max_err = artifact.manifest.tests.max_abs_err;
for (i, vector) in artifact.test_vectors.iter().enumerate() {
let output = infer_fn(&vector.tokens)?;
// Compare outputs
let actual_max_err = output
.iter()
.zip(&vector.expected)
.map(|(&a, &b)| (a as i32 - b as i32).abs())
.max()
.unwrap_or(0);
if actual_max_err > max_err {
return Err(Error::TestVectorError {
expected: max_err,
actual: actual_max_err,
});
}
}
Ok(())
}
/// Generate test vectors for an artifact
pub fn generate_test_vectors(
artifact: &mut ModelArtifact,
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
count: usize,
) -> Result<()> {
use rand::Rng;
let mut rng = rand::thread_rng();
let seq_len = artifact.manifest.shape.seq_len as usize;
let vocab = artifact.manifest.shape.vocab as u16;
artifact.test_vectors.clear();
for _ in 0..count {
// Generate random input
let tokens: Vec<u16> = (0..seq_len).map(|_| rng.gen_range(0..vocab)).collect();
// Run inference
let expected = infer_fn(&tokens)?;
artifact.test_vectors.push(crate::artifact::TestVector {
tokens,
expected,
max_abs_err: artifact.manifest.tests.max_abs_err,
});
}
artifact.manifest.tests.vectors = count as u32;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::types::{FixedShape, QuantSpec};
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "test".into(),
model_hash: String::new(),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![])
}
#[test]
fn test_verify_artifact() {
let artifact = create_test_artifact();
assert!(verify_artifact(&artifact).is_ok());
}
#[test]
fn test_compute_signing_message() {
let artifact = create_test_artifact();
let msg = compute_signing_message(&artifact);
assert_eq!(msg.len(), 32); // SHA-256 output
}
}

View File

@@ -0,0 +1,566 @@
//! FPGA Daemon backend
//!
//! Communicates with a local daemon over Unix socket or TCP
//! to send inference requests to an FPGA accelerator.
use std::collections::HashMap;
use std::io::{Read, Write};
use std::path::Path;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::artifact::ModelArtifact;
use crate::backend::{
commands, compute_topk, crc32, protocol, read_lock, validate_tokens, write_lock, BackendStats,
RequestFrame, ResponseFrame, TransformerBackend,
};
use crate::error::{Error, Result};
use crate::types::{
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
};
/// Connection type for daemon communication
#[derive(Debug, Clone)]
pub enum DaemonConnection {
/// Unix domain socket path
Unix(String),
/// TCP address (host:port)
Tcp(String),
}
impl DaemonConnection {
/// Create a Unix socket connection
pub fn unix(path: impl Into<String>) -> Self {
Self::Unix(path.into())
}
/// Create a TCP connection
pub fn tcp(addr: impl Into<String>) -> Self {
Self::Tcp(addr.into())
}
/// Default socket path
pub fn default_socket() -> Self {
Self::Unix("/var/run/ruvector_fpga.sock".into())
}
}
/// FPGA Daemon backend
pub struct FpgaDaemonBackend {
/// Connection configuration
connection: DaemonConnection,
/// Loaded models (cached metadata)
models: RwLock<HashMap<ModelId, ModelMetadata>>,
/// Statistics
stats: RwLock<BackendStats>,
/// Configuration
config: DaemonConfig,
}
/// Cached model metadata
struct ModelMetadata {
artifact: ModelArtifact,
loaded_at: Instant,
}
/// Configuration for daemon backend
#[derive(Debug, Clone)]
pub struct DaemonConfig {
/// Connection timeout in milliseconds
pub connect_timeout_ms: u64,
/// Read timeout in milliseconds
pub read_timeout_ms: u64,
/// Write timeout in milliseconds
pub write_timeout_ms: u64,
/// Number of retry attempts
pub retries: usize,
/// Retry backoff multiplier
pub backoff_multiplier: f64,
/// Return only top-K results
pub topk_only: bool,
/// Top-K count
pub topk: u16,
}
impl Default for DaemonConfig {
fn default() -> Self {
Self {
connect_timeout_ms: 5000,
read_timeout_ms: 10000,
write_timeout_ms: 5000,
retries: 3,
backoff_multiplier: 2.0,
topk_only: true,
topk: 16,
}
}
}
impl FpgaDaemonBackend {
/// Create a new daemon backend with Unix socket
pub fn new(socket_path: impl AsRef<Path>) -> Self {
Self::with_connection(
DaemonConnection::unix(socket_path.as_ref().to_string_lossy()),
DaemonConfig::default(),
)
}
/// Create with custom connection and config
pub fn with_connection(connection: DaemonConnection, config: DaemonConfig) -> Self {
Self {
connection,
models: RwLock::new(HashMap::new()),
stats: RwLock::new(BackendStats::default()),
config,
}
}
/// Connect to the daemon
fn connect(&self) -> Result<Box<dyn ReadWrite>> {
let timeout = Duration::from_millis(self.config.connect_timeout_ms);
match &self.connection {
DaemonConnection::Unix(path) => {
#[cfg(unix)]
{
use std::os::unix::net::UnixStream;
let stream = UnixStream::connect(path)
.map_err(|e| Error::daemon_connection(format!("Unix socket: {}", e)))?;
stream
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
.ok();
stream
.set_write_timeout(Some(Duration::from_millis(
self.config.write_timeout_ms,
)))
.ok();
Ok(Box::new(stream))
}
#[cfg(not(unix))]
{
let _ = (path, timeout);
Err(Error::FeatureNotAvailable(
"Unix sockets not available on this platform".into(),
))
}
}
DaemonConnection::Tcp(addr) => {
use std::net::TcpStream;
let stream = TcpStream::connect_timeout(
&addr
.parse()
.map_err(|e| Error::daemon_connection(format!("Invalid address: {}", e)))?,
timeout,
)
.map_err(|e| Error::daemon_connection(format!("TCP: {}", e)))?;
stream
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
.ok();
stream
.set_write_timeout(Some(Duration::from_millis(self.config.write_timeout_ms)))
.ok();
Ok(Box::new(stream))
}
}
}
/// Send inference request to daemon
fn send_request(
&self,
stream: &mut dyn ReadWrite,
req: &InferenceRequest,
) -> Result<(Vec<i16>, ResponseFrame)> {
let shape = &req.shape;
// Build request flags
let mut flags = 0u16;
if self.config.topk_only {
flags |= protocol::flags::TOPK_ONLY;
}
// Create request frame
let frame = RequestFrame::new(
shape.seq_len,
shape.d_model,
shape.vocab,
&req.model,
flags,
self.config.topk,
);
// Build payload
let mut payload = Vec::with_capacity(
protocol::HEADER_SIZE + req.tokens.len() * 2 + req.attn_mask.len() + 8,
);
// Write header
payload.extend_from_slice(&frame.to_bytes());
// Write tokens (u16 little-endian)
for &token in req.tokens {
payload.extend_from_slice(&token.to_le_bytes());
}
// Write mask
payload.extend_from_slice(req.attn_mask);
// Write gate hint (packed)
payload.extend_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
payload.push(req.gate_hint.boundary_crossed as u8);
payload.push(req.gate_hint.max_compute_class as u8);
// Calculate and append checksum
let checksum = crc32(&payload);
payload.extend_from_slice(&checksum.to_le_bytes());
// Send payload
stream
.write_all(&payload)
.map_err(|e| Error::backend(format!("Write failed: {}", e)))?;
stream
.flush()
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
// Read response header
let mut response_header = [0u8; 14];
stream
.read_exact(&mut response_header)
.map_err(|e| Error::backend(format!("Read header failed: {}", e)))?;
let response = ResponseFrame::from_bytes(&response_header);
// Copy packed fields to avoid alignment issues
let status = { response.status };
// Check status
match status {
protocol::status::OK => {}
protocol::status::MODEL_NOT_FOUND => {
return Err(Error::ModelNotFound(req.model));
}
protocol::status::SHAPE_MISMATCH => {
return Err(Error::ShapeMismatch {
expected: req.shape,
actual: req.shape, // Daemon should provide actual shape
});
}
protocol::status::GATE_BLOCKED => {
return Err(Error::GateBlocked {
reason: crate::types::SkipReason::PolicyDenied,
});
}
_ => {
return Err(Error::backend(format!("Daemon error: status {}", status)));
}
}
// Read logits
let logits_count = if self.config.topk_only {
self.config.topk as usize * 2 // (token_id, logit) pairs
} else {
shape.vocab as usize
};
let mut logits_bytes = vec![0u8; logits_count * 2];
stream
.read_exact(&mut logits_bytes)
.map_err(|e| Error::backend(format!("Read logits failed: {}", e)))?;
// Parse logits
let logits: Vec<i16> = logits_bytes
.chunks(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
// Read and verify checksum
let mut checksum_bytes = [0u8; 4];
stream.read_exact(&mut checksum_bytes).ok(); // Checksum is optional
Ok((logits, response))
}
/// Send load model command to daemon
fn send_load_command(
&self,
stream: &mut dyn ReadWrite,
artifact: &ModelArtifact,
) -> Result<()> {
// Pack artifact
let artifact_bytes = crate::artifact::pack::pack_artifact(artifact)?;
// Build command packet:
// [command: 1] [model_id: 32] [artifact_len: 4] [artifact_data: N] [checksum: 4]
let mut payload = Vec::with_capacity(1 + 32 + 4 + artifact_bytes.len() + 4);
// Command byte
payload.push(commands::LOAD_MODEL);
// Model ID (32 bytes)
payload.extend_from_slice(artifact.model_id().as_bytes());
// Artifact length (u32 LE)
payload.extend_from_slice(&(artifact_bytes.len() as u32).to_le_bytes());
// Artifact data
payload.extend_from_slice(&artifact_bytes);
// Checksum
let checksum = crc32(&payload);
payload.extend_from_slice(&checksum.to_le_bytes());
// Send
stream
.write_all(&payload)
.map_err(|e| Error::backend(format!("Write load command failed: {}", e)))?;
stream
.flush()
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
// Read response: [status: 1] [message_len: 2] [message: N]
let mut status = [0u8; 1];
stream
.read_exact(&mut status)
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
if status[0] != 0 {
// Read error message
let mut msg_len = [0u8; 2];
stream.read_exact(&mut msg_len).ok();
let len = u16::from_le_bytes(msg_len) as usize;
let mut msg = vec![0u8; len.min(256)];
stream.read_exact(&mut msg).ok();
let error_msg = String::from_utf8_lossy(&msg);
return Err(Error::backend(format!(
"Daemon rejected load: {}",
error_msg
)));
}
Ok(())
}
/// Send unload model command to daemon
fn send_unload_command(&self, stream: &mut dyn ReadWrite, model_id: ModelId) -> Result<()> {
// Build command packet: [command: 1] [model_id: 32] [checksum: 4]
let mut payload = Vec::with_capacity(1 + 32 + 4);
payload.push(commands::UNLOAD_MODEL);
payload.extend_from_slice(model_id.as_bytes());
let checksum = crc32(&payload);
payload.extend_from_slice(&checksum.to_le_bytes());
// Send
stream
.write_all(&payload)
.map_err(|e| Error::backend(format!("Write unload command failed: {}", e)))?;
stream
.flush()
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
// Read response status
let mut status = [0u8; 1];
stream
.read_exact(&mut status)
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
if status[0] != 0 {
return Err(Error::backend("Daemon rejected unload"));
}
Ok(())
}
/// Execute with retries
fn with_retries<T, F>(&self, mut f: F) -> Result<T>
where
F: FnMut() -> Result<T>,
{
let mut last_error = None;
let mut delay = Duration::from_millis(100);
for attempt in 0..=self.config.retries {
match f() {
Ok(result) => return Ok(result),
Err(e) if e.is_recoverable() => {
last_error = Some(e);
if attempt < self.config.retries {
std::thread::sleep(delay);
delay = Duration::from_secs_f64(
delay.as_secs_f64() * self.config.backoff_multiplier,
);
}
}
Err(e) => return Err(e),
}
}
Err(last_error.unwrap_or_else(|| Error::backend("Unknown error")))
}
}
impl TransformerBackend for FpgaDaemonBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
let model_id = artifact.model_id();
// Send load command to daemon to preload the model
self.with_retries(|| {
let mut stream = self.connect()?;
self.send_load_command(stream.as_mut(), artifact)
})?;
// Cache metadata locally
{
let mut models = write_lock(&self.models, |m| {
m.insert(
model_id,
ModelMetadata {
artifact: artifact.clone(),
loaded_at: Instant::now(),
},
);
})?;
}
write_lock(&self.stats, |s| {
s.models_loaded += 1;
})?;
Ok(model_id)
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
let start = Instant::now();
// Validate request
req.validate()?;
// Check model is loaded locally and validate tokens
let model_metadata = read_lock(&self.models, |models| {
models.get(&req.model).map(|m| m.artifact.clone())
})?
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate tokens against vocabulary
validate_tokens(req.tokens, model_metadata.manifest.shape.vocab)?;
// Execute with retries
let (logits, response) = self.with_retries(|| {
let mut stream = self.connect()?;
self.send_request(stream.as_mut(), &req)
})?;
let latency_ns = start.elapsed().as_nanos() as u32;
// Parse response
let gate_decision = response.to_gate_decision();
// Build top-K if we got pairs
let (logits_q, topk) = if self.config.topk_only {
// logits contains (token_id, logit) pairs
let pairs: Vec<(u16, i16)> = logits
.chunks(2)
.filter_map(|chunk| {
if chunk.len() == 2 {
Some((chunk[0] as u16, chunk[1]))
} else {
None
}
})
.collect();
(vec![], Some(pairs))
} else {
// Full logits - use common compute_topk
let topk = compute_topk(&logits, 16);
(logits, Some(topk))
};
// Copy packed fields to avoid alignment issues
let resp_cycles = { response.cycles };
let resp_latency_ns = { response.latency_ns };
// Create witness
let witness = WitnessLog::new(
model_metadata.model_hash(),
model_metadata.quant_hash(),
BackendKind::FpgaDaemon,
resp_cycles,
latency_ns.min(resp_latency_ns.max(latency_ns)),
gate_decision,
);
// Update stats (with poison handling)
write_lock(&self.stats, |stats| {
stats.total_inferences += 1;
stats.total_cycles += resp_cycles as u64;
let n = stats.total_inferences;
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
GateDecision::Skipped { .. } => stats.skipped += 1,
_ => {}
}
})?;
Ok(InferenceResult::new(logits_q, topk, witness))
}
fn unload(&self, model: ModelId) -> Result<()> {
// Send unload command to daemon
self.with_retries(|| {
let mut stream = self.connect()?;
self.send_unload_command(stream.as_mut(), model)
})?;
// Remove from local cache
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
if removed {
write_lock(&self.stats, |s| {
s.models_loaded = s.models_loaded.saturating_sub(1);
})?;
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
}
fn kind(&self) -> BackendKind {
BackendKind::FpgaDaemon
}
fn stats(&self) -> BackendStats {
self.stats.read().unwrap().clone()
}
}
/// Trait combining Read and Write for stream abstraction
trait ReadWrite: Read + Write + Send {}
impl<T: Read + Write + Send> ReadWrite for T {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_daemon_connection_types() {
let unix = DaemonConnection::unix("/tmp/test.sock");
assert!(matches!(unix, DaemonConnection::Unix(_)));
let tcp = DaemonConnection::tcp("127.0.0.1:8080");
assert!(matches!(tcp, DaemonConnection::Tcp(_)));
}
#[test]
fn test_config_defaults() {
let config = DaemonConfig::default();
assert_eq!(config.connect_timeout_ms, 5000);
assert_eq!(config.retries, 3);
assert!(config.topk_only);
}
}

View File

@@ -0,0 +1,645 @@
//! FPGA PCIe backend
//!
//! Direct memory-mapped access to FPGA accelerator via PCIe.
//! Uses DMA ring buffers for zero-copy, lock-free operation.
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Instant;
#[cfg(feature = "pcie")]
use memmap2::{MmapMut, MmapOptions};
use crate::artifact::ModelArtifact;
use crate::backend::{
compute_topk, protocol, read_lock, validate_tokens, write_lock, BackendStats,
TransformerBackend,
};
use crate::error::{Error, Result};
use crate::types::{
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
};
/// PCIe device configuration
#[derive(Debug, Clone)]
pub struct PcieConfig {
/// Device path (e.g., /dev/ruvector0)
pub device_path: String,
/// BAR0 offset for control registers
pub bar0_offset: usize,
/// BAR1 offset for DMA buffers
pub bar1_offset: usize,
/// Number of request slots in ring buffer
pub ring_slots: usize,
/// Size of each request slot in bytes
pub slot_size: usize,
/// DMA timeout in milliseconds
pub dma_timeout_ms: u64,
/// Enable batch mode (multiple requests per DMA burst)
pub batch_mode: bool,
/// Maximum requests per batch
pub batch_size: usize,
}
impl Default for PcieConfig {
fn default() -> Self {
Self {
device_path: "/dev/ruvector0".into(),
bar0_offset: 0,
bar1_offset: 0x10000,
ring_slots: 16,
slot_size: 64 * 1024, // 64KB per slot
dma_timeout_ms: 100,
batch_mode: false,
batch_size: 4,
}
}
}
/// Ring buffer slot state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
enum SlotState {
Free = 0,
Pending = 1,
Complete = 2,
Error = 3,
}
/// DMA ring buffer for lock-free request/response handling
struct DmaRingBuffer {
/// Memory-mapped request buffer
#[cfg(feature = "pcie")]
request_mmap: MmapMut,
/// Memory-mapped response buffer
#[cfg(feature = "pcie")]
response_mmap: MmapMut,
/// Slot states
slot_states: Vec<AtomicU32>,
/// Producer index (next slot to write)
producer_idx: AtomicU32,
/// Consumer index (next slot to read)
consumer_idx: AtomicU32,
/// Number of slots
num_slots: usize,
/// Size per slot
slot_size: usize,
}
impl DmaRingBuffer {
/// Create a new DMA ring buffer (mock for non-PCIe builds)
#[cfg(not(feature = "pcie"))]
fn new(_config: &PcieConfig) -> Result<Self> {
Err(Error::FeatureNotAvailable(
"PCIe support not compiled".into(),
))
}
/// Create a new DMA ring buffer
#[cfg(feature = "pcie")]
fn new(config: &PcieConfig) -> Result<Self> {
use std::fs::OpenOptions;
// Open device
let file = OpenOptions::new()
.read(true)
.write(true)
.open(&config.device_path)
.map_err(|e| Error::PcieError(format!("Failed to open device: {}", e)))?;
let total_size = config.ring_slots * config.slot_size;
// Map request buffer (BAR1)
let request_mmap = unsafe {
MmapOptions::new()
.offset(config.bar1_offset as u64)
.len(total_size)
.map_mut(&file)
.map_err(|e| Error::PcieError(format!("Failed to map request buffer: {}", e)))?
};
// Map response buffer (BAR1 + offset)
let response_mmap = unsafe {
MmapOptions::new()
.offset((config.bar1_offset + total_size) as u64)
.len(total_size)
.map_mut(&file)
.map_err(|e| Error::PcieError(format!("Failed to map response buffer: {}", e)))?
};
// Initialize slot states
let slot_states: Vec<AtomicU32> = (0..config.ring_slots)
.map(|_| AtomicU32::new(SlotState::Free as u32))
.collect();
Ok(Self {
request_mmap,
response_mmap,
slot_states,
producer_idx: AtomicU32::new(0),
consumer_idx: AtomicU32::new(0),
num_slots: config.ring_slots,
slot_size: config.slot_size,
})
}
/// Acquire a slot for writing
fn acquire_slot(&self) -> Option<usize> {
let producer = self.producer_idx.load(Ordering::Acquire);
let slot = producer as usize % self.num_slots;
// Check if slot is free
if self.slot_states[slot].load(Ordering::Acquire) == SlotState::Free as u32 {
// Try to claim it
if self.slot_states[slot]
.compare_exchange(
SlotState::Free as u32,
SlotState::Pending as u32,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
self.producer_idx
.store(producer.wrapping_add(1), Ordering::Release);
return Some(slot);
}
}
None
}
/// Release a slot after reading response
fn release_slot(&self, slot: usize) {
self.slot_states[slot].store(SlotState::Free as u32, Ordering::Release);
self.consumer_idx.fetch_add(1, Ordering::AcqRel);
}
/// Check if a slot is complete
fn is_complete(&self, slot: usize) -> bool {
self.slot_states[slot].load(Ordering::Acquire) == SlotState::Complete as u32
}
/// Mark a slot as complete (called by FPGA via doorbell/interrupt)
fn mark_complete(&self, slot: usize) {
self.slot_states[slot].store(SlotState::Complete as u32, Ordering::Release);
}
/// Get request buffer for a slot
#[cfg(feature = "pcie")]
fn request_buffer(&mut self, slot: usize) -> &mut [u8] {
let start = slot * self.slot_size;
let end = start + self.slot_size;
&mut self.request_mmap[start..end]
}
/// Get response buffer for a slot
#[cfg(feature = "pcie")]
fn response_buffer(&self, slot: usize) -> &[u8] {
let start = slot * self.slot_size;
let end = start + self.slot_size;
&self.response_mmap[start..end]
}
}
/// FPGA PCIe backend
pub struct FpgaPcieBackend {
/// Configuration
config: PcieConfig,
/// DMA ring buffer
ring: Option<DmaRingBuffer>,
/// Loaded models
models: RwLock<HashMap<ModelId, ModelMetadata>>,
/// Statistics
stats: RwLock<BackendStats>,
/// Total cycles counter
total_cycles: AtomicU64,
/// FPGA memory allocator state (next free offset)
fpga_mem_offset: AtomicU64,
/// FPGA memory total size (2GB default)
fpga_mem_size: u64,
}
/// Cached model metadata
struct ModelMetadata {
artifact: ModelArtifact,
fpga_slot: u32, // Slot in FPGA memory where model is loaded
weights_offset: u64, // Offset in FPGA DDR where weights are stored
weights_size: usize, // Size of weights in bytes
}
/// FPGA DDR base offset for model weights
const FPGA_DDR_MODEL_BASE: u64 = 0x1000_0000; // 256MB offset
impl FpgaPcieBackend {
/// Create a new PCIe backend
pub fn new(config: PcieConfig) -> Result<Self> {
#[cfg(feature = "pcie")]
let ring = Some(DmaRingBuffer::new(&config)?);
#[cfg(not(feature = "pcie"))]
let ring = None;
Ok(Self {
config,
ring,
models: RwLock::new(HashMap::new()),
stats: RwLock::new(BackendStats::default()),
total_cycles: AtomicU64::new(0),
fpga_mem_offset: AtomicU64::new(FPGA_DDR_MODEL_BASE),
fpga_mem_size: 2 * 1024 * 1024 * 1024, // 2GB
})
}
/// Create with default configuration
pub fn default_device() -> Result<Self> {
Self::new(PcieConfig::default())
}
/// Write inference request to DMA buffer
#[cfg(feature = "pcie")]
fn write_request(
&self,
ring: &mut DmaRingBuffer,
slot: usize,
req: &InferenceRequest,
) -> Result<()> {
use crate::backend::{protocol, RequestFrame};
let buffer = ring.request_buffer(slot);
let shape = &req.shape;
// Write header
let frame = RequestFrame::new(shape.seq_len, shape.d_model, shape.vocab, &req.model, 0, 16);
let header = frame.to_bytes();
buffer[..protocol::HEADER_SIZE].copy_from_slice(&header);
let mut offset = protocol::HEADER_SIZE;
// Write tokens
for &token in req.tokens {
buffer[offset..offset + 2].copy_from_slice(&token.to_le_bytes());
offset += 2;
}
// Write mask
buffer[offset..offset + req.attn_mask.len()].copy_from_slice(req.attn_mask);
offset += req.attn_mask.len();
// Write gate hint
buffer[offset..offset + 2].copy_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
offset += 2;
buffer[offset] = req.gate_hint.boundary_crossed as u8;
offset += 1;
buffer[offset] = req.gate_hint.max_compute_class as u8;
Ok(())
}
/// Read inference response from DMA buffer
#[cfg(feature = "pcie")]
fn read_response(
&self,
ring: &DmaRingBuffer,
slot: usize,
shape: &crate::types::FixedShape,
) -> Result<(Vec<i16>, u32, u32, GateDecision)> {
use crate::backend::ResponseFrame;
let buffer = ring.response_buffer(slot);
// Read response header
let response = ResponseFrame::from_bytes(&buffer[..14].try_into().unwrap());
// Check status
if response.status != 0 {
return Err(Error::backend(format!(
"FPGA error: status {}",
response.status
)));
}
// Read logits
let vocab = shape.vocab as usize;
let mut logits = Vec::with_capacity(vocab);
let mut offset = 14;
for _ in 0..vocab {
let value = i16::from_le_bytes([buffer[offset], buffer[offset + 1]]);
logits.push(value);
offset += 2;
}
Ok((
logits,
response.cycles,
response.latency_ns,
response.to_gate_decision(),
))
}
/// Ring doorbell to notify FPGA of pending request
#[cfg(feature = "pcie")]
fn ring_doorbell(&self, _slot: usize) {
// In a real implementation, this would write to a control register
// to notify the FPGA that a new request is available
}
/// Wait for response with polling
fn wait_for_response(&self, ring: &DmaRingBuffer, slot: usize, timeout_ms: u64) -> Result<()> {
let start = Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while !ring.is_complete(slot) {
if start.elapsed() > timeout {
return Err(Error::Timeout { ms: timeout_ms });
}
std::hint::spin_loop();
}
Ok(())
}
/// Allocate FPGA DDR memory for model weights
fn allocate_fpga_memory(&self, size: usize) -> Result<u64> {
// Align to 4KB boundary for DMA efficiency
let aligned_size = (size + 0xFFF) & !0xFFF;
// Atomic allocation (simple bump allocator)
let offset = self
.fpga_mem_offset
.fetch_add(aligned_size as u64, Ordering::SeqCst);
// Check for overflow
if offset + aligned_size as u64 > self.fpga_mem_size {
// Roll back allocation
self.fpga_mem_offset
.fetch_sub(aligned_size as u64, Ordering::SeqCst);
return Err(Error::ResourceExhausted("FPGA DDR memory full".into()));
}
Ok(offset)
}
/// Upload model weights to FPGA DDR via DMA
#[cfg(feature = "pcie")]
fn upload_weights_dma(&self, weights: &[u8], fpga_offset: u64) -> Result<()> {
// DMA transfer configuration
const DMA_CHUNK_SIZE: usize = 64 * 1024; // 64KB per transfer
let ring = self
.ring
.as_ref()
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
// Transfer weights in chunks
let mut transferred = 0usize;
while transferred < weights.len() {
let chunk_size = DMA_CHUNK_SIZE.min(weights.len() - transferred);
// Acquire a DMA slot
let slot = loop {
if let Some(s) = ring.acquire_slot() {
break s;
}
std::hint::spin_loop();
};
// Write DMA command to slot (simplified protocol)
// In real hardware:
// - Write target FPGA DDR address
// - Write source offset in slot
// - Write transfer length
// - Ring doorbell
// For now, we simulate the DMA by marking complete
ring.mark_complete(slot);
// Wait for completion
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
// Release slot
ring.release_slot(slot);
transferred += chunk_size;
}
Ok(())
}
/// Free FPGA DDR memory (simplified - real impl would use proper allocator)
fn free_fpga_memory(&self, _offset: u64, _size: usize) {
// In a production system, this would:
// 1. Mark the memory region as free in an allocator
// 2. Potentially compact memory if fragmentation is high
// 3. Update hardware memory management unit
//
// For this implementation, we use a bump allocator without free.
// Memory is reclaimed when all models are unloaded.
}
/// Check if all models are unloaded and reset memory allocator
fn maybe_reset_allocator(&self) {
let models_empty = read_lock(&self.models, |m| m.is_empty()).unwrap_or(false);
if models_empty {
self.fpga_mem_offset
.store(FPGA_DDR_MODEL_BASE, Ordering::SeqCst);
}
}
}
impl TransformerBackend for FpgaPcieBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
#[cfg(not(feature = "pcie"))]
{
let _ = artifact;
return Err(Error::FeatureNotAvailable(
"PCIe support not compiled".into(),
));
}
#[cfg(feature = "pcie")]
{
// Validate artifact
artifact.validate()?;
let model_id = artifact.model_id();
let weights_size = artifact.weights.len();
// Allocate FPGA DDR memory for weights
let weights_offset = self.allocate_fpga_memory(weights_size)?;
// Upload model weights to FPGA DDR via DMA
if let Err(e) = self.upload_weights_dma(&artifact.weights, weights_offset) {
// Roll back allocation on failure
self.free_fpga_memory(weights_offset, weights_size);
return Err(e);
}
// Get slot number for this model
let fpga_slot = read_lock(&self.models, |m| m.len() as u32)?;
// Store metadata
write_lock(&self.models, |models| {
models.insert(
model_id,
ModelMetadata {
artifact: artifact.clone(),
fpga_slot,
weights_offset,
weights_size,
},
);
})?;
write_lock(&self.stats, |stats| {
stats.models_loaded += 1;
})?;
Ok(model_id)
}
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
#[cfg(not(feature = "pcie"))]
{
let _ = req;
return Err(Error::FeatureNotAvailable(
"PCIe support not compiled".into(),
));
}
#[cfg(feature = "pcie")]
{
let start = Instant::now();
// Validate request
req.validate()?;
// Get model metadata
let model_artifact = read_lock(&self.models, |models| {
models.get(&req.model).map(|m| m.artifact.clone())
})?
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate tokens against vocabulary
validate_tokens(req.tokens, model_artifact.manifest.shape.vocab)?;
// Get ring buffer
let ring = self
.ring
.as_ref()
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
// Acquire slot
let slot = ring
.acquire_slot()
.ok_or_else(|| Error::ResourceExhausted("No DMA slots available".into()))?;
// Write request (need mutable access - simplified for now)
// In production, this would use proper interior mutability
// self.write_request(ring, slot, &req)?;
// Ring doorbell
// self.ring_doorbell(slot);
// Wait for response
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
// Read response
let (logits, cycles, fpga_latency_ns, gate_decision) =
self.read_response(ring, slot, &req.shape)?;
// Release slot
ring.release_slot(slot);
let latency_ns = start.elapsed().as_nanos() as u32;
// Compute top-K using common utility
let topk = compute_topk(&logits, 16);
// Create witness
let witness = WitnessLog::new(
model_artifact.model_hash(),
model_artifact.quant_hash(),
BackendKind::FpgaPcie,
cycles,
fpga_latency_ns.min(latency_ns),
gate_decision,
);
// Update stats
self.total_cycles
.fetch_add(cycles as u64, Ordering::Relaxed);
write_lock(&self.stats, |stats| {
stats.total_inferences += 1;
stats.total_cycles = self.total_cycles.load(Ordering::Relaxed);
let n = stats.total_inferences;
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
GateDecision::Skipped { .. } => stats.skipped += 1,
_ => {}
}
})?;
Ok(InferenceResult::new(logits, Some(topk), witness))
}
}
fn unload(&self, model: ModelId) -> Result<()> {
// Remove from cache and get memory info for deallocation
let removed = write_lock(&self.models, |models| {
models
.remove(&model)
.map(|m| (m.weights_offset, m.weights_size))
})?;
if let Some((offset, size)) = removed {
// Free FPGA DDR memory
self.free_fpga_memory(offset, size);
// Check if we can reset the allocator
self.maybe_reset_allocator();
write_lock(&self.stats, |stats| {
stats.models_loaded = stats.models_loaded.saturating_sub(1);
})?;
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
}
fn kind(&self) -> BackendKind {
BackendKind::FpgaPcie
}
fn stats(&self) -> BackendStats {
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pcie_config_default() {
let config = PcieConfig::default();
assert_eq!(config.ring_slots, 16);
assert_eq!(config.slot_size, 64 * 1024);
}
#[test]
fn test_slot_state_values() {
assert_eq!(SlotState::Free as u8, 0);
assert_eq!(SlotState::Pending as u8, 1);
assert_eq!(SlotState::Complete as u8, 2);
}
}

View File

@@ -0,0 +1,428 @@
//! Backend implementations for FPGA Transformer
//!
//! All backends implement the `TransformerBackend` trait for uniform API.
use crate::artifact::ModelArtifact;
use crate::error::Result;
use crate::types::{InferenceRequest, InferenceResult, ModelId};
#[cfg(feature = "native_sim")]
pub mod native_sim;
#[cfg(feature = "daemon")]
pub mod fpga_daemon;
#[cfg(feature = "pcie")]
pub mod fpga_pcie;
#[cfg(feature = "wasm")]
pub mod wasm_sim;
/// Trait for transformer inference backends
///
/// All backends must be thread-safe and implement the same API.
pub trait TransformerBackend: Send + Sync {
/// Load a model artifact and return its ID
///
/// The artifact is validated, test vectors are run, and
/// the model is prepared for inference.
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId>;
/// Run inference on the given request
///
/// The request must specify a model that has been loaded.
/// Returns the inference result with witness log.
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult>;
/// Unload a model to free resources
fn unload(&self, model: ModelId) -> Result<()>;
/// Check if a model is loaded
fn is_loaded(&self, model: ModelId) -> bool;
/// Get the backend kind
fn kind(&self) -> crate::types::BackendKind;
/// Get backend-specific statistics
fn stats(&self) -> BackendStats {
BackendStats::default()
}
}
/// Backend statistics
#[derive(Debug, Clone, Default)]
pub struct BackendStats {
/// Number of models currently loaded
pub models_loaded: usize,
/// Total inferences performed
pub total_inferences: u64,
/// Total cycles consumed (FPGA only)
pub total_cycles: u64,
/// Average latency in nanoseconds
pub avg_latency_ns: u64,
/// P99 latency in nanoseconds
pub p99_latency_ns: u64,
/// Number of early exits
pub early_exits: u64,
/// Number of skipped inferences
pub skipped: u64,
}
/// Protocol constants for daemon/PCIe communication
pub mod protocol {
/// Magic number for frame validation
pub const MAGIC: u32 = 0x5256_5846; // "RVXF" - RuVector FPGA
/// Current protocol version
pub const VERSION: u16 = 1;
/// Frame header size in bytes
pub const HEADER_SIZE: usize = 24;
/// Maximum payload size
pub const MAX_PAYLOAD: usize = 1024 * 1024; // 1MB
/// Request flags
pub mod flags {
/// Return only top-K predictions
pub const TOPK_ONLY: u16 = 0x0001;
/// Use LUT-based softmax
pub const LUT_SOFTMAX: u16 = 0x0002;
/// Enable early exit
pub const EARLY_EXIT: u16 = 0x0004;
/// Return detailed witness
pub const WITNESS_DETAIL: u16 = 0x0008;
}
/// Response status codes
pub mod status {
/// Success
pub const OK: u16 = 0;
/// Model not found
pub const MODEL_NOT_FOUND: u16 = 1;
/// Shape mismatch
pub const SHAPE_MISMATCH: u16 = 2;
/// Gate blocked
pub const GATE_BLOCKED: u16 = 3;
/// Internal error
pub const INTERNAL_ERROR: u16 = 0xFFFF;
}
}
/// Request frame for wire protocol
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
pub struct RequestFrame {
/// Magic number (MAGIC)
pub magic: u32,
/// Protocol version
pub protocol: u16,
/// Sequence length
pub seq_len: u16,
/// Model dimension
pub d_model: u16,
/// Vocabulary size
pub vocab: u16,
/// Model ID (lower 32 bits)
pub model_id_low: u32,
/// Model ID (upper 32 bits)
pub model_id_high: u32,
/// Request flags
pub flags: u16,
/// Top-K count (if TOPK_ONLY flag set)
pub topk: u16,
}
impl RequestFrame {
/// Create a new request frame
pub fn new(
seq_len: u16,
d_model: u16,
vocab: u32,
model_id: &ModelId,
flags: u16,
topk: u16,
) -> Self {
let id_bytes = model_id.as_bytes();
let model_id_low = u32::from_le_bytes([id_bytes[0], id_bytes[1], id_bytes[2], id_bytes[3]]);
let model_id_high =
u32::from_le_bytes([id_bytes[4], id_bytes[5], id_bytes[6], id_bytes[7]]);
Self {
magic: protocol::MAGIC,
protocol: protocol::VERSION,
seq_len,
d_model,
vocab: (vocab & 0xFFFF) as u16,
model_id_low,
model_id_high,
flags,
topk,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> [u8; protocol::HEADER_SIZE] {
let mut bytes = [0u8; protocol::HEADER_SIZE];
bytes[0..4].copy_from_slice(&self.magic.to_le_bytes());
bytes[4..6].copy_from_slice(&self.protocol.to_le_bytes());
bytes[6..8].copy_from_slice(&self.seq_len.to_le_bytes());
bytes[8..10].copy_from_slice(&self.d_model.to_le_bytes());
bytes[10..12].copy_from_slice(&self.vocab.to_le_bytes());
bytes[12..16].copy_from_slice(&self.model_id_low.to_le_bytes());
bytes[16..20].copy_from_slice(&self.model_id_high.to_le_bytes());
bytes[20..22].copy_from_slice(&self.flags.to_le_bytes());
bytes[22..24].copy_from_slice(&self.topk.to_le_bytes());
bytes
}
}
/// Response frame from wire protocol
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
pub struct ResponseFrame {
/// Status code
pub status: u16,
/// Latency in nanoseconds
pub latency_ns: u32,
/// Compute cycles
pub cycles: u32,
/// Gate decision (packed)
pub gate_decision: u8,
/// Exit layer (if early exit)
pub exit_layer: u8,
/// Skip reason (if skipped)
pub skip_reason: u8,
/// Reserved
pub reserved: u8,
}
impl ResponseFrame {
/// Parse from bytes
pub fn from_bytes(bytes: &[u8; 14]) -> Self {
Self {
status: u16::from_le_bytes([bytes[0], bytes[1]]),
latency_ns: u32::from_le_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
cycles: u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]),
gate_decision: bytes[10],
exit_layer: bytes[11],
skip_reason: bytes[12],
reserved: bytes[13],
}
}
/// Convert gate decision to enum
pub fn to_gate_decision(&self) -> crate::types::GateDecision {
match self.gate_decision {
0 => crate::types::GateDecision::RanFull,
1 => crate::types::GateDecision::EarlyExit {
layer: self.exit_layer,
},
2 => crate::types::GateDecision::Skipped {
reason: match self.skip_reason {
0 => crate::types::SkipReason::LowCoherence,
1 => crate::types::SkipReason::PolicyDenied,
_ => crate::types::SkipReason::BudgetExceeded,
},
},
_ => crate::types::GateDecision::RanFull,
}
}
}
/// Calculate CRC32 checksum for frame validation
pub fn crc32(data: &[u8]) -> u32 {
// Simple CRC32 implementation (could use crc32fast crate in production)
let mut crc: u32 = 0xFFFFFFFF;
for &byte in data {
crc ^= byte as u32;
for _ in 0..8 {
crc = if crc & 1 != 0 {
(crc >> 1) ^ 0xEDB88320
} else {
crc >> 1
};
}
}
!crc
}
// ============================================================================
// Common utilities shared across backends
// ============================================================================
/// Compute top-K predictions from logits
/// Returns sorted (token_id, logit) pairs, descending by logit value
#[inline]
pub fn compute_topk(logits: &[i16], k: usize) -> Vec<(u16, i16)> {
if logits.is_empty() {
return vec![];
}
// For small K, partial sort is faster
if k <= 32 && logits.len() > 100 {
// Use partial heap-based selection
let mut heap: Vec<(i16, u16)> = Vec::with_capacity(k + 1);
for (i, &v) in logits.iter().enumerate() {
if heap.len() < k {
heap.push((v, i as u16));
if heap.len() == k {
// Heapify
heap.sort_by(|a, b| a.0.cmp(&b.0));
}
} else if v > heap[0].0 {
heap[0] = (v, i as u16);
// Maintain min-heap property
let mut idx = 0;
while idx * 2 + 1 < heap.len() {
let left = idx * 2 + 1;
let right = idx * 2 + 2;
let mut smallest = idx;
if heap[left].0 < heap[smallest].0 {
smallest = left;
}
if right < heap.len() && heap[right].0 < heap[smallest].0 {
smallest = right;
}
if smallest == idx {
break;
}
heap.swap(idx, smallest);
idx = smallest;
}
}
}
heap.sort_by(|a, b| b.0.cmp(&a.0));
heap.into_iter().map(|(v, i)| (i, v)).collect()
} else {
// Full sort for small arrays
let mut indexed: Vec<(usize, i16)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.cmp(&a.1));
indexed
.into_iter()
.take(k)
.map(|(i, v)| (i as u16, v))
.collect()
}
}
/// Helper to safely read from RwLock, returning error on poison
pub fn read_lock<T, R>(
lock: &std::sync::RwLock<T>,
f: impl FnOnce(&T) -> R,
) -> crate::error::Result<R> {
lock.read()
.map(|guard| f(&*guard))
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (read)".into()))
}
/// Helper to safely write to RwLock, returning error on poison
pub fn write_lock<T, R>(
lock: &std::sync::RwLock<T>,
f: impl FnOnce(&mut T) -> R,
) -> crate::error::Result<R> {
lock.write()
.map(|mut guard| f(&mut *guard))
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (write)".into()))
}
/// Validate token indices against vocabulary size
#[inline]
pub fn validate_tokens(tokens: &[u16], vocab_size: u32) -> crate::error::Result<()> {
for (i, &token) in tokens.iter().enumerate() {
if token as u32 >= vocab_size {
return Err(crate::error::Error::InvalidConfig(format!(
"Token {} at index {} exceeds vocabulary size {}",
token, i, vocab_size
)));
}
}
Ok(())
}
/// Build witness log from inference metadata
pub fn build_witness(
model_hash: [u8; 32],
quant_hash: [u8; 32],
backend: crate::types::BackendKind,
cycles: u32,
latency_ns: u32,
gate_decision: crate::types::GateDecision,
) -> crate::types::WitnessLog {
crate::types::WitnessLog::new(
model_hash,
quant_hash,
backend,
cycles,
latency_ns,
gate_decision,
)
}
/// Command types for daemon protocol
pub mod commands {
/// Load model command
pub const LOAD_MODEL: u8 = 0x01;
/// Unload model command
pub const UNLOAD_MODEL: u8 = 0x02;
/// Inference request command
pub const INFER: u8 = 0x03;
/// Ping/health check command
pub const PING: u8 = 0x04;
/// Get status command
pub const STATUS: u8 = 0x05;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_frame_roundtrip() {
let model_id = ModelId::new([0x42u8; 32]);
let frame = RequestFrame::new(64, 256, 32000, &model_id, 0, 16);
let bytes = frame.to_bytes();
assert_eq!(bytes.len(), protocol::HEADER_SIZE);
assert_eq!(&bytes[0..4], &protocol::MAGIC.to_le_bytes());
}
#[test]
fn test_crc32() {
let data = b"test data";
let crc = crc32(data);
// CRC should be consistent
assert_eq!(crc, crc32(data));
}
#[test]
fn test_compute_topk() {
let logits: Vec<i16> = vec![100, 50, 300, 200, 150];
let topk = compute_topk(&logits, 3);
assert_eq!(topk.len(), 3);
assert_eq!(topk[0], (2, 300)); // Index 2, value 300
assert_eq!(topk[1], (3, 200)); // Index 3, value 200
assert_eq!(topk[2], (4, 150)); // Index 4, value 150
}
#[test]
fn test_compute_topk_large() {
let logits: Vec<i16> = (0..1000).map(|i| (i * 7 % 500) as i16).collect();
let topk = compute_topk(&logits, 10);
assert_eq!(topk.len(), 10);
// Should be sorted descending
for i in 1..topk.len() {
assert!(topk[i - 1].1 >= topk[i].1);
}
}
#[test]
fn test_validate_tokens() {
assert!(validate_tokens(&[0, 1, 2], 100).is_ok());
assert!(validate_tokens(&[99], 100).is_ok());
assert!(validate_tokens(&[100], 100).is_err());
assert!(validate_tokens(&[0, 50, 101], 100).is_err());
}
}

View File

@@ -0,0 +1,544 @@
//! Native Rust simulator backend
//!
//! Provides a pure-Rust implementation of the transformer inference
//! for testing, development, and fallback when no FPGA is available.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Instant;
use crate::artifact::ModelArtifact;
use crate::backend::{
compute_topk, read_lock, validate_tokens, write_lock, BackendStats, TransformerBackend,
};
use crate::error::{Error, Result};
use crate::gating::CoherenceGate;
use crate::quant::{dequantize_i8, quantize_i16, softmax_lut};
use crate::types::{
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
QuantSpec, SkipReason, WitnessLog,
};
/// Loaded model data for native simulation
struct LoadedModel {
/// Model artifact (contains weights and config)
artifact: ModelArtifact,
/// Precomputed embedding matrix (dequantized for sim)
embeddings: Vec<f32>,
/// Layer weights (simplified for simulation)
layers: Vec<LayerWeights>,
/// Output projection
output_proj: Vec<f32>,
}
/// Simplified layer weights for simulation
struct LayerWeights {
/// Attention Q projection
wq: Vec<f32>,
/// Attention K projection
wk: Vec<f32>,
/// Attention V projection
wv: Vec<f32>,
/// Attention output projection
wo: Vec<f32>,
/// FFN up projection
w1: Vec<f32>,
/// FFN down projection
w2: Vec<f32>,
/// Layer norm weights
ln1_weight: Vec<f32>,
ln2_weight: Vec<f32>,
}
/// Native simulator backend
pub struct NativeSimBackend {
/// Loaded models
models: RwLock<HashMap<ModelId, Arc<LoadedModel>>>,
/// Coherence gate
gate: Arc<dyn CoherenceGate>,
/// Statistics
stats: RwLock<BackendStats>,
/// Configuration
config: NativeSimConfig,
}
/// Configuration for native simulator
#[derive(Debug, Clone)]
pub struct NativeSimConfig {
/// Maximum models to keep loaded
pub max_models: usize,
/// Enable detailed tracing
pub trace: bool,
/// Use LUT-based softmax
pub lut_softmax: bool,
/// Number of layers to simulate (0 = all)
pub max_layers: usize,
}
impl Default for NativeSimConfig {
fn default() -> Self {
Self {
max_models: 8,
trace: false,
lut_softmax: true,
max_layers: 0,
}
}
}
impl NativeSimBackend {
/// Create a new native simulator backend
pub fn new(gate: Arc<dyn CoherenceGate>) -> Self {
Self::with_config(gate, NativeSimConfig::default())
}
/// Create with custom configuration
pub fn with_config(gate: Arc<dyn CoherenceGate>, config: NativeSimConfig) -> Self {
Self {
models: RwLock::new(HashMap::new()),
gate,
stats: RwLock::new(BackendStats::default()),
config,
}
}
/// Run the core transformer inference
fn run_inference(
&self,
model: &LoadedModel,
tokens: &[u16],
_attn_mask: &[u8],
gate_hint: &GateHint,
) -> Result<(Vec<i16>, GateDecision)> {
let shape = &model.artifact.manifest.shape;
let num_layers = model.layers.len();
// Check preflight gate
let preflight = self.gate.preflight(gate_hint);
if let GateDecision::Skipped { reason } = preflight {
return Ok((
vec![0i16; shape.vocab as usize],
GateDecision::Skipped { reason },
));
}
// Initialize hidden states from embeddings
let d_model = shape.d_model as usize;
let seq_len = tokens.len();
let mut hidden = vec![0.0f32; seq_len * d_model];
// Lookup embeddings
for (i, &token) in tokens.iter().enumerate() {
let offset = (token as usize) * d_model;
if offset + d_model <= model.embeddings.len() {
hidden[i * d_model..(i + 1) * d_model]
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
}
}
// Run through layers
let max_layers = if self.config.max_layers > 0 {
self.config.max_layers.min(num_layers)
} else {
num_layers
};
for layer_idx in 0..max_layers {
let layer = &model.layers[layer_idx];
// Check layer checkpoint for early exit
let coherence_signal = self.compute_coherence_signal(&hidden);
if let Some(decision) = self.gate.checkpoint(layer_idx as u8, coherence_signal) {
if let GateDecision::EarlyExit { layer } = decision {
// Early exit - compute output from current hidden state
let logits = self.compute_output(&hidden, &model.output_proj, shape);
return Ok((logits, GateDecision::EarlyExit { layer }));
}
}
// Simplified attention + FFN (for simulation purposes)
hidden = self.run_layer(&hidden, layer, shape);
}
// Compute output logits
let logits = self.compute_output(&hidden, &model.output_proj, shape);
Ok((logits, GateDecision::RanFull))
}
/// Run a single transformer layer
fn run_layer(&self, hidden: &[f32], layer: &LayerWeights, shape: &FixedShape) -> Vec<f32> {
let d_model = shape.d_model as usize;
let seq_len = hidden.len() / d_model;
// Simplified layer computation
// In a real implementation, this would do full attention + FFN
let mut output = hidden.to_vec();
// Layer norm 1
for t in 0..seq_len {
let start = t * d_model;
let end = start + d_model;
layer_norm_inplace(&mut output[start..end], &layer.ln1_weight);
}
// Simplified attention (just apply output projection as placeholder)
// Real implementation would compute Q, K, V, attention scores, etc.
if !layer.wo.is_empty() {
let mut attn_out = vec![0.0f32; output.len()];
for t in 0..seq_len {
for i in 0..d_model {
let mut sum = 0.0f32;
for j in 0..d_model.min(layer.wo.len() / d_model) {
sum += output[t * d_model + j] * layer.wo[j * d_model + i];
}
attn_out[t * d_model + i] = sum;
}
}
// Residual connection
for i in 0..output.len() {
output[i] += attn_out[i];
}
}
// Layer norm 2
for t in 0..seq_len {
let start = t * d_model;
let end = start + d_model;
layer_norm_inplace(&mut output[start..end], &layer.ln2_weight);
}
// Simplified FFN (SwiGLU-like)
if !layer.w1.is_empty() && !layer.w2.is_empty() {
let ffn_dim = layer.w1.len() / d_model;
let mut ffn_out = vec![0.0f32; output.len()];
for t in 0..seq_len {
// Up projection
let mut up = vec![0.0f32; ffn_dim];
for i in 0..ffn_dim {
for j in 0..d_model {
up[i] += output[t * d_model + j] * layer.w1[j * ffn_dim + i];
}
// SiLU activation
up[i] = up[i] * sigmoid(up[i]);
}
// Down projection
for i in 0..d_model {
for j in 0..ffn_dim.min(layer.w2.len() / d_model) {
ffn_out[t * d_model + i] += up[j] * layer.w2[j * d_model + i];
}
}
}
// Residual connection
for i in 0..output.len() {
output[i] += ffn_out[i];
}
}
output
}
/// Compute output logits from hidden state
fn compute_output(&self, hidden: &[f32], output_proj: &[f32], shape: &FixedShape) -> Vec<i16> {
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
let seq_len = hidden.len() / d_model;
// Take last token's hidden state
let last_hidden = &hidden[(seq_len - 1) * d_model..];
// Compute logits
let mut logits = vec![0.0f32; vocab];
if output_proj.len() >= d_model * vocab {
for v in 0..vocab {
for d in 0..d_model {
logits[v] += last_hidden[d] * output_proj[d * vocab + v];
}
}
} else {
// Fallback: random logits for simulation when weights not available
for v in 0..vocab {
logits[v] = (v as f32 * 0.01).sin();
}
}
// Apply softmax (optional) and quantize
if self.config.lut_softmax {
softmax_lut(&mut logits);
} else {
softmax_f32(&mut logits);
}
// Quantize to i16
quantize_i16(&logits)
}
/// Compute coherence signal for early exit decision
fn compute_coherence_signal(&self, hidden: &[f32]) -> i16 {
// Simple coherence metric: variance of hidden states
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
// Scale to Q8.8 fixed point
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
}
/// Prepare model from artifact (dequantize weights for simulation)
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<LoadedModel> {
let shape = &artifact.manifest.shape;
let quant = &artifact.manifest.quant;
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
// Dequantize embeddings
let embedding_size = vocab * d_model;
let embeddings = if artifact.weights.len() >= embedding_size {
dequantize_i8(&artifact.weights[..embedding_size], quant)
} else {
// Generate random embeddings for testing
(0..embedding_size)
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
.collect()
};
// Create simplified layer weights
let num_layers = 4; // Default for simulation
let layers: Vec<LayerWeights> = (0..num_layers)
.map(|_| LayerWeights {
wq: vec![0.01; d_model * d_model],
wk: vec![0.01; d_model * d_model],
wv: vec![0.01; d_model * d_model],
wo: vec![0.01; d_model * d_model],
w1: vec![0.01; d_model * 4 * d_model],
w2: vec![0.01; 4 * d_model * d_model],
ln1_weight: vec![1.0; d_model],
ln2_weight: vec![1.0; d_model],
})
.collect();
// Output projection
let output_proj = vec![0.01; d_model * vocab];
Ok(LoadedModel {
artifact: artifact.clone(),
embeddings,
layers,
output_proj,
})
}
}
impl TransformerBackend for NativeSimBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
// Prepare model
let model = self.prepare_model(artifact)?;
let model_id = artifact.model_id();
// Check capacity (with poison handling)
let at_capacity = read_lock(&self.models, |models| {
models.len() >= self.config.max_models && !models.contains_key(&model_id)
})?;
if at_capacity {
return Err(Error::ResourceExhausted("Max models reached".into()));
}
// Store model
write_lock(&self.models, |models| {
models.insert(model_id, Arc::new(model));
})?;
// Update stats
write_lock(&self.stats, |stats| {
stats.models_loaded += 1;
})?;
Ok(model_id)
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
let start = Instant::now();
// Validate request
req.validate()?;
// Get model (with poison handling)
let model = read_lock(&self.models, |models| models.get(&req.model).cloned())?
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate shape
if model.artifact.manifest.shape != req.shape {
return Err(Error::ShapeMismatch {
expected: model.artifact.manifest.shape,
actual: req.shape,
});
}
// Validate tokens against vocabulary
validate_tokens(req.tokens, model.artifact.manifest.shape.vocab)?;
// Run inference
let (logits_q, gate_decision) =
self.run_inference(&model, req.tokens, req.attn_mask, &req.gate_hint)?;
let latency_ns = start.elapsed().as_nanos() as u32;
// Compute top-K using common utility
let topk = compute_topk(&logits_q, 16);
// Create witness
let witness = WitnessLog::new(
model.artifact.model_hash(),
model.artifact.quant_hash(),
BackendKind::NativeSim,
0, // No cycles for simulator
latency_ns,
gate_decision,
);
// Update stats (with poison handling)
write_lock(&self.stats, |stats| {
stats.total_inferences += 1;
let n = stats.total_inferences;
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
GateDecision::Skipped { .. } => stats.skipped += 1,
_ => {}
}
})?;
Ok(InferenceResult::new(logits_q, Some(topk), witness))
}
fn unload(&self, model: ModelId) -> Result<()> {
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
if removed {
write_lock(&self.stats, |stats| {
stats.models_loaded = stats.models_loaded.saturating_sub(1);
})?;
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
}
fn kind(&self) -> BackendKind {
BackendKind::NativeSim
}
fn stats(&self) -> BackendStats {
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
}
}
// Helper functions
fn layer_norm_inplace(x: &mut [f32], weight: &[f32]) {
let mean = x.iter().sum::<f32>() / x.len() as f32;
let variance = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
let std = (variance + 1e-5).sqrt();
for (i, v) in x.iter_mut().enumerate() {
*v = (*v - mean) / std * weight.get(i).copied().unwrap_or(1.0);
}
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn softmax_f32(x: &mut [f32]) {
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
if sum > 0.0 {
for v in x.iter_mut() {
*v /= sum;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::gating::DefaultCoherenceGate;
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "test_model".into(),
model_hash: "0".repeat(64),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact {
manifest,
weights: vec![0u8; 4096 * 64], // Minimal embedding weights
bitstream: None,
calibration: None,
test_vectors: vec![],
signature: [0u8; 64],
pubkey: [0u8; 32],
}
}
#[test]
fn test_native_sim_load_unload() {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let artifact = create_test_artifact();
let model_id = backend.load(&artifact).unwrap();
assert!(backend.is_loaded(model_id));
backend.unload(model_id).unwrap();
assert!(!backend.is_loaded(model_id));
}
#[test]
fn test_native_sim_inference() {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let artifact = create_test_artifact();
let model_id = backend.load(&artifact).unwrap();
let tokens: Vec<u16> = (0..32).collect();
let mask = vec![1u8; 32];
let req = InferenceRequest::new(
model_id,
FixedShape::micro(),
&tokens,
&mask,
GateHint::allow_all(),
);
let result = backend.infer(req).unwrap();
assert!(!result.logits_q.is_empty());
assert!(result.topk.is_some());
assert_eq!(result.witness.backend, BackendKind::NativeSim);
}
}

View File

@@ -0,0 +1,348 @@
//! WASM Simulator backend
//!
//! Pure Rust implementation that runs in WASM environments.
//! Uses RefCell for interior mutability since WASM is single-threaded.
#![cfg(feature = "wasm")]
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use crate::artifact::ModelArtifact;
use crate::backend::{compute_topk, validate_tokens, BackendStats, TransformerBackend};
use crate::error::{Error, Result};
use crate::gating::CoherenceGate;
use crate::quant::{dequantize_i8, quantize_i16};
use crate::types::{
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
QuantSpec, WitnessLog,
};
/// Loaded model for WASM simulation
struct WasmModel {
/// Model artifact
artifact: ModelArtifact,
/// Prepacked embedding table (dequantized to f32 for computation)
embeddings: Vec<f32>,
/// Number of layers
num_layers: usize,
/// Shape info
shape: FixedShape,
}
/// WASM simulator backend state (interior mutable for single-threaded WASM)
struct WasmState {
/// Loaded models
models: HashMap<ModelId, WasmModel>,
/// Statistics
stats: BackendStats,
}
/// WASM simulator backend
///
/// Uses RefCell for interior mutability since WASM is inherently single-threaded.
/// This allows the TransformerBackend trait to be implemented with &self methods.
pub struct WasmSimBackend {
/// Interior mutable state
state: RefCell<WasmState>,
/// Coherence gate (immutable, shared)
gate: Rc<dyn CoherenceGate>,
}
impl WasmSimBackend {
/// Create a new WASM simulator backend
pub fn new(gate: Rc<dyn CoherenceGate>) -> Self {
Self {
state: RefCell::new(WasmState {
models: HashMap::new(),
stats: BackendStats::default(),
}),
gate,
}
}
/// Prepare model from artifact
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<WasmModel> {
let shape = artifact.manifest.shape;
let quant = &artifact.manifest.quant;
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
// Dequantize embeddings
let embedding_size = vocab * d_model;
let embeddings = if artifact.weights.len() >= embedding_size {
dequantize_i8(&artifact.weights[..embedding_size], quant)
} else {
// Generate deterministic embeddings for testing
(0..embedding_size)
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
.collect()
};
// Determine number of layers from artifact or default
let num_layers = if artifact.manifest.backend.options.early_exit {
6
} else {
4
};
Ok(WasmModel {
artifact: artifact.clone(),
embeddings,
num_layers,
shape,
})
}
/// Run inference for WASM
fn run_inference(
&self,
model: &WasmModel,
tokens: &[u16],
gate_hint: &GateHint,
) -> (Vec<i16>, GateDecision) {
let shape = &model.shape;
// Check preflight
let preflight = self.gate.preflight(gate_hint);
if !preflight.did_run() {
return (vec![0i16; shape.vocab as usize], preflight);
}
let vocab = shape.vocab as usize;
let d_model = shape.d_model as usize;
// Initialize hidden state from embeddings
let seq_len = tokens.len();
let mut hidden = vec![0.0f32; seq_len * d_model];
// Lookup embeddings with bounds checking
for (i, &token) in tokens.iter().enumerate() {
let offset = (token as usize).min(vocab.saturating_sub(1)) * d_model;
if offset + d_model <= model.embeddings.len() {
hidden[i * d_model..(i + 1) * d_model]
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
}
}
// Run through simplified layers with early exit support
for layer in 0..model.num_layers {
// Simple layer computation (for WASM we keep it lightweight)
// Apply simple transformation
for t in 0..seq_len {
let start = t * d_model;
// Simple ReLU-like activation
for i in 0..d_model {
hidden[start + i] =
hidden[start + i].max(0.0) * 0.99 + hidden[start + i] * 0.01;
}
}
// Check for early exit
let coherence_signal = compute_coherence(&hidden);
if let Some(decision) = self.gate.checkpoint(layer as u8, coherence_signal) {
let logits = self.compute_output(&hidden, model);
return (logits, decision);
}
}
// Compute output logits
let logits = self.compute_output(&hidden, model);
(logits, GateDecision::RanFull)
}
/// Compute output logits from hidden state
fn compute_output(&self, hidden: &[f32], model: &WasmModel) -> Vec<i16> {
let shape = &model.shape;
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
let seq_len = hidden.len() / d_model;
// Take last token's hidden state
let last_hidden = &hidden[(seq_len.saturating_sub(1)) * d_model..];
// Compute logits via dot product with embedding matrix (transposed)
let mut logits_f32 = vec![0.0f32; vocab];
for v in 0..vocab.min(model.embeddings.len() / d_model) {
let v_offset = v * d_model;
let mut dot = 0.0f32;
for d in 0..d_model.min(last_hidden.len()) {
if v_offset + d < model.embeddings.len() {
dot += last_hidden[d] * model.embeddings[v_offset + d];
}
}
logits_f32[v] = dot;
}
// Apply softmax and quantize
softmax_inplace(&mut logits_f32);
quantize_i16(&logits_f32)
}
}
// Note: WASM is single-threaded, so these trait bounds are satisfied trivially
// by never actually being used across threads
unsafe impl Send for WasmSimBackend {}
unsafe impl Sync for WasmSimBackend {}
impl TransformerBackend for WasmSimBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
// Prepare model
let model = self.prepare_model(artifact)?;
let model_id = artifact.model_id();
// Store in state
let mut state = self.state.borrow_mut();
state.models.insert(model_id, model);
state.stats.models_loaded += 1;
Ok(model_id)
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
let start = js_sys::Date::now();
// Validate request
req.validate()?;
// Get model (immutable borrow)
let state = self.state.borrow();
let model = state
.models
.get(&req.model)
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate tokens
validate_tokens(req.tokens, model.shape.vocab)?;
// Run inference
let (logits, gate_decision) = self.run_inference(model, req.tokens, &req.gate_hint);
let latency_ns = ((js_sys::Date::now() - start) * 1_000_000.0) as u32;
// Compute top-K
let topk = compute_topk(&logits, 16);
// Build witness
let witness = WitnessLog::new(
model.artifact.model_hash(),
model.artifact.quant_hash(),
BackendKind::WasmSim,
0, // No cycles for WASM sim
latency_ns,
gate_decision,
);
drop(state); // Release borrow before mutable borrow
// Update stats
{
let mut state = self.state.borrow_mut();
state.stats.total_inferences += 1;
let n = state.stats.total_inferences;
state.stats.avg_latency_ns =
(state.stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => state.stats.early_exits += 1,
GateDecision::Skipped { .. } => state.stats.skipped += 1,
_ => {}
}
}
Ok(InferenceResult::new(logits, Some(topk), witness))
}
fn unload(&self, model: ModelId) -> Result<()> {
let mut state = self.state.borrow_mut();
if state.models.remove(&model).is_some() {
state.stats.models_loaded = state.stats.models_loaded.saturating_sub(1);
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
self.state.borrow().models.contains_key(&model)
}
fn kind(&self) -> BackendKind {
BackendKind::WasmSim
}
fn stats(&self) -> BackendStats {
self.state.borrow().stats.clone()
}
}
/// Compute coherence signal from hidden state
fn compute_coherence(hidden: &[f32]) -> i16 {
if hidden.is_empty() {
return 0;
}
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
}
/// In-place softmax
fn softmax_inplace(x: &mut [f32]) {
if x.is_empty() {
return;
}
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
if sum > 0.0 {
for v in x.iter_mut() {
*v /= sum;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::gating::DefaultCoherenceGate;
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "wasm_test".into(),
model_hash: "0".repeat(64),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact {
manifest,
weights: vec![0u8; 4096 * 64],
bitstream: None,
calibration: None,
test_vectors: vec![],
signature: [0u8; 64],
pubkey: [0u8; 32],
}
}
#[test]
fn test_wasm_sim_prepare_model() {
let gate = Rc::new(DefaultCoherenceGate::new());
let backend = WasmSimBackend::new(gate);
let artifact = create_test_artifact();
let model = backend.prepare_model(&artifact).unwrap();
assert_eq!(model.shape.seq_len, 32);
assert!(!model.embeddings.is_empty());
}
}

View File

@@ -0,0 +1,136 @@
//! Error types for FPGA Transformer backend
use thiserror::Error;
/// Result type alias for FPGA Transformer operations
pub type Result<T> = std::result::Result<T, Error>;
/// FPGA Transformer error types
#[derive(Error, Debug)]
pub enum Error {
/// Model artifact is invalid or corrupted
#[error("Invalid artifact: {0}")]
InvalidArtifact(String),
/// Artifact signature verification failed
#[error("Signature verification failed: {0}")]
SignatureError(String),
/// Test vectors failed validation
#[error("Test vector validation failed: expected max error {expected}, got {actual}")]
TestVectorError { expected: i32, actual: i32 },
/// Model not found or not loaded
#[error("Model not found: {0:?}")]
ModelNotFound(crate::types::ModelId),
/// Shape mismatch between request and model
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
expected: crate::types::FixedShape,
actual: crate::types::FixedShape,
},
/// Input length does not match expected sequence length
#[error("Input length mismatch: expected {expected}, got {actual}")]
InputLengthMismatch { expected: usize, actual: usize },
/// Backend communication error
#[error("Backend error: {0}")]
BackendError(String),
/// Daemon connection failed
#[error("Daemon connection failed: {0}")]
DaemonConnectionError(String),
/// PCIe communication error
#[error("PCIe error: {0}")]
PcieError(String),
/// DMA operation failed
#[error("DMA error: {0}")]
DmaError(String),
/// Gating decision blocked inference
#[error("Inference blocked by gate: {reason:?}")]
GateBlocked { reason: crate::types::SkipReason },
/// Quantization error
#[error("Quantization error: {0}")]
QuantizationError(String),
/// Overflow during fixed-point computation
#[error("Fixed-point overflow at {location}")]
FixedPointOverflow { location: &'static str },
/// Invalid configuration
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
/// IO error
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
/// JSON parsing error
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
/// Checksum mismatch
#[error("Checksum mismatch: expected {expected:08x}, got {actual:08x}")]
ChecksumMismatch { expected: u32, actual: u32 },
/// Protocol version mismatch
#[error("Protocol version mismatch: expected {expected}, got {actual}")]
ProtocolMismatch { expected: u16, actual: u16 },
/// Timeout waiting for response
#[error("Timeout after {ms}ms")]
Timeout { ms: u64 },
/// Resource exhausted (memory, slots, etc.)
#[error("Resource exhausted: {0}")]
ResourceExhausted(String),
/// Feature not available in this build
#[error("Feature not available: {0}")]
FeatureNotAvailable(String),
}
impl Error {
/// Create a new InvalidArtifact error
pub fn invalid_artifact(msg: impl Into<String>) -> Self {
Self::InvalidArtifact(msg.into())
}
/// Create a new BackendError
pub fn backend(msg: impl Into<String>) -> Self {
Self::BackendError(msg.into())
}
/// Create a new DaemonConnectionError
pub fn daemon_connection(msg: impl Into<String>) -> Self {
Self::DaemonConnectionError(msg.into())
}
/// Check if this error is recoverable
pub fn is_recoverable(&self) -> bool {
matches!(
self,
Error::Timeout { .. }
| Error::DaemonConnectionError(_)
| Error::BackendError(_)
| Error::GateBlocked { .. }
)
}
/// Check if this error indicates a configuration problem
pub fn is_config_error(&self) -> bool {
matches!(
self,
Error::InvalidConfig(_)
| Error::ShapeMismatch { .. }
| Error::InputLengthMismatch { .. }
| Error::FeatureNotAvailable(_)
)
}
}

View File

@@ -0,0 +1,302 @@
//! C ABI bindings for FFI integration
//!
//! Provides a stable C interface for linking from other languages.
use std::ffi::{c_char, c_int, c_void, CStr};
use std::ptr;
use std::sync::Arc;
use crate::backend::native_sim::NativeSimBackend;
use crate::backend::TransformerBackend;
use crate::gating::DefaultCoherenceGate;
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
/// Opaque engine handle
pub struct FpgaEngine {
backend: Box<dyn TransformerBackend>,
}
/// Result code
#[repr(C)]
pub enum FpgaResult {
Ok = 0,
InvalidArgument = 1,
ModelNotFound = 2,
InferenceFailed = 3,
AllocationFailed = 4,
InvalidArtifact = 5,
}
/// Inference result structure
#[repr(C)]
pub struct FpgaInferenceResult {
/// Status code
pub status: FpgaResult,
/// Logits (caller must free with fpga_free_logits)
pub logits: *mut i16,
/// Number of logits
pub logits_len: usize,
/// Top-K results (token_id, logit pairs)
pub topk: *mut u32,
/// Number of top-K pairs
pub topk_len: usize,
/// Latency in nanoseconds
pub latency_ns: u32,
/// Compute cycles
pub cycles: u32,
/// Gate decision (0=full, 1=early_exit, 2=skipped)
pub gate_decision: u8,
/// Exit layer (if early exit)
pub exit_layer: u8,
}
/// Create a new FPGA engine with native simulator backend
///
/// Returns a handle that must be freed with `fpga_engine_destroy`
#[no_mangle]
pub extern "C" fn fpga_engine_create() -> *mut FpgaEngine {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = Box::new(NativeSimBackend::new(gate));
let engine = Box::new(FpgaEngine { backend });
Box::into_raw(engine)
}
/// Destroy an FPGA engine
#[no_mangle]
pub extern "C" fn fpga_engine_destroy(engine: *mut FpgaEngine) {
if !engine.is_null() {
unsafe {
drop(Box::from_raw(engine));
}
}
}
/// Load a model artifact
///
/// Returns model ID bytes (32 bytes) on success, NULL on failure
#[no_mangle]
pub extern "C" fn fpga_load_artifact(
engine: *mut FpgaEngine,
artifact_bytes: *const u8,
artifact_len: usize,
model_id_out: *mut u8,
) -> FpgaResult {
if engine.is_null() || artifact_bytes.is_null() || model_id_out.is_null() {
return FpgaResult::InvalidArgument;
}
let engine = unsafe { &mut *engine };
let artifact_slice = unsafe { std::slice::from_raw_parts(artifact_bytes, artifact_len) };
let artifact = match crate::artifact::unpack_artifact(artifact_slice) {
Ok(a) => a,
Err(_) => return FpgaResult::InvalidArtifact,
};
match engine.backend.load(&artifact) {
Ok(model_id) => {
unsafe {
ptr::copy_nonoverlapping(model_id.as_bytes().as_ptr(), model_id_out, 32);
}
FpgaResult::Ok
}
Err(_) => FpgaResult::InvalidArtifact,
}
}
/// Run inference
///
/// Result must be freed with `fpga_result_free`
#[no_mangle]
pub extern "C" fn fpga_infer(
engine: *mut FpgaEngine,
model_id: *const u8,
tokens: *const u16,
tokens_len: usize,
mask: *const u8,
mask_len: usize,
coherence_score: i16,
boundary_crossed: bool,
max_compute_class: u8,
) -> FpgaInferenceResult {
let error_result = || FpgaInferenceResult {
status: FpgaResult::InvalidArgument,
logits: ptr::null_mut(),
logits_len: 0,
topk: ptr::null_mut(),
topk_len: 0,
latency_ns: 0,
cycles: 0,
gate_decision: 2,
exit_layer: 0,
};
if engine.is_null() || model_id.is_null() || tokens.is_null() || mask.is_null() {
return error_result();
}
let engine = unsafe { &mut *engine };
// Parse model ID
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(id_slice);
let model = ModelId::new(id_bytes);
// Parse tokens and mask
let tokens_slice = unsafe { std::slice::from_raw_parts(tokens, tokens_len) };
let mask_slice = unsafe { std::slice::from_raw_parts(mask, mask_len) };
// Build shape (micro for C API)
let shape = FixedShape::micro();
// Build gate hint
let compute_class =
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
let gate_hint = GateHint::new(coherence_score, boundary_crossed, compute_class);
// Create request
let req = InferenceRequest::new(model, shape, tokens_slice, mask_slice, gate_hint);
// Run inference
match engine.backend.infer(req) {
Ok(result) => {
// Allocate logits with checked allocation (prevents panic on overflow)
let logits_len = result.logits_q.len();
let logits = if logits_len > 0 {
match std::alloc::Layout::array::<i16>(logits_len) {
Ok(layout) if layout.size() > 0 => {
let ptr = unsafe { std::alloc::alloc(layout) as *mut i16 };
if !ptr.is_null() {
unsafe {
ptr::copy_nonoverlapping(result.logits_q.as_ptr(), ptr, logits_len);
}
}
ptr
}
_ => ptr::null_mut(), // Return null on allocation failure
}
} else {
ptr::null_mut()
};
// Allocate top-K with checked allocation
let (topk, topk_len) = if let Some(ref tk) = result.topk {
let len = tk.len() * 2; // (token, logit) pairs
match std::alloc::Layout::array::<u32>(len) {
Ok(layout) if layout.size() > 0 => {
let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 };
if !ptr.is_null() {
for (i, (token, logit)) in tk.iter().enumerate() {
unsafe {
*ptr.add(i * 2) = *token as u32;
*ptr.add(i * 2 + 1) = *logit as u32;
}
}
}
(ptr, tk.len())
}
_ => (ptr::null_mut(), 0), // Return null on allocation failure
}
} else {
(ptr::null_mut(), 0)
};
// Encode gate decision
let (gate_decision, exit_layer) = match result.witness.gate_decision {
crate::types::GateDecision::RanFull => (0, 0),
crate::types::GateDecision::EarlyExit { layer } => (1, layer),
crate::types::GateDecision::Skipped { .. } => (2, 0),
};
FpgaInferenceResult {
status: FpgaResult::Ok,
logits,
logits_len,
topk,
topk_len,
latency_ns: result.witness.latency_ns,
cycles: result.witness.cycles,
gate_decision,
exit_layer,
}
}
Err(_) => {
let mut result = error_result();
result.status = FpgaResult::InferenceFailed;
result
}
}
}
/// Free inference result
#[no_mangle]
pub extern "C" fn fpga_result_free(result: *mut FpgaInferenceResult) {
if result.is_null() {
return;
}
unsafe {
let r = &mut *result;
if !r.logits.is_null() && r.logits_len > 0 {
std::alloc::dealloc(
r.logits as *mut u8,
std::alloc::Layout::array::<i16>(r.logits_len).unwrap(),
);
r.logits = ptr::null_mut();
}
if !r.topk.is_null() && r.topk_len > 0 {
std::alloc::dealloc(
r.topk as *mut u8,
std::alloc::Layout::array::<u32>(r.topk_len * 2).unwrap(),
);
r.topk = ptr::null_mut();
}
}
}
/// Unload a model
#[no_mangle]
pub extern "C" fn fpga_unload(engine: *mut FpgaEngine, model_id: *const u8) -> FpgaResult {
if engine.is_null() || model_id.is_null() {
return FpgaResult::InvalidArgument;
}
let engine = unsafe { &mut *engine };
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(id_slice);
let model = ModelId::new(id_bytes);
match engine.backend.unload(model) {
Ok(()) => FpgaResult::Ok,
Err(_) => FpgaResult::ModelNotFound,
}
}
/// Check if a model is loaded
#[no_mangle]
pub extern "C" fn fpga_is_loaded(engine: *const FpgaEngine, model_id: *const u8) -> bool {
if engine.is_null() || model_id.is_null() {
return false;
}
let engine = unsafe { &*engine };
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(id_slice);
let model = ModelId::new(id_bytes);
engine.backend.is_loaded(model)
}
/// Get version string
#[no_mangle]
pub extern "C" fn fpga_version() -> *const c_char {
// Static string with null terminator
static VERSION: &[u8] = b"0.1.0\0";
VERSION.as_ptr() as *const c_char
}

View File

@@ -0,0 +1,8 @@
//! Foreign function interfaces for FPGA Transformer
//!
//! Provides C ABI and WASM bindings.
#[cfg(feature = "wasm")]
pub mod wasm_bindgen;
pub mod c_abi;

View File

@@ -0,0 +1,282 @@
//! WASM bindings via wasm-bindgen
//!
//! Provides the same API shape for browser and Node.js environments.
#![cfg(feature = "wasm")]
use js_sys::{Array, Int16Array, Object, Reflect, Uint16Array, Uint8Array};
use wasm_bindgen::prelude::*;
use crate::artifact::{unpack_artifact, ModelArtifact};
use crate::backend::native_sim::{NativeSimBackend, NativeSimConfig};
use crate::backend::TransformerBackend;
use crate::gating::DefaultCoherenceGate;
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
use std::sync::Arc;
/// WASM Engine for transformer inference
#[wasm_bindgen]
pub struct WasmEngine {
backend: NativeSimBackend,
loaded_models: Vec<ModelId>,
last_witness: Option<crate::types::WitnessLog>,
}
#[wasm_bindgen]
impl WasmEngine {
/// Create a new WASM engine
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
// Use permissive config for WASM
let config = NativeSimConfig {
max_models: 4,
trace: false,
lut_softmax: true,
max_layers: 0,
};
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::with_config(gate, config);
Self {
backend,
loaded_models: Vec::new(),
last_witness: None,
}
}
/// Load a model artifact from bytes
///
/// Returns the model ID as a Uint8Array on success
#[wasm_bindgen(js_name = loadArtifact)]
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<Uint8Array, JsValue> {
let artifact = unpack_artifact(artifact_bytes)
.map_err(|e| JsValue::from_str(&format!("Failed to unpack artifact: {}", e)))?;
let model_id = self
.backend
.load(&artifact)
.map_err(|e| JsValue::from_str(&format!("Failed to load model: {}", e)))?;
self.loaded_models.push(model_id);
// Return model ID as Uint8Array
let id_array = Uint8Array::new_with_length(32);
id_array.copy_from(model_id.as_bytes());
Ok(id_array)
}
/// Run inference
///
/// Returns an object with logits, topk, and witness
#[wasm_bindgen]
pub fn infer(
&mut self,
model_id: &[u8],
tokens: &[u16],
mask: &[u8],
coherence_score_q: i16,
boundary_crossed: bool,
max_compute_class: u8,
) -> Result<JsValue, JsValue> {
// Parse model ID
if model_id.len() != 32 {
return Err(JsValue::from_str("Model ID must be 32 bytes"));
}
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(model_id);
let model = ModelId::new(id_bytes);
// Get shape from loaded model
// For WASM, we use micro shape by default
let shape = FixedShape::micro();
// Validate input length
if tokens.len() != shape.seq_len as usize {
return Err(JsValue::from_str(&format!(
"Token length mismatch: expected {}, got {}",
shape.seq_len,
tokens.len()
)));
}
// Build gate hint
let compute_class =
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
let gate_hint = GateHint::new(coherence_score_q, boundary_crossed, compute_class);
// Create request
let req = InferenceRequest::new(model, shape, tokens, mask, gate_hint);
// Run inference
let result = self
.backend
.infer(req)
.map_err(|e| JsValue::from_str(&format!("Inference failed: {}", e)))?;
// Store witness
self.last_witness = Some(result.witness.clone());
// Build result object
let obj = Object::new();
// Add logits
let logits = Int16Array::new_with_length(result.logits_q.len() as u32);
logits.copy_from(&result.logits_q);
Reflect::set(&obj, &"logits".into(), &logits)?;
// Add top-K if available
if let Some(topk) = &result.topk {
let topk_array = Array::new();
for (token, logit) in topk {
let pair = Array::new();
pair.push(&JsValue::from(*token));
pair.push(&JsValue::from(*logit));
topk_array.push(&pair);
}
Reflect::set(&obj, &"topk".into(), &topk_array)?;
}
// Add witness info
let witness = Object::new();
Reflect::set(
&witness,
&"backend".into(),
&format!("{:?}", result.witness.backend).into(),
)?;
Reflect::set(
&witness,
&"cycles".into(),
&JsValue::from(result.witness.cycles),
)?;
Reflect::set(
&witness,
&"latency_ns".into(),
&JsValue::from(result.witness.latency_ns),
)?;
Reflect::set(
&witness,
&"gate_decision".into(),
&format!("{:?}", result.witness.gate_decision).into(),
)?;
Reflect::set(&obj, &"witness".into(), &witness)?;
Ok(obj.into())
}
/// Get the last witness log as JSON
#[wasm_bindgen(js_name = getWitness)]
pub fn get_witness(&self) -> Result<JsValue, JsValue> {
match &self.last_witness {
Some(witness) => {
let json = serde_json::to_string(witness)
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))?;
Ok(JsValue::from_str(&json))
}
None => Ok(JsValue::NULL),
}
}
/// Get list of loaded model IDs
#[wasm_bindgen(js_name = getLoadedModels)]
pub fn get_loaded_models(&self) -> Array {
let arr = Array::new();
for id in &self.loaded_models {
let id_array = Uint8Array::new_with_length(32);
id_array.copy_from(id.as_bytes());
arr.push(&id_array);
}
arr
}
/// Unload a model
#[wasm_bindgen]
pub fn unload(&mut self, model_id: &[u8]) -> Result<(), JsValue> {
if model_id.len() != 32 {
return Err(JsValue::from_str("Model ID must be 32 bytes"));
}
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(model_id);
let model = ModelId::new(id_bytes);
self.backend
.unload(model)
.map_err(|e| JsValue::from_str(&format!("Unload failed: {}", e)))?;
self.loaded_models.retain(|id| *id != model);
Ok(())
}
/// Get backend statistics
#[wasm_bindgen(js_name = getStats)]
pub fn get_stats(&self) -> Result<JsValue, JsValue> {
let stats = self.backend.stats();
let obj = Object::new();
Reflect::set(
&obj,
&"models_loaded".into(),
&JsValue::from(stats.models_loaded as u32),
)?;
Reflect::set(
&obj,
&"total_inferences".into(),
&JsValue::from(stats.total_inferences as f64),
)?;
Reflect::set(
&obj,
&"avg_latency_ns".into(),
&JsValue::from(stats.avg_latency_ns as f64),
)?;
Reflect::set(
&obj,
&"early_exits".into(),
&JsValue::from(stats.early_exits as f64),
)?;
Reflect::set(
&obj,
&"skipped".into(),
&JsValue::from(stats.skipped as f64),
)?;
Ok(obj.into())
}
}
impl Default for WasmEngine {
fn default() -> Self {
Self::new()
}
}
/// Utility function to create a micro shape configuration
#[wasm_bindgen(js_name = microShape)]
pub fn micro_shape() -> Result<JsValue, JsValue> {
let shape = FixedShape::micro();
let obj = Object::new();
Reflect::set(&obj, &"seq_len".into(), &JsValue::from(shape.seq_len))?;
Reflect::set(&obj, &"d_model".into(), &JsValue::from(shape.d_model))?;
Reflect::set(&obj, &"heads".into(), &JsValue::from(shape.heads))?;
Reflect::set(&obj, &"d_head".into(), &JsValue::from(shape.d_head))?;
Reflect::set(&obj, &"vocab".into(), &JsValue::from(shape.vocab))?;
Ok(obj.into())
}
/// Utility function to validate an artifact without loading
#[wasm_bindgen(js_name = validateArtifact)]
pub fn validate_artifact(artifact_bytes: &[u8]) -> Result<JsValue, JsValue> {
let artifact = unpack_artifact(artifact_bytes)
.map_err(|e| JsValue::from_str(&format!("Invalid artifact: {}", e)))?;
artifact
.validate()
.map_err(|e| JsValue::from_str(&format!("Validation failed: {}", e)))?;
let obj = Object::new();
Reflect::set(&obj, &"name".into(), &artifact.manifest.name.into())?;
Reflect::set(&obj, &"valid".into(), &JsValue::TRUE)?;
Ok(obj.into())
}

View File

@@ -0,0 +1,301 @@
//! Coherence-based gating for inference control
use crate::types::{ComputeClass, GateDecision, GateHint, SkipReason};
use crate::witness::WitnessLog;
/// Trait for coherence-based gating
pub trait CoherenceGate: Send + Sync {
/// Preflight check before inference
///
/// Returns a gate decision based on coherence signals.
fn preflight(&self, hint: &GateHint) -> GateDecision;
/// Layer checkpoint for early exit decisions
///
/// Called after each layer to determine if early exit is appropriate.
/// Returns Some(decision) to exit early, None to continue.
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision>;
/// Check if write is allowed based on witness
///
/// Used to gate state changes in memory systems.
fn allow_write(&self, witness: &WitnessLog) -> bool;
}
/// Configuration for coherence gate
#[derive(Debug, Clone)]
pub struct CoherenceConfig {
/// Minimum coherence score to run (Q8.8)
pub min_coherence: i16,
/// Coherence threshold for early exit
pub early_exit_threshold: i16,
/// Enable early exit
pub early_exit_enabled: bool,
/// Minimum layers before early exit
pub min_layers: u8,
/// Require stable coherence for writes
pub require_stable_for_write: bool,
/// Minimum coherence for writes (Q8.8)
pub min_write_coherence: i16,
}
impl Default for CoherenceConfig {
fn default() -> Self {
Self {
min_coherence: -256, // -1.0 in Q8.8, very permissive
early_exit_threshold: 512, // 2.0 in Q8.8
early_exit_enabled: true,
min_layers: 2,
require_stable_for_write: true,
min_write_coherence: 0, // Require non-negative coherence
}
}
}
impl CoherenceConfig {
/// Create a strict configuration
pub fn strict() -> Self {
Self {
min_coherence: 0,
early_exit_threshold: 256,
early_exit_enabled: true,
min_layers: 4,
require_stable_for_write: true,
min_write_coherence: 128,
}
}
/// Create a permissive configuration (always allows)
pub fn permissive() -> Self {
Self {
min_coherence: i16::MIN,
early_exit_threshold: i16::MAX,
early_exit_enabled: false,
min_layers: 0,
require_stable_for_write: false,
min_write_coherence: i16::MIN,
}
}
}
/// Default coherence gate implementation
pub struct DefaultCoherenceGate {
config: CoherenceConfig,
}
impl DefaultCoherenceGate {
/// Create with default config
pub fn new() -> Self {
Self::with_config(CoherenceConfig::default())
}
/// Create with custom config
pub fn with_config(config: CoherenceConfig) -> Self {
Self { config }
}
/// Check if compute class allows operation
fn check_compute_class(&self, hint: &GateHint) -> bool {
// Reflex class can always run (fast path)
// Higher classes require sufficient coherence
match hint.max_compute_class {
ComputeClass::Reflex => true,
ComputeClass::Associative => hint.coherence_score_q >= self.config.min_coherence / 2,
ComputeClass::Deliberative => hint.coherence_score_q >= self.config.min_coherence,
}
}
}
impl Default for DefaultCoherenceGate {
fn default() -> Self {
Self::new()
}
}
impl CoherenceGate for DefaultCoherenceGate {
fn preflight(&self, hint: &GateHint) -> GateDecision {
// Check minimum coherence
if hint.coherence_score_q < self.config.min_coherence {
return GateDecision::Skipped {
reason: SkipReason::LowCoherence,
};
}
// Check compute class restrictions
if !self.check_compute_class(hint) {
return GateDecision::Skipped {
reason: SkipReason::BudgetExceeded,
};
}
// Allow full inference
GateDecision::RanFull
}
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
if !self.config.early_exit_enabled {
return None;
}
if layer < self.config.min_layers {
return None;
}
// Check if coherence signal is high enough to exit early
if signal_q >= self.config.early_exit_threshold {
return Some(GateDecision::EarlyExit { layer });
}
None
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
// Skip writes if inference was skipped
if !witness.gate_decision.did_run() {
return false;
}
// If we require stable coherence, only allow writes after full run
if self.config.require_stable_for_write {
matches!(witness.gate_decision, GateDecision::RanFull)
} else {
true
}
}
}
/// Mincut-aware coherence gate
///
/// Uses mincut signals to make more informed gating decisions.
pub struct MincutCoherenceGate {
base: DefaultCoherenceGate,
/// Minimum lambda (mincut value) for inference
pub min_lambda: i16,
/// Lambda threshold for early exit
pub lambda_exit_threshold: i16,
}
impl MincutCoherenceGate {
/// Create a new mincut-aware gate
pub fn new(config: CoherenceConfig, min_lambda: i16, lambda_exit_threshold: i16) -> Self {
Self {
base: DefaultCoherenceGate::with_config(config),
min_lambda,
lambda_exit_threshold,
}
}
}
impl CoherenceGate for MincutCoherenceGate {
fn preflight(&self, hint: &GateHint) -> GateDecision {
// Use base coherence check
let base_decision = self.base.preflight(hint);
if !base_decision.did_run() {
return base_decision;
}
// Additional mincut check
// If boundary was crossed and coherence is low, skip
if hint.boundary_crossed && hint.coherence_score_q < 0 {
return GateDecision::Skipped {
reason: SkipReason::LowCoherence,
};
}
GateDecision::RanFull
}
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
// Use base checkpoint with mincut-adjusted threshold
let adjusted_threshold = if signal_q > self.lambda_exit_threshold {
// High lambda suggests stable state, lower exit threshold
self.base.config.early_exit_threshold / 2
} else {
self.base.config.early_exit_threshold
};
if layer >= self.base.config.min_layers && signal_q >= adjusted_threshold {
return Some(GateDecision::EarlyExit { layer });
}
None
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
// Use base write check
self.base.allow_write(witness)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_gate_preflight() {
let gate = DefaultCoherenceGate::new();
// High coherence should pass
let hint = GateHint::new(256, false, ComputeClass::Deliberative);
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
// Low coherence should fail
let hint = GateHint::new(-512, false, ComputeClass::Deliberative);
assert!(matches!(
gate.preflight(&hint),
GateDecision::Skipped { .. }
));
}
#[test]
fn test_early_exit_checkpoint() {
let gate = DefaultCoherenceGate::new();
// Layer 0 - too early
assert!(gate.checkpoint(0, 1000).is_none());
// Layer 4 with high signal - should exit
let decision = gate.checkpoint(4, 1000);
assert!(matches!(
decision,
Some(GateDecision::EarlyExit { layer: 4 })
));
}
#[test]
fn test_write_gating() {
let gate = DefaultCoherenceGate::new();
// Full run should allow writes
let witness = crate::witness::WitnessLog::empty();
assert!(gate.allow_write(&witness));
// Skipped should not allow writes
let mut skipped_witness = crate::witness::WitnessLog::empty();
skipped_witness.gate_decision = GateDecision::Skipped {
reason: SkipReason::LowCoherence,
};
assert!(!gate.allow_write(&skipped_witness));
}
#[test]
fn test_strict_config() {
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::strict());
// Strict should require positive coherence
let hint = GateHint::new(-1, false, ComputeClass::Deliberative);
assert!(matches!(
gate.preflight(&hint),
GateDecision::Skipped { .. }
));
}
#[test]
fn test_permissive_config() {
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::permissive());
// Permissive should allow anything
let hint = GateHint::new(i16::MIN, true, ComputeClass::Reflex);
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
}
}

View File

@@ -0,0 +1,95 @@
//! Gating subsystem for coherence-based inference control
//!
//! Provides preflight and postflight gates that integrate mincut signals
//! and write policies for memory safety.
pub mod coherence_gate;
pub mod policy_gate;
pub use coherence_gate::{CoherenceConfig, CoherenceGate, DefaultCoherenceGate};
pub use policy_gate::{DefaultPolicyGate, PolicyGate, WritePolicy};
use crate::types::{GateDecision, GateHint, SkipReason};
use crate::witness::WitnessLog;
/// Combined gate that checks both coherence and policy
pub struct CombinedGate {
coherence: Box<dyn CoherenceGate>,
policy: Box<dyn PolicyGate>,
}
impl CombinedGate {
/// Create a new combined gate
pub fn new(coherence: Box<dyn CoherenceGate>, policy: Box<dyn PolicyGate>) -> Self {
Self { coherence, policy }
}
/// Create with default implementations
pub fn default_gates() -> Self {
Self::new(
Box::new(DefaultCoherenceGate::new()),
Box::new(DefaultPolicyGate::new()),
)
}
/// Preflight check before inference
pub fn preflight(&self, hint: &GateHint) -> GateDecision {
// First check policy
if !self.policy.allow_inference(hint) {
return GateDecision::Skipped {
reason: SkipReason::PolicyDenied,
};
}
// Then check coherence
self.coherence.preflight(hint)
}
/// Checkpoint during inference
pub fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
self.coherence.checkpoint(layer, signal_q)
}
/// Check if write is allowed after inference
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
self.coherence.allow_write(witness) && self.policy.allow_write(witness)
}
}
impl CoherenceGate for CombinedGate {
fn preflight(&self, hint: &GateHint) -> GateDecision {
CombinedGate::preflight(self, hint)
}
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
CombinedGate::checkpoint(self, layer, signal_q)
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
CombinedGate::allow_write(self, witness)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_combined_gate_preflight() {
let gate = CombinedGate::default_gates();
// Allow all hint should pass
let decision = gate.preflight(&GateHint::allow_all());
assert!(matches!(decision, GateDecision::RanFull));
}
#[test]
fn test_combined_gate_low_coherence() {
let gate = CombinedGate::default_gates();
// Very low coherence should skip
let hint = GateHint::new(-1000, false, crate::types::ComputeClass::Reflex);
let decision = gate.preflight(&hint);
assert!(matches!(decision, GateDecision::Skipped { .. }));
}
}

View File

@@ -0,0 +1,305 @@
//! Policy-based gating for access control and resource management
use crate::types::{ComputeClass, GateHint};
use crate::witness::WitnessLog;
/// Trait for policy-based gating
pub trait PolicyGate: Send + Sync {
/// Check if inference is allowed
fn allow_inference(&self, hint: &GateHint) -> bool;
/// Check if write is allowed after inference
fn allow_write(&self, witness: &WitnessLog) -> bool;
/// Get remaining compute budget
fn remaining_budget(&self) -> Option<u64>;
/// Record compute usage
fn record_usage(&self, cycles: u32);
}
/// Write policy configuration
#[derive(Debug, Clone)]
pub struct WritePolicy {
/// Allow writes after early exit
pub allow_early_exit_writes: bool,
/// Maximum latency (ns) for write eligibility
pub max_latency_ns: u32,
/// Require specific backend
pub required_backend: Option<crate::types::BackendKind>,
/// Minimum compute class for writes
pub min_compute_class: ComputeClass,
}
impl Default for WritePolicy {
fn default() -> Self {
Self {
allow_early_exit_writes: false,
max_latency_ns: u32::MAX,
required_backend: None,
min_compute_class: ComputeClass::Reflex,
}
}
}
impl WritePolicy {
/// Create a strict write policy
pub fn strict() -> Self {
Self {
allow_early_exit_writes: false,
max_latency_ns: 10_000_000, // 10ms
required_backend: None,
min_compute_class: ComputeClass::Deliberative,
}
}
/// Create a permissive write policy
pub fn permissive() -> Self {
Self {
allow_early_exit_writes: true,
max_latency_ns: u32::MAX,
required_backend: None,
min_compute_class: ComputeClass::Reflex,
}
}
/// Require FPGA backend for writes
pub fn require_fpga(mut self) -> Self {
self.required_backend = Some(crate::types::BackendKind::FpgaPcie);
self
}
}
/// Default policy gate implementation
pub struct DefaultPolicyGate {
write_policy: WritePolicy,
/// Compute budget (total cycles allowed, 0 = unlimited)
budget_cycles: std::sync::atomic::AtomicU64,
/// Used cycles
used_cycles: std::sync::atomic::AtomicU64,
}
impl DefaultPolicyGate {
/// Create with default policy
pub fn new() -> Self {
Self::with_policy(WritePolicy::default())
}
/// Create with custom write policy
pub fn with_policy(write_policy: WritePolicy) -> Self {
Self {
write_policy,
budget_cycles: std::sync::atomic::AtomicU64::new(0),
used_cycles: std::sync::atomic::AtomicU64::new(0),
}
}
/// Set compute budget
pub fn set_budget(&self, cycles: u64) {
self.budget_cycles
.store(cycles, std::sync::atomic::Ordering::SeqCst);
}
/// Reset used cycles
pub fn reset_usage(&self) {
self.used_cycles
.store(0, std::sync::atomic::Ordering::SeqCst);
}
}
impl Default for DefaultPolicyGate {
fn default() -> Self {
Self::new()
}
}
impl PolicyGate for DefaultPolicyGate {
fn allow_inference(&self, hint: &GateHint) -> bool {
// Check compute budget
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
if budget > 0 {
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
if used >= budget {
return false;
}
}
// Check compute class restrictions
// Always allow reflex, check others based on config
hint.max_compute_class >= ComputeClass::Reflex
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
// Check if inference ran
if !witness.gate_decision.did_run() {
return false;
}
// Check early exit policy
if !self.write_policy.allow_early_exit_writes {
if let crate::types::GateDecision::EarlyExit { .. } = witness.gate_decision {
return false;
}
}
// Check latency
if witness.latency_ns > self.write_policy.max_latency_ns {
return false;
}
// Check backend requirement
if let Some(required) = self.write_policy.required_backend {
if witness.backend != required {
return false;
}
}
true
}
fn remaining_budget(&self) -> Option<u64> {
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
if budget == 0 {
return None;
}
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
Some(budget.saturating_sub(used))
}
fn record_usage(&self, cycles: u32) {
self.used_cycles
.fetch_add(cycles as u64, std::sync::atomic::Ordering::SeqCst);
}
}
/// Rate-limited policy gate
pub struct RateLimitedPolicyGate {
base: DefaultPolicyGate,
/// Maximum inferences per second
max_inferences_per_sec: u32,
/// Inference count in current window
inference_count: std::sync::atomic::AtomicU32,
/// Window start time
window_start: std::sync::RwLock<std::time::Instant>,
}
impl RateLimitedPolicyGate {
/// Create with rate limit
pub fn new(max_inferences_per_sec: u32, write_policy: WritePolicy) -> Self {
Self {
base: DefaultPolicyGate::with_policy(write_policy),
max_inferences_per_sec,
inference_count: std::sync::atomic::AtomicU32::new(0),
window_start: std::sync::RwLock::new(std::time::Instant::now()),
}
}
/// Check and update rate limit
fn check_rate_limit(&self) -> bool {
let now = std::time::Instant::now();
// Check if we need to reset the window
{
let window_start = self.window_start.read().unwrap();
if now.duration_since(*window_start).as_secs() >= 1 {
drop(window_start);
let mut window_start = self.window_start.write().unwrap();
*window_start = now;
self.inference_count
.store(0, std::sync::atomic::Ordering::SeqCst);
}
}
// Check count
let count = self
.inference_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
count < self.max_inferences_per_sec
}
}
impl PolicyGate for RateLimitedPolicyGate {
fn allow_inference(&self, hint: &GateHint) -> bool {
// Check rate limit first
if !self.check_rate_limit() {
return false;
}
self.base.allow_inference(hint)
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
self.base.allow_write(witness)
}
fn remaining_budget(&self) -> Option<u64> {
self.base.remaining_budget()
}
fn record_usage(&self, cycles: u32) {
self.base.record_usage(cycles);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy_allows_inference() {
let gate = DefaultPolicyGate::new();
let hint = GateHint::allow_all();
assert!(gate.allow_inference(&hint));
}
#[test]
fn test_budget_limiting() {
let gate = DefaultPolicyGate::new();
gate.set_budget(1000);
let hint = GateHint::allow_all();
// Should allow initially
assert!(gate.allow_inference(&hint));
// Record usage exceeding budget
gate.record_usage(1500);
// Should deny now
assert!(!gate.allow_inference(&hint));
// Reset and check again
gate.reset_usage();
assert!(gate.allow_inference(&hint));
}
#[test]
fn test_write_policy_early_exit() {
let gate = DefaultPolicyGate::with_policy(WritePolicy::default());
let mut witness = crate::witness::WitnessLog::empty();
witness.gate_decision = crate::types::GateDecision::EarlyExit { layer: 3 };
// Default policy denies early exit writes
assert!(!gate.allow_write(&witness));
// Permissive policy allows
let permissive = DefaultPolicyGate::with_policy(WritePolicy::permissive());
assert!(permissive.allow_write(&witness));
}
#[test]
fn test_write_policy_latency() {
let mut policy = WritePolicy::default();
policy.max_latency_ns = 1000;
let gate = DefaultPolicyGate::with_policy(policy);
let mut witness = crate::witness::WitnessLog::empty();
witness.latency_ns = 500;
assert!(gate.allow_write(&witness));
witness.latency_ns = 2000;
assert!(!gate.allow_write(&witness));
}
}

View File

@@ -0,0 +1,331 @@
//! # FPGA Transformer Backend
//!
//! Ultra low latency transformer inference with FPGA acceleration,
//! coherence gating, and deterministic execution.
//!
//! ## Features
//!
//! - **Deterministic latency paths**: Fixed shape inference with bounded timing
//! - **Quantization first design**: Explicit INT4/INT8 quantization with reproducible math
//! - **Zero allocation hot path**: No heap allocations during inference
//! - **Coherence gating**: Mincut-integrated gate decisions
//! - **Multiple backends**: FPGA PCIe, FPGA Daemon, Native Sim, WASM Sim
//! - **Witness logging**: Auditable inference with ReasoningBank integration
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use ruvector_fpga_transformer::{Engine, artifact::ModelArtifact};
//! use ruvector_fpga_transformer::backend::native_sim::NativeSimBackend;
//! use ruvector_fpga_transformer::gating::DefaultCoherenceGate;
//! use ruvector_fpga_transformer::types::{InferenceRequest, GateHint, FixedShape};
//! use std::sync::Arc;
//!
//! // Create backend and gate
//! let gate = Arc::new(DefaultCoherenceGate::new());
//! let backend = NativeSimBackend::new(gate.clone());
//!
//! // Create engine
//! let mut engine = Engine::new(Box::new(backend), gate);
//!
//! // Load artifact (from file or bytes)
//! // let model_id = engine.load_artifact(&artifact_bytes)?;
//!
//! // Run inference
//! // let result = engine.infer(request)?;
//! ```
//!
//! ## Backend Selection
//!
//! The crate supports multiple backends selected at runtime:
//!
//! - `FpgaPcie`: Direct PCIe access to FPGA (requires `pcie` feature)
//! - `FpgaDaemon`: Communication via local daemon (requires `daemon` feature)
//! - `NativeSim`: Pure Rust simulator (requires `native_sim` feature)
//! - `WasmSim`: WASM-compatible simulator (requires `wasm` feature)
//!
//! ## Artifact Format
//!
//! Models are packaged as signed artifacts containing:
//! - Manifest with shape and quantization metadata
//! - Quantized weights
//! - Optional FPGA bitstream
//! - Test vectors for validation
//! - Ed25519 signature
#![warn(missing_docs)]
#![cfg_attr(feature = "wasm", allow(unused_imports))]
pub mod artifact;
pub mod backend;
pub mod error;
pub mod ffi;
pub mod gating;
pub mod quant;
pub mod types;
pub mod witness;
pub use artifact::ModelArtifact;
pub use backend::TransformerBackend;
pub use error::{Error, Result};
pub use gating::CoherenceGate;
pub use types::{
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
InferenceResult, Layout, ModelId, QuantSpec, QuantizedTensor, SkipReason, WitnessLog,
};
use std::sync::Arc;
/// Crate version
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Main engine for FPGA transformer inference
///
/// The engine combines a backend (FPGA, simulator, etc.) with a coherence gate
/// for controlled inference execution.
pub struct Engine {
/// Backend for inference execution
backend: Box<dyn TransformerBackend>,
/// Coherence gate for decision making
gate: Arc<dyn CoherenceGate>,
/// Loaded models
models: std::collections::HashMap<ModelId, ModelInfo>,
/// Inference statistics
stats: EngineStats,
}
/// Information about a loaded model
#[derive(Debug, Clone)]
pub struct ModelInfo {
/// Model artifact
pub artifact: ModelArtifact,
/// Shape configuration
pub shape: FixedShape,
/// Quantization spec
pub quant: QuantSpec,
}
/// Engine statistics
#[derive(Debug, Default, Clone)]
pub struct EngineStats {
/// Total inferences
pub total_inferences: u64,
/// Successful inferences
pub successful: u64,
/// Skipped inferences
pub skipped: u64,
/// Early exits
pub early_exits: u64,
/// Total latency (ns)
pub total_latency_ns: u64,
}
impl Engine {
/// Create a new engine with the specified backend and gate
pub fn new(backend: Box<dyn TransformerBackend>, gate: Arc<dyn CoherenceGate>) -> Self {
Self {
backend,
gate,
models: std::collections::HashMap::new(),
stats: EngineStats::default(),
}
}
/// Create with default native simulator backend
#[cfg(feature = "native_sim")]
pub fn native_sim() -> Self {
let gate = Arc::new(gating::DefaultCoherenceGate::new());
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
Self::new(backend, gate)
}
/// Load a model artifact from bytes
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<ModelId> {
let artifact = artifact::unpack_artifact(artifact_bytes)?;
self.load(&artifact)
}
/// Load a model artifact
pub fn load(&mut self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
// Load into backend
let model_id = self.backend.load(artifact)?;
// Store info
self.models.insert(
model_id,
ModelInfo {
artifact: artifact.clone(),
shape: artifact.manifest.shape,
quant: artifact.manifest.quant,
},
);
Ok(model_id)
}
/// Run inference
pub fn infer(&mut self, req: InferenceRequest) -> Result<InferenceResult> {
self.stats.total_inferences += 1;
// Check preflight gate
let preflight = self.gate.preflight(&req.gate_hint);
if let GateDecision::Skipped { reason } = preflight {
self.stats.skipped += 1;
return Err(Error::GateBlocked { reason });
}
// Run inference
let result = self.backend.infer(req)?;
// Update stats
self.stats.total_latency_ns += result.witness.latency_ns as u64;
match result.witness.gate_decision {
GateDecision::RanFull => self.stats.successful += 1,
GateDecision::EarlyExit { .. } => {
self.stats.successful += 1;
self.stats.early_exits += 1;
}
GateDecision::Skipped { .. } => self.stats.skipped += 1,
}
Ok(result)
}
/// Unload a model
pub fn unload(&mut self, model: ModelId) -> Result<()> {
self.backend.unload(model)?;
self.models.remove(&model);
Ok(())
}
/// Get model shape
pub fn shape(&self, model: ModelId) -> Result<FixedShape> {
self.models
.get(&model)
.map(|info| info.shape)
.ok_or_else(|| Error::ModelNotFound(model))
}
/// Get model info
pub fn model_info(&self, model: ModelId) -> Option<&ModelInfo> {
self.models.get(&model)
}
/// Check if model is loaded
pub fn is_loaded(&self, model: ModelId) -> bool {
self.models.contains_key(&model)
}
/// Get list of loaded models
pub fn loaded_models(&self) -> Vec<ModelId> {
self.models.keys().copied().collect()
}
/// Get engine statistics
pub fn stats(&self) -> &EngineStats {
&self.stats
}
/// Get backend statistics
pub fn backend_stats(&self) -> backend::BackendStats {
self.backend.stats()
}
/// Get backend kind
pub fn backend_kind(&self) -> BackendKind {
self.backend.kind()
}
/// Check if write is allowed based on witness
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
self.gate.allow_write(witness)
}
/// Reset statistics
pub fn reset_stats(&mut self) {
self.stats = EngineStats::default();
}
}
impl EngineStats {
/// Get average latency in nanoseconds
pub fn avg_latency_ns(&self) -> f64 {
if self.successful == 0 {
0.0
} else {
self.total_latency_ns as f64 / self.successful as f64
}
}
/// Get average latency in milliseconds
pub fn avg_latency_ms(&self) -> f64 {
self.avg_latency_ns() / 1_000_000.0
}
/// Get success rate
pub fn success_rate(&self) -> f64 {
if self.total_inferences == 0 {
1.0
} else {
self.successful as f64 / self.total_inferences as f64
}
}
/// Get early exit rate
pub fn early_exit_rate(&self) -> f64 {
if self.successful == 0 {
0.0
} else {
self.early_exits as f64 / self.successful as f64
}
}
}
/// Prelude for convenient imports
pub mod prelude {
pub use crate::{
artifact::ModelArtifact,
backend::TransformerBackend,
gating::CoherenceGate,
types::{
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
InferenceResult, ModelId, QuantSpec, SkipReason, WitnessLog,
},
Engine, Error, Result,
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_engine_creation() {
let gate = Arc::new(gating::DefaultCoherenceGate::new());
#[cfg(feature = "native_sim")]
{
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
let engine = Engine::new(backend, gate);
assert!(engine.loaded_models().is_empty());
}
}
#[test]
fn test_engine_stats() {
let stats = EngineStats {
total_inferences: 100,
successful: 80,
skipped: 10,
early_exits: 20,
total_latency_ns: 8_000_000,
};
assert!((stats.success_rate() - 0.8).abs() < 0.01);
assert!((stats.early_exit_rate() - 0.25).abs() < 0.01);
assert!((stats.avg_latency_ns() - 100_000.0).abs() < 1.0);
}
}

View File

@@ -0,0 +1,303 @@
//! Calibration data for quantization
use crate::error::Result;
use serde::{Deserialize, Serialize};
/// Calibration data for a model
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationData {
/// Layer-wise activation statistics
pub layers: Vec<LayerCalibration>,
/// Global input statistics
pub input_stats: ActivationStats,
/// Number of calibration samples used
pub num_samples: usize,
/// Calibration method used
pub method: CalibrationMethod,
}
/// Per-layer calibration data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerCalibration {
/// Layer index
pub layer_idx: usize,
/// Layer name
pub name: String,
/// Activation statistics after this layer
pub activation_stats: ActivationStats,
/// Weight statistics for this layer
pub weight_stats: WeightStats,
/// Optimal scale for activations (Q16.16)
pub act_scale: i32,
/// Optimal scale for weights (Q16.16)
pub weight_scale: i32,
}
/// Activation statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ActivationStats {
/// Minimum value seen
pub min: f32,
/// Maximum value seen
pub max: f32,
/// Mean value
pub mean: f32,
/// Standard deviation
pub std: f32,
/// Histogram bins (for entropy calibration)
#[serde(default)]
pub histogram: Vec<u32>,
/// Histogram bin edges
#[serde(default)]
pub bin_edges: Vec<f32>,
}
impl ActivationStats {
/// Create empty stats
pub fn new() -> Self {
Self::default()
}
/// Update stats with a batch of values
pub fn update(&mut self, values: &[f32]) {
if values.is_empty() {
return;
}
// Update min/max
for &v in values {
if v < self.min || self.min == 0.0 {
self.min = v;
}
if v > self.max {
self.max = v;
}
}
// Update running mean and std
let n = values.len() as f32;
let batch_mean = values.iter().sum::<f32>() / n;
let batch_var = values.iter().map(|v| (v - batch_mean).powi(2)).sum::<f32>() / n;
// Simple update (not online algorithm)
self.mean = batch_mean;
self.std = batch_var.sqrt();
}
/// Compute optimal scale for symmetric quantization to n bits
pub fn optimal_scale(&self, bits: u8) -> f32 {
let max_range = self.max.abs().max(self.min.abs());
let qmax = (1 << (bits - 1)) as f32 - 1.0;
max_range / qmax
}
}
/// Weight statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WeightStats {
/// Min weight value
pub min: f32,
/// Max weight value
pub max: f32,
/// Sparsity (fraction of zeros)
pub sparsity: f32,
}
impl WeightStats {
/// Compute from weight tensor
pub fn from_weights(weights: &[f32]) -> Self {
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut zeros = 0usize;
for &w in weights {
if w < min {
min = w;
}
if w > max {
max = w;
}
if w.abs() < 1e-6 {
zeros += 1;
}
}
Self {
min,
max,
sparsity: zeros as f32 / weights.len() as f32,
}
}
}
/// Calibration method
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CalibrationMethod {
/// Use min/max of observed values
MinMax,
/// Use percentile clipping (e.g., 99.9%)
Percentile(u32), // 999 = 99.9%
/// Entropy-based calibration (KL divergence)
Entropy,
/// Mean-squared error minimization
Mse,
}
impl Default for CalibrationMethod {
fn default() -> Self {
Self::MinMax
}
}
impl CalibrationData {
/// Create empty calibration data
pub fn new(method: CalibrationMethod) -> Self {
Self {
layers: Vec::new(),
input_stats: ActivationStats::new(),
num_samples: 0,
method,
}
}
/// Add layer calibration
pub fn add_layer(&mut self, calib: LayerCalibration) {
self.layers.push(calib);
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>> {
Ok(serde_json::to_vec(self)?)
}
/// Deserialize from bytes
pub fn from_bytes(data: &[u8]) -> Result<Self> {
Ok(serde_json::from_slice(data)?)
}
}
/// Calibrate a model by collecting activation statistics
pub fn calibrate_model<F>(
run_inference: F,
calibration_inputs: &[Vec<u16>],
num_layers: usize,
method: CalibrationMethod,
) -> Result<CalibrationData>
where
F: Fn(&[u16]) -> Result<Vec<Vec<f32>>>, // Returns activations per layer
{
let mut calibration = CalibrationData::new(method);
// Initialize layer stats
let mut layer_stats: Vec<ActivationStats> =
(0..num_layers).map(|_| ActivationStats::new()).collect();
// Run calibration passes
for input in calibration_inputs {
// Run inference and collect activations
let activations = run_inference(input)?;
// Update statistics
for (layer_idx, layer_act) in activations.iter().enumerate() {
if layer_idx < num_layers {
layer_stats[layer_idx].update(layer_act);
}
}
calibration.num_samples += 1;
}
// Create layer calibrations
for (layer_idx, stats) in layer_stats.into_iter().enumerate() {
let act_scale = match method {
CalibrationMethod::MinMax => stats.optimal_scale(8),
CalibrationMethod::Percentile(_) => stats.optimal_scale(8) * 0.99,
CalibrationMethod::Entropy => stats.optimal_scale(8),
CalibrationMethod::Mse => stats.optimal_scale(8),
};
calibration.add_layer(LayerCalibration {
layer_idx,
name: format!("layer_{}", layer_idx),
activation_stats: stats,
weight_stats: WeightStats::default(),
act_scale: (act_scale * 65536.0) as i32,
weight_scale: 65536, // Default 1.0
});
}
Ok(calibration)
}
/// Apply percentile clipping to calibration
pub fn apply_percentile(stats: &ActivationStats, percentile: f32) -> (f32, f32) {
if stats.histogram.is_empty() || stats.bin_edges.len() < 2 {
return (stats.min, stats.max);
}
let total: u32 = stats.histogram.iter().sum();
let target_low = (total as f32 * (1.0 - percentile) / 2.0) as u32;
let target_high = (total as f32 * (1.0 + percentile) / 2.0) as u32;
let mut cumsum = 0u32;
let mut low_idx = 0;
let mut high_idx = stats.histogram.len() - 1;
for (i, &count) in stats.histogram.iter().enumerate() {
cumsum += count;
if cumsum >= target_low && low_idx == 0 {
low_idx = i;
}
if cumsum >= target_high {
high_idx = i;
break;
}
}
(stats.bin_edges[low_idx], stats.bin_edges[high_idx + 1])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_activation_stats_update() {
let mut stats = ActivationStats::new();
stats.update(&[1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(stats.min, 1.0);
assert_eq!(stats.max, 5.0);
assert!((stats.mean - 3.0).abs() < 0.01);
}
#[test]
fn test_optimal_scale() {
let mut stats = ActivationStats::new();
stats.min = -1.0;
stats.max = 1.0;
let scale = stats.optimal_scale(8);
// For 8-bit, qmax = 127, so scale should be 1.0/127 ≈ 0.00787
assert!((scale - 1.0 / 127.0).abs() < 0.001);
}
#[test]
fn test_weight_stats() {
let weights = vec![0.0, 0.1, -0.1, 0.5, -0.5, 0.0];
let stats = WeightStats::from_weights(&weights);
assert_eq!(stats.min, -0.5);
assert_eq!(stats.max, 0.5);
assert!((stats.sparsity - 2.0 / 6.0).abs() < 0.01);
}
#[test]
fn test_calibration_serialization() {
let calib = CalibrationData::new(CalibrationMethod::MinMax);
let bytes = calib.to_bytes().unwrap();
let restored = CalibrationData::from_bytes(&bytes).unwrap();
assert_eq!(calib.method, restored.method);
}
}

View File

@@ -0,0 +1,336 @@
//! Lookup table implementations for fixed-point operations
//!
//! Provides LUT-based exp, log, and softmax for deterministic computation.
/// LUT-based exponential function
/// Input: Q8.8 fixed point [-16, 16)
/// Output: Q0.16 fixed point [0, 1)
const EXP_LUT_SIZE: usize = 256;
const EXP_LUT_SHIFT: i32 = 8; // Q8.8 input
/// Precomputed exp LUT for range [-8, 8) in Q8.8
static EXP_LUT: [u16; EXP_LUT_SIZE] = generate_exp_lut();
/// Generate exp LUT at compile time
const fn generate_exp_lut() -> [u16; EXP_LUT_SIZE] {
let mut lut = [0u16; EXP_LUT_SIZE];
let mut i = 0;
while i < EXP_LUT_SIZE {
// Convert index to Q8.8 value (range -128..128 in fixed point = -0.5..0.5)
let x_q = (i as i32) - 128;
// Scale to get reasonable exp range
let x_f = (x_q as f64) / 32.0; // x in [-4, 4)
// Compute exp and scale to Q0.16
let exp_val = const_exp(x_f);
let scaled = exp_val / (1.0 + const_exp(4.0)); // Normalize
// Convert to u16
lut[i] = if scaled > 1.0 {
65535
} else if scaled < 0.0 {
0
} else {
(scaled * 65535.0) as u16
};
i += 1;
}
lut
}
/// Const-compatible exp approximation using Taylor series
const fn const_exp(x: f64) -> f64 {
// exp(x) ≈ 1 + x + x²/2 + x³/6 + x⁴/24 + x⁵/120
let x2 = x * x;
let x3 = x2 * x;
let x4 = x3 * x;
let x5 = x4 * x;
1.0 + x + x2 / 2.0 + x3 / 6.0 + x4 / 24.0 + x5 / 120.0
}
/// LUT-based exponential
/// Input: i16 in Q8.8 format
/// Output: u16 in Q0.16 format
#[inline]
pub fn exp_lut(x: i16) -> u16 {
// Clamp to LUT range
let clamped = x.clamp(-128 * 256, 127 * 256);
// Scale to LUT index
let idx = ((clamped >> EXP_LUT_SHIFT) + 128) as usize;
EXP_LUT[idx.min(EXP_LUT_SIZE - 1)]
}
/// Log LUT for Q0.16 input
static LOG_LUT: [i16; 256] = generate_log_lut();
const fn generate_log_lut() -> [i16; 256] {
let mut lut = [0i16; 256];
let mut i = 1;
while i < 256 {
// Input is scaled by 256, so x = i/256 in [0.004, 1)
let x = (i as f64) / 256.0;
// log(x) in Q8.8 format
let log_val = const_ln(x);
lut[i] = (log_val * 256.0) as i16;
i += 1;
}
lut[0] = i16::MIN; // log(0) = -inf, use min value
lut
}
/// Const-compatible natural log approximation
const fn const_ln(x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
// Use series expansion around x=1: ln(x) = 2 * sum((x-1)/(x+1))^(2n+1)/(2n+1)
let y = (x - 1.0) / (x + 1.0);
let y2 = y * y;
// ln(x) ≈ 2 * (y + y³/3 + y⁵/5 + y⁷/7 + y⁹/9)
let y3 = y2 * y;
let y5 = y3 * y2;
let y7 = y5 * y2;
let y9 = y7 * y2;
2.0 * (y + y3 / 3.0 + y5 / 5.0 + y7 / 7.0 + y9 / 9.0)
}
/// LUT-based natural log
/// Input: u16 in Q0.16 format (0 to 65535 = 0.0 to ~1.0)
/// Output: i16 in Q8.8 format
#[inline]
pub fn log_lut(x: u16) -> i16 {
if x == 0 {
return i16::MIN;
}
// Scale to LUT index
let idx = (x >> 8) as usize;
LOG_LUT[idx.min(255)]
}
/// Softmax using LUT-based exp
/// Operates in-place on Q8.8 logits, outputs Q0.16 probabilities
pub fn softmax_lut_q(logits: &mut [i16]) {
if logits.is_empty() {
return;
}
// Find max for numerical stability
let max = *logits.iter().max().unwrap_or(&0);
// Compute exp(x - max) using LUT
let mut sum: u32 = 0;
let mut exp_values: Vec<u16> = Vec::with_capacity(logits.len());
for &logit in logits.iter() {
let shifted = logit.saturating_sub(max);
let exp_val = exp_lut(shifted);
exp_values.push(exp_val);
sum += exp_val as u32;
}
// Normalize
if sum == 0 {
sum = 1;
}
for (i, logit) in logits.iter_mut().enumerate() {
let prob = ((exp_values[i] as u64 * 65535) / sum as u64) as u16;
*logit = prob as i16;
}
}
/// Softmax on f32 values using LUT (for compatibility)
pub fn softmax_lut(logits: &mut [f32]) {
if logits.is_empty() {
return;
}
// Find max
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
// Compute exp
let mut sum = 0.0f32;
for v in logits.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
// Normalize
if sum > 0.0 {
for v in logits.iter_mut() {
*v /= sum;
}
}
}
/// Piecewise linear softmax approximation
/// More accurate than LUT but still deterministic
pub fn softmax_pwl(logits: &mut [i16]) {
if logits.is_empty() {
return;
}
let max = *logits.iter().max().unwrap_or(&0);
// Piecewise linear exp approximation
// exp(x) ≈ 1 + x for x near 0
// exp(x) ≈ 2^(x/ln2) for larger x
let mut sum: i64 = 0;
let mut exp_values: Vec<i32> = Vec::with_capacity(logits.len());
for &logit in logits.iter() {
let x = (logit - max) as i32; // x <= 0
// Piecewise approximation (in Q8.8)
let exp_val = if x >= -256 {
// x in [-1, 0] -> linear: 1 + x
(256 + x).max(0) as i32
} else if x >= -2048 {
// x in [-8, -1] -> exponential decay
let shifted = (x + 2048) >> 3; // Scale to 0-256 range
(shifted * shifted / 256).max(1) as i32
} else {
// x < -8 -> essentially zero
1
};
exp_values.push(exp_val);
sum += exp_val as i64;
}
// Normalize to Q0.16
if sum == 0 {
sum = 1;
}
for (i, logit) in logits.iter_mut().enumerate() {
let prob = (exp_values[i] as i64 * 65535 / sum) as i16;
*logit = prob;
}
}
/// GELU approximation using LUT
/// GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
pub fn gelu_lut(x: i16) -> i16 {
// Simplified approximation: GELU(x) ≈ x * sigmoid(1.702 * x)
let scaled = ((x as i32 * 435) >> 8) as i16; // 1.702 * x in Q8.8
let sigmoid_val = sigmoid_lut(scaled);
((x as i32 * sigmoid_val as i32) >> 15) as i16
}
/// Sigmoid LUT
static SIGMOID_LUT: [u16; 256] = generate_sigmoid_lut();
const fn generate_sigmoid_lut() -> [u16; 256] {
let mut lut = [0u16; 256];
let mut i = 0;
while i < 256 {
// Map index to x in [-8, 8)
let x = ((i as i32) - 128) as f64 / 16.0;
// sigmoid(x) = 1 / (1 + exp(-x))
let sig = 1.0 / (1.0 + const_exp(-x));
lut[i] = (sig * 65535.0) as u16;
i += 1;
}
lut
}
/// LUT-based sigmoid
/// Input: i16 in Q8.8 format
/// Output: u16 in Q0.16 format
#[inline]
pub fn sigmoid_lut(x: i16) -> u16 {
// Scale to LUT range
let idx = (((x >> 5) + 128) as usize).min(255);
SIGMOID_LUT[idx]
}
/// SiLU (Swish) using sigmoid LUT
/// SiLU(x) = x * sigmoid(x)
#[inline]
pub fn silu_lut(x: i16) -> i16 {
let sigmoid_val = sigmoid_lut(x);
((x as i32 * sigmoid_val as i32) >> 16) as i16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exp_lut() {
// exp(0) should return a non-zero value
let result = exp_lut(0);
assert!(result > 0, "exp(0) should be positive");
// exp is monotonically increasing
let result_neg = exp_lut(-256); // -1.0 in Q8.8
let result_zero = exp_lut(0);
let result_pos = exp_lut(256); // 1.0 in Q8.8
assert!(
result_neg <= result_zero,
"exp should be monotonically increasing"
);
assert!(
result_zero <= result_pos,
"exp should be monotonically increasing"
);
}
#[test]
fn test_sigmoid_lut() {
// sigmoid(0) = 0.5
let result = sigmoid_lut(0);
let expected = 32768u16; // 0.5 in Q0.16
assert!(
(result as i32 - expected as i32).abs() < 5000,
"sigmoid(0) ≈ 0.5"
);
// sigmoid is monotonically increasing
let result_neg = sigmoid_lut(-1024);
let result_zero = sigmoid_lut(0);
let result_pos = sigmoid_lut(1024);
assert!(
result_neg < result_zero,
"sigmoid should be monotonically increasing"
);
assert!(
result_zero < result_pos,
"sigmoid should be monotonically increasing"
);
}
#[test]
fn test_softmax_lut() {
let mut logits = vec![1.0f32, 2.0, 3.0, 4.0];
softmax_lut(&mut logits);
// Sum should be 1.0
let sum: f32 = logits.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
// Should be increasing
for i in 1..logits.len() {
assert!(logits[i] > logits[i - 1]);
}
}
#[test]
fn test_gelu_lut() {
// GELU(0) should be approximately 0
assert!(gelu_lut(0).abs() < 100);
// GELU maintains sign
let neg_result = gelu_lut(-256);
assert!(neg_result <= 0, "GELU of negative should be non-positive");
// GELU of positive values should be positive
let pos_result = gelu_lut(256);
assert!(pos_result > 0, "GELU of positive should be positive");
}
}

View File

@@ -0,0 +1,238 @@
//! Quantization subsystem
//!
//! Explicit, reproducible quantization for weights and activations.
pub mod calib;
pub mod lut;
pub mod qformat;
pub use calib::{calibrate_model, CalibrationData};
pub use lut::{exp_lut, log_lut, softmax_lut};
pub use qformat::{dequantize_i16, dequantize_i8, quantize_i16, quantize_i8};
use crate::types::QuantSpec;
/// Fixed-point Q15 format (1.15)
/// Range: [-1.0, 1.0 - 2^-15]
/// Resolution: 2^-15 ≈ 3.05e-5
pub type Q15 = i16;
/// Fixed-point Q16.16 format
/// Range: [-32768.0, 32767.999...]
/// Resolution: 2^-16 ≈ 1.53e-5
pub type Q16_16 = i32;
/// Convert f32 to Q15
#[inline]
pub fn f32_to_q15(x: f32) -> Q15 {
(x.clamp(-1.0, 1.0 - f32::EPSILON) * 32768.0) as Q15
}
/// Convert Q15 to f32
#[inline]
pub fn q15_to_f32(x: Q15) -> f32 {
x as f32 / 32768.0
}
/// Convert f32 to Q16.16
#[inline]
pub fn f32_to_q16_16(x: f32) -> Q16_16 {
(x * 65536.0) as Q16_16
}
/// Convert Q16.16 to f32
#[inline]
pub fn q16_16_to_f32(x: Q16_16) -> f32 {
x as f32 / 65536.0
}
/// Fixed-point multiplication Q15 * Q15 -> Q15
#[inline]
pub fn q15_mul(a: Q15, b: Q15) -> Q15 {
// Multiply with proper rounding
let product = (a as i32 * b as i32 + 0x4000) >> 15;
product.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
}
/// Fixed-point dot product with accumulator
/// Note: For very large vectors (>65536 elements), use q15_dot_saturating
#[inline]
pub fn q15_dot(a: &[Q15], b: &[Q15]) -> i32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum()
}
/// Saturating fixed-point dot product (prevents overflow for large vectors)
#[inline]
pub fn q15_dot_saturating(a: &[Q15], b: &[Q15]) -> i32 {
a.iter().zip(b.iter()).fold(0i32, |acc, (&x, &y)| {
acc.saturating_add((x as i32).saturating_mul(y as i32))
})
}
/// Fixed-point dot product normalized to Q15
#[inline]
pub fn q15_dot_normalized(a: &[Q15], b: &[Q15], shift: u8) -> Q15 {
let sum = q15_dot(a, b);
let shifted = (sum + (1 << (shift - 1))) >> shift;
shifted.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
}
/// Quantization context for a layer
#[derive(Debug, Clone)]
pub struct QuantContext {
/// Weight quantization spec
pub weight_spec: QuantSpec,
/// Input scale (Q16.16)
pub input_scale: Q16_16,
/// Output scale (Q16.16)
pub output_scale: Q16_16,
/// Accumulator bit width
pub acc_bits: u8,
/// Right shift for normalization
pub norm_shift: u8,
}
impl QuantContext {
/// Create from QuantSpec
pub fn from_spec(spec: &QuantSpec) -> Self {
Self {
weight_spec: *spec,
input_scale: spec.scale_q,
output_scale: spec.scale_q,
acc_bits: 32,
norm_shift: 15,
}
}
/// Compute the required accumulator bits to avoid overflow
pub fn required_acc_bits(input_bits: u8, weight_bits: u8, vector_len: usize) -> u8 {
// Each multiply produces input_bits + weight_bits
// Sum of vector_len terms adds log2(vector_len) bits
let product_bits = input_bits + weight_bits;
let sum_bits = (vector_len as f64).log2().ceil() as u8;
product_bits + sum_bits + 1 // +1 for sign
}
}
/// Packing utilities for sub-byte quantization
pub mod packing {
/// Pack two 4-bit values into one byte
#[inline]
pub fn pack_int4(a: i8, b: i8) -> u8 {
((a & 0x0F) as u8) | (((b & 0x0F) as u8) << 4)
}
/// Unpack byte into two 4-bit values
#[inline]
pub fn unpack_int4(packed: u8) -> (i8, i8) {
let a = (packed & 0x0F) as i8;
let a = if a & 0x08 != 0 { a | !0x0F } else { a }; // Sign extend
let b = ((packed >> 4) & 0x0F) as i8;
let b = if b & 0x08 != 0 { b | !0x0F } else { b };
(a, b)
}
/// Pack four 2-bit values into one byte
#[inline]
pub fn pack_int2(a: i8, b: i8, c: i8, d: i8) -> u8 {
((a & 0x03) as u8)
| (((b & 0x03) as u8) << 2)
| (((c & 0x03) as u8) << 4)
| (((d & 0x03) as u8) << 6)
}
/// Unpack byte into four 2-bit values
#[inline]
pub fn unpack_int2(packed: u8) -> (i8, i8, i8, i8) {
let a = (packed & 0x03) as i8;
let a = if a & 0x02 != 0 { a | !0x03 } else { a };
let b = ((packed >> 2) & 0x03) as i8;
let b = if b & 0x02 != 0 { b | !0x03 } else { b };
let c = ((packed >> 4) & 0x03) as i8;
let c = if c & 0x02 != 0 { c | !0x03 } else { c };
let d = ((packed >> 6) & 0x03) as i8;
let d = if d & 0x02 != 0 { d | !0x03 } else { d };
(a, b, c, d)
}
/// Pack eight 1-bit values into one byte
#[inline]
pub fn pack_binary(bits: &[bool; 8]) -> u8 {
bits.iter()
.enumerate()
.fold(0u8, |acc, (i, &b)| acc | ((b as u8) << i))
}
/// Unpack byte into eight 1-bit values
#[inline]
pub fn unpack_binary(packed: u8) -> [bool; 8] {
[
packed & 0x01 != 0,
packed & 0x02 != 0,
packed & 0x04 != 0,
packed & 0x08 != 0,
packed & 0x10 != 0,
packed & 0x20 != 0,
packed & 0x40 != 0,
packed & 0x80 != 0,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_q15_conversion() {
assert_eq!(f32_to_q15(0.0), 0);
assert_eq!(f32_to_q15(0.5), 16384);
assert_eq!(f32_to_q15(-0.5), -16384);
let x = 0.123f32;
let q = f32_to_q15(x);
let back = q15_to_f32(q);
assert!((x - back).abs() < 0.001);
}
#[test]
fn test_q15_mul() {
let a = f32_to_q15(0.5);
let b = f32_to_q15(0.5);
let c = q15_mul(a, b);
let result = q15_to_f32(c);
assert!((result - 0.25).abs() < 0.01);
}
#[test]
fn test_packing_int4() {
let (a, b) = (5i8, -3i8);
let packed = packing::pack_int4(a, b);
let (ua, ub) = packing::unpack_int4(packed);
assert_eq!(a, ua);
assert_eq!(b, ub);
}
#[test]
fn test_packing_int2() {
let (a, b, c, d) = (1i8, -1i8, 0i8, -2i8);
let packed = packing::pack_int2(a, b, c, d);
let (ua, ub, uc, ud) = packing::unpack_int2(packed);
assert_eq!(a, ua);
assert_eq!(b, ub);
assert_eq!(c, uc);
// -2 in 2-bit is 10 binary, which unpacks to -2 (sign extended)
assert_eq!(-2i8, ud);
}
#[test]
fn test_packing_binary() {
let bits = [true, false, true, true, false, false, true, false];
let packed = packing::pack_binary(&bits);
let unpacked = packing::unpack_binary(packed);
assert_eq!(bits, unpacked);
}
}

View File

@@ -0,0 +1,237 @@
//! Quantization format operations
use crate::types::QuantSpec;
/// Quantize f32 values to i8
pub fn quantize_i8(values: &[f32], spec: &QuantSpec) -> Vec<i8> {
let scale = spec.scale_q as f32 / 65536.0;
let zero = spec.zero_q as f32 / 65536.0;
values
.iter()
.map(|&v| {
let quantized = ((v - zero) / scale).round();
quantized.clamp(-128.0, 127.0) as i8
})
.collect()
}
/// Quantize f32 values to i16
pub fn quantize_i16(values: &[f32]) -> Vec<i16> {
// Find min/max
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
// Handle edge case
if (max - min).abs() < f32::EPSILON {
return vec![0i16; values.len()];
}
let scale = (max - min) / 65535.0;
values
.iter()
.map(|&v| {
let normalized = (v - min) / scale - 32768.0;
normalized.round().clamp(-32768.0, 32767.0) as i16
})
.collect()
}
/// Dequantize i8 values to f32
pub fn dequantize_i8(values: &[u8], spec: &QuantSpec) -> Vec<f32> {
let scale = spec.scale_q as f32 / 65536.0;
let zero = spec.zero_q as f32 / 65536.0;
values
.iter()
.map(|&v| {
let signed = v as i8;
signed as f32 * scale + zero
})
.collect()
}
/// Dequantize i16 values to f32
pub fn dequantize_i16(values: &[i16], scale: f32, zero: f32) -> Vec<f32> {
values.iter().map(|&v| v as f32 * scale + zero).collect()
}
/// Symmetric quantization (zero point = 0)
pub fn quantize_symmetric_i8(values: &[f32]) -> (Vec<i8>, f32) {
let abs_max = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if abs_max < f32::EPSILON {
return (vec![0i8; values.len()], 1.0);
}
let scale = abs_max / 127.0;
let quantized = values
.iter()
.map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
.collect();
(quantized, scale)
}
/// Asymmetric quantization (uses full i8 range)
pub fn quantize_asymmetric_i8(values: &[f32]) -> (Vec<u8>, f32, i32) {
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if (max - min).abs() < f32::EPSILON {
return (vec![0u8; values.len()], 1.0, 0);
}
let scale = (max - min) / 255.0;
let zero_point = (-min / scale).round() as i32;
let quantized = values
.iter()
.map(|&v| {
let q = (v / scale).round() as i32 + zero_point;
q.clamp(0, 255) as u8
})
.collect();
(quantized, scale, zero_point)
}
/// Per-channel quantization for weights
pub fn quantize_per_channel_i8(weights: &[f32], out_channels: usize) -> (Vec<i8>, Vec<f32>) {
let in_features = weights.len() / out_channels;
let mut quantized = Vec::with_capacity(weights.len());
let mut scales = Vec::with_capacity(out_channels);
for c in 0..out_channels {
let start = c * in_features;
let end = start + in_features;
let channel_weights = &weights[start..end];
let (q, scale) = quantize_symmetric_i8(channel_weights);
quantized.extend(q);
scales.push(scale);
}
(quantized, scales)
}
/// Blocked quantization for hardware efficiency
pub fn quantize_blocked_i8(values: &[f32], block_size: usize) -> (Vec<i8>, Vec<f32>, Vec<i8>) {
let num_blocks = (values.len() + block_size - 1) / block_size;
let mut quantized = Vec::with_capacity(values.len());
let mut scales = Vec::with_capacity(num_blocks);
let mut zeros = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(values.len());
let block = &values[start..end];
let (q, scale) = quantize_symmetric_i8(block);
quantized.extend(q);
scales.push(scale);
zeros.push(0i8);
}
(quantized, scales, zeros)
}
/// Matrix quantization for GEMM
#[derive(Debug, Clone)]
pub struct QuantizedMatrix {
/// Quantized values
pub data: Vec<i8>,
/// Rows
pub rows: usize,
/// Columns
pub cols: usize,
/// Per-row scales (for per-channel quantization)
pub scales: Vec<f32>,
/// Per-row zero points
pub zeros: Vec<i8>,
}
impl QuantizedMatrix {
/// Quantize a matrix with per-row scaling
pub fn from_f32(data: &[f32], rows: usize, cols: usize) -> Self {
assert_eq!(data.len(), rows * cols);
let (quantized, scales) = quantize_per_channel_i8(data, rows);
Self {
data: quantized,
rows,
cols,
scales,
zeros: vec![0i8; rows],
}
}
/// Get a row
pub fn row(&self, idx: usize) -> &[i8] {
let start = idx * self.cols;
&self.data[start..start + self.cols]
}
/// Dequantize to f32
pub fn to_f32(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.rows * self.cols);
for r in 0..self.rows {
let scale = self.scales[r];
let zero = self.zeros[r] as f32;
for &v in self.row(r) {
result.push((v as f32 - zero) * scale);
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::QuantSpec;
#[test]
fn test_quantize_symmetric() {
let values = vec![1.0, -1.0, 0.5, -0.5, 0.0];
let (quantized, scale) = quantize_symmetric_i8(&values);
// Dequantize and check
for (i, &q) in quantized.iter().enumerate() {
let dequant = q as f32 * scale;
assert!((dequant - values[i]).abs() < 0.1);
}
}
#[test]
fn test_quantize_asymmetric() {
let values = vec![0.0, 0.5, 1.0, 1.5, 2.0];
let (quantized, scale, zero) = quantize_asymmetric_i8(&values);
// Dequantize and check
for (i, &q) in quantized.iter().enumerate() {
let dequant = (q as i32 - zero) as f32 * scale;
assert!((dequant - values[i]).abs() < 0.1);
}
}
#[test]
fn test_quantized_matrix() {
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.1 - 3.2).collect();
let matrix = QuantizedMatrix::from_f32(&data, 8, 8);
assert_eq!(matrix.rows, 8);
assert_eq!(matrix.cols, 8);
assert_eq!(matrix.scales.len(), 8);
let dequantized = matrix.to_f32();
for (orig, deq) in data.iter().zip(dequantized.iter()) {
assert!((orig - deq).abs() < 0.2);
}
}
}

View File

@@ -0,0 +1,624 @@
//! Core types for FPGA Transformer backend
//!
//! All types are designed for deterministic, allocation-free inference
//! with explicit quantization metadata.
use serde::{Deserialize, Serialize};
/// Unique identifier for a loaded model (SHA-256 hash)
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelId(pub [u8; 32]);
impl ModelId {
/// Create a new ModelId from bytes
pub const fn new(bytes: [u8; 32]) -> Self {
Self(bytes)
}
/// Create a zero ModelId (for testing)
pub const fn zero() -> Self {
Self([0u8; 32])
}
/// Get the bytes of the ModelId
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
/// Convert to hex string
pub fn to_hex(&self) -> String {
self.0.iter().map(|b| format!("{:02x}", b)).collect()
}
/// Parse from hex string
pub fn from_hex(s: &str) -> Option<Self> {
if s.len() != 64 {
return None;
}
let mut bytes = [0u8; 32];
for (i, chunk) in s.as_bytes().chunks(2).enumerate() {
let hex_str = std::str::from_utf8(chunk).ok()?;
bytes[i] = u8::from_str_radix(hex_str, 16).ok()?;
}
Some(Self(bytes))
}
}
impl std::fmt::Display for ModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_hex())
}
}
/// Fixed shape specification for transformer inference
///
/// All dimensions are compile-time or model-time constants.
/// This enables zero-allocation inference paths.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct FixedShape {
/// Maximum sequence length
pub seq_len: u16,
/// Model/hidden dimension
pub d_model: u16,
/// Number of attention heads
pub heads: u8,
/// Dimension per head
pub d_head: u16,
/// Vocabulary size
pub vocab: u32,
}
impl FixedShape {
/// Create a new FixedShape
pub const fn new(seq_len: u16, d_model: u16, heads: u8, d_head: u16, vocab: u32) -> Self {
Self {
seq_len,
d_model,
heads,
d_head,
vocab,
}
}
/// Micro configuration for edge/WASM deployment
pub const fn micro() -> Self {
Self {
seq_len: 32,
d_model: 64,
heads: 4,
d_head: 16,
vocab: 4096,
}
}
/// Small configuration for embedded
pub const fn small() -> Self {
Self {
seq_len: 64,
d_model: 128,
heads: 4,
d_head: 32,
vocab: 8192,
}
}
/// Baseline configuration
pub const fn baseline() -> Self {
Self {
seq_len: 128,
d_model: 256,
heads: 8,
d_head: 32,
vocab: 32000,
}
}
/// Calculate total parameters for embedding layer
pub const fn embedding_params(&self) -> usize {
self.vocab as usize * self.d_model as usize
}
/// Calculate parameters per attention layer
pub const fn attention_params(&self) -> usize {
// Q, K, V projections + output projection
4 * (self.d_model as usize * self.d_model as usize)
}
/// Calculate parameters per FFN layer (assuming 4x expansion)
pub const fn ffn_params(&self) -> usize {
2 * (self.d_model as usize * 4 * self.d_model as usize)
}
/// Validate shape consistency
pub fn validate(&self) -> Result<(), String> {
if self.d_model as usize != self.heads as usize * self.d_head as usize {
return Err(format!(
"d_model ({}) must equal heads ({}) * d_head ({})",
self.d_model, self.heads, self.d_head
));
}
if self.seq_len == 0 {
return Err("seq_len must be > 0".into());
}
if self.vocab == 0 {
return Err("vocab must be > 0".into());
}
Ok(())
}
}
impl Default for FixedShape {
fn default() -> Self {
Self::baseline()
}
}
/// Quantization specification
///
/// Explicit quantization metadata ensuring reproducible inference.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct QuantSpec {
/// Weight bit width (1, 2, 4, 8)
pub w_bits: u8,
/// Activation bit width (4, 8, 16)
pub a_bits: u8,
/// Scale factor (Q16.16 fixed point)
pub scale_q: i32,
/// Zero point (Q16.16 fixed point)
pub zero_q: i32,
/// Memory layout
pub layout: Layout,
}
impl QuantSpec {
/// Create a new QuantSpec
pub const fn new(w_bits: u8, a_bits: u8, scale_q: i32, zero_q: i32, layout: Layout) -> Self {
Self {
w_bits,
a_bits,
scale_q,
zero_q,
layout,
}
}
/// INT4 weights, INT8 activations (common for edge)
pub const fn int4_int8() -> Self {
Self {
w_bits: 4,
a_bits: 8,
scale_q: 1 << 16, // 1.0 in Q16.16
zero_q: 0,
layout: Layout::Blocked { block: 32 },
}
}
/// INT8 weights and activations
pub const fn int8() -> Self {
Self {
w_bits: 8,
a_bits: 8,
scale_q: 1 << 16,
zero_q: 0,
layout: Layout::RowMajor,
}
}
/// Bytes per weight element
pub const fn bytes_per_weight(&self) -> usize {
match self.w_bits {
1 => 1, // Packed 8 per byte, but minimum 1 byte
2 => 1, // Packed 4 per byte
4 => 1, // Packed 2 per byte
8 => 1,
16 => 2,
_ => 4,
}
}
/// Weights packed per byte
pub const fn weights_per_byte(&self) -> usize {
match self.w_bits {
1 => 8,
2 => 4,
4 => 2,
_ => 1,
}
}
}
impl Default for QuantSpec {
fn default() -> Self {
Self::int8()
}
}
/// Memory layout for quantized tensors
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Layout {
/// Standard row-major layout
RowMajor,
/// Blocked layout for SIMD/hardware efficiency
Blocked { block: u16 },
/// Heads interleaved for attention computation
InterleavedHeads,
}
impl Default for Layout {
fn default() -> Self {
Self::RowMajor
}
}
/// Hint for gating decisions
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct GateHint {
/// Coherence score (Q8.8 fixed point, higher = more coherent)
pub coherence_score_q: i16,
/// Whether a boundary was crossed in the input
pub boundary_crossed: bool,
/// Maximum compute class allowed
pub max_compute_class: ComputeClass,
}
impl GateHint {
/// Create a new GateHint
pub const fn new(
coherence_score_q: i16,
boundary_crossed: bool,
max_compute_class: ComputeClass,
) -> Self {
Self {
coherence_score_q,
boundary_crossed,
max_compute_class,
}
}
/// Default hint allowing full computation
pub const fn allow_all() -> Self {
Self {
coherence_score_q: i16::MAX,
boundary_crossed: false,
max_compute_class: ComputeClass::Deliberative,
}
}
/// Reflex-only hint for fast path
pub const fn reflex_only() -> Self {
Self {
coherence_score_q: 0,
boundary_crossed: false,
max_compute_class: ComputeClass::Reflex,
}
}
}
impl Default for GateHint {
fn default() -> Self {
Self::allow_all()
}
}
/// Compute class for tiered inference
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum ComputeClass {
/// Fastest path, minimal computation (1-2 layers)
Reflex = 0,
/// Medium path, associative memory (4-6 layers)
Associative = 1,
/// Full deliberative computation (all layers)
Deliberative = 2,
}
impl ComputeClass {
/// Convert from u8
pub const fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Reflex),
1 => Some(Self::Associative),
2 => Some(Self::Deliberative),
_ => None,
}
}
}
impl Default for ComputeClass {
fn default() -> Self {
Self::Deliberative
}
}
/// Inference request
#[derive(Debug, Clone)]
pub struct InferenceRequest<'a> {
/// Model to use
pub model: ModelId,
/// Expected shape
pub shape: FixedShape,
/// Input token IDs (length = seq_len)
pub tokens: &'a [u16],
/// Attention mask (length = seq_len or seq_len^2)
pub attn_mask: &'a [u8],
/// Gating hint for coherence control
pub gate_hint: GateHint,
}
impl<'a> InferenceRequest<'a> {
/// Create a new InferenceRequest
pub fn new(
model: ModelId,
shape: FixedShape,
tokens: &'a [u16],
attn_mask: &'a [u8],
gate_hint: GateHint,
) -> Self {
Self {
model,
shape,
tokens,
attn_mask,
gate_hint,
}
}
/// Validate the request
pub fn validate(&self) -> crate::error::Result<()> {
if self.tokens.len() != self.shape.seq_len as usize {
return Err(crate::error::Error::InputLengthMismatch {
expected: self.shape.seq_len as usize,
actual: self.tokens.len(),
});
}
if self.attn_mask.len() != self.shape.seq_len as usize
&& self.attn_mask.len() != (self.shape.seq_len as usize).pow(2)
{
return Err(crate::error::Error::InputLengthMismatch {
expected: self.shape.seq_len as usize,
actual: self.attn_mask.len(),
});
}
Ok(())
}
}
/// Inference result
#[derive(Debug, Clone)]
pub struct InferenceResult {
/// Full logits (quantized, length = vocab) or empty if topk_only
pub logits_q: Vec<i16>,
/// Top-K predictions (token_id, logit_q)
pub topk: Option<Vec<(u16, i16)>>,
/// Witness log for audit trail
pub witness: WitnessLog,
}
impl InferenceResult {
/// Create a new InferenceResult
pub fn new(logits_q: Vec<i16>, topk: Option<Vec<(u16, i16)>>, witness: WitnessLog) -> Self {
Self {
logits_q,
topk,
witness,
}
}
/// Get the argmax token
pub fn argmax(&self) -> Option<u16> {
if let Some(ref topk) = self.topk {
topk.first().map(|(token, _)| *token)
} else if !self.logits_q.is_empty() {
self.logits_q
.iter()
.enumerate()
.max_by_key(|(_, &v)| v)
.map(|(i, _)| i as u16)
} else {
None
}
}
}
/// Witness log for audit trail and ReasoningBank integration
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct WitnessLog {
/// Hash of the model used
pub model_hash: [u8; 32],
/// Hash of quantization parameters used
pub quant_hash: [u8; 32],
/// Backend that executed the inference
pub backend: BackendKind,
/// Compute cycles used (FPGA) or 0 (sim)
pub cycles: u32,
/// Latency in nanoseconds
pub latency_ns: u32,
/// Gate decision made
pub gate_decision: GateDecision,
}
impl WitnessLog {
/// Create a new WitnessLog
pub fn new(
model_hash: [u8; 32],
quant_hash: [u8; 32],
backend: BackendKind,
cycles: u32,
latency_ns: u32,
gate_decision: GateDecision,
) -> Self {
Self {
model_hash,
quant_hash,
backend,
cycles,
latency_ns,
gate_decision,
}
}
/// Create an empty witness (for testing)
pub fn empty() -> Self {
Self {
model_hash: [0u8; 32],
quant_hash: [0u8; 32],
backend: BackendKind::NativeSim,
cycles: 0,
latency_ns: 0,
gate_decision: GateDecision::RanFull,
}
}
}
/// Backend types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BackendKind {
/// PCIe-connected FPGA
FpgaPcie,
/// FPGA via local daemon
FpgaDaemon,
/// WASM simulator
WasmSim,
/// Native Rust simulator
NativeSim,
}
impl std::fmt::Display for BackendKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FpgaPcie => write!(f, "fpga_pcie"),
Self::FpgaDaemon => write!(f, "fpga_daemon"),
Self::WasmSim => write!(f, "wasm_sim"),
Self::NativeSim => write!(f, "native_sim"),
}
}
}
/// Gate decision outcome
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GateDecision {
/// Full inference completed
RanFull,
/// Early exit at specified layer
EarlyExit { layer: u8 },
/// Inference was skipped
Skipped { reason: SkipReason },
}
impl GateDecision {
/// Check if inference actually ran
pub const fn did_run(&self) -> bool {
!matches!(self, Self::Skipped { .. })
}
/// Get the exit layer (full = max layers)
pub const fn exit_layer(&self, max_layers: u8) -> u8 {
match self {
Self::RanFull => max_layers,
Self::EarlyExit { layer } => *layer,
Self::Skipped { .. } => 0,
}
}
}
/// Reason for skipping inference
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SkipReason {
/// Coherence score too low
LowCoherence,
/// Policy denied the inference
PolicyDenied,
/// Compute budget exceeded
BudgetExceeded,
}
impl std::fmt::Display for SkipReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::LowCoherence => write!(f, "low_coherence"),
Self::PolicyDenied => write!(f, "policy_denied"),
Self::BudgetExceeded => write!(f, "budget_exceeded"),
}
}
}
/// Quantized tensor wrapper
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
/// Raw quantized data
pub data: Vec<u8>,
/// Quantization specification
pub spec: QuantSpec,
/// Tensor shape (row-major)
pub shape: Vec<usize>,
}
impl QuantizedTensor {
/// Create a new quantized tensor
pub fn new(data: Vec<u8>, spec: QuantSpec, shape: Vec<usize>) -> Self {
Self { data, spec, shape }
}
/// Total number of elements
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
/// Expected data size in bytes
pub fn expected_bytes(&self) -> usize {
let numel = self.numel();
(numel + self.spec.weights_per_byte() - 1) / self.spec.weights_per_byte()
}
/// Validate tensor integrity
pub fn validate(&self) -> crate::error::Result<()> {
let expected = self.expected_bytes();
if self.data.len() != expected {
return Err(crate::error::Error::QuantizationError(format!(
"Data size mismatch: expected {} bytes, got {}",
expected,
self.data.len()
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_id_hex_roundtrip() {
let bytes = [0x12u8; 32];
let id = ModelId::new(bytes);
let hex = id.to_hex();
let parsed = ModelId::from_hex(&hex).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_fixed_shape_validate() {
let valid = FixedShape::new(64, 256, 8, 32, 32000);
assert!(valid.validate().is_ok());
let invalid = FixedShape::new(64, 256, 8, 16, 32000); // 8*16 != 256
assert!(invalid.validate().is_err());
}
#[test]
fn test_quant_spec_bytes() {
assert_eq!(QuantSpec::int8().weights_per_byte(), 1);
assert_eq!(QuantSpec::int4_int8().weights_per_byte(), 2);
}
#[test]
fn test_gate_decision() {
assert!(GateDecision::RanFull.did_run());
assert!(GateDecision::EarlyExit { layer: 3 }.did_run());
assert!(!GateDecision::Skipped {
reason: SkipReason::LowCoherence
}
.did_run());
}
}

View File

@@ -0,0 +1,215 @@
//! Witness hashing for integrity verification
use crate::types::WitnessLog;
use sha2::{Digest, Sha256};
/// Compute a hash of the witness log for integrity verification
pub fn compute_witness_hash(witness: &WitnessLog) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(&witness.model_hash);
hasher.update(&witness.quant_hash);
hasher.update(&[witness.backend as u8]);
hasher.update(&witness.cycles.to_le_bytes());
hasher.update(&witness.latency_ns.to_le_bytes());
// Hash gate decision
match witness.gate_decision {
crate::types::GateDecision::RanFull => {
hasher.update(&[0u8]);
}
crate::types::GateDecision::EarlyExit { layer } => {
hasher.update(&[1u8, layer]);
}
crate::types::GateDecision::Skipped { reason } => {
hasher.update(&[2u8, reason as u8]);
}
}
hasher.finalize().into()
}
/// Verify a witness hash
pub fn verify_witness_hash(witness: &WitnessLog, expected: &[u8; 32]) -> bool {
let computed = compute_witness_hash(witness);
computed == *expected
}
/// Compute a combined hash for a sequence of witnesses
/// Useful for verifying an entire inference chain
pub fn compute_chain_hash(witnesses: &[WitnessLog]) -> [u8; 32] {
let mut hasher = Sha256::new();
for witness in witnesses {
let witness_hash = compute_witness_hash(witness);
hasher.update(&witness_hash);
}
hasher.finalize().into()
}
/// Witness proof for verification
#[derive(Debug, Clone)]
pub struct WitnessProof {
/// Hash of the witness
pub hash: [u8; 32],
/// Timestamp when proof was created
pub timestamp_ns: u64,
/// Optional signature
pub signature: Option<[u8; 64]>,
}
impl WitnessProof {
/// Create a new proof from a witness
pub fn new(witness: &WitnessLog) -> Self {
Self {
hash: compute_witness_hash(witness),
timestamp_ns: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0),
signature: None,
}
}
/// Create a proof with signature
#[cfg(feature = "sign")]
pub fn signed(witness: &WitnessLog, secret_key: &[u8; 32]) -> Self {
use ed25519_dalek::{Signer, SigningKey};
let hash = compute_witness_hash(witness);
let timestamp_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
// Create message to sign
let mut message = [0u8; 40];
message[..32].copy_from_slice(&hash);
message[32..40].copy_from_slice(&timestamp_ns.to_le_bytes());
let signing_key = SigningKey::from_bytes(secret_key);
let signature = signing_key.sign(&message);
Self {
hash,
timestamp_ns,
signature: Some(signature.to_bytes()),
}
}
/// Verify the proof against a witness
pub fn verify(&self, witness: &WitnessLog) -> bool {
verify_witness_hash(witness, &self.hash)
}
/// Verify the signature
#[cfg(feature = "sign")]
pub fn verify_signature(&self, pubkey: &[u8; 32]) -> bool {
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
let Some(sig_bytes) = self.signature else {
return false;
};
let Ok(verifying_key) = VerifyingKey::from_bytes(pubkey) else {
return false;
};
let signature = Signature::from_bytes(&sig_bytes);
let mut message = [0u8; 40];
message[..32].copy_from_slice(&self.hash);
message[32..40].copy_from_slice(&self.timestamp_ns.to_le_bytes());
verifying_key.verify(&message, &signature).is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{BackendKind, GateDecision};
#[test]
fn test_witness_hash_deterministic() {
let witness = WitnessLog::new(
[1u8; 32],
[2u8; 32],
BackendKind::NativeSim,
1000,
50000,
GateDecision::RanFull,
);
let hash1 = compute_witness_hash(&witness);
let hash2 = compute_witness_hash(&witness);
assert_eq!(hash1, hash2);
}
#[test]
fn test_witness_hash_changes() {
let witness1 = WitnessLog::new(
[1u8; 32],
[2u8; 32],
BackendKind::NativeSim,
1000,
50000,
GateDecision::RanFull,
);
let witness2 = WitnessLog::new(
[1u8; 32],
[2u8; 32],
BackendKind::NativeSim,
1001, // Different cycles
50000,
GateDecision::RanFull,
);
let hash1 = compute_witness_hash(&witness1);
let hash2 = compute_witness_hash(&witness2);
assert_ne!(hash1, hash2);
}
#[test]
fn test_verify_witness_hash() {
let witness = WitnessLog::empty();
let hash = compute_witness_hash(&witness);
assert!(verify_witness_hash(&witness, &hash));
assert!(!verify_witness_hash(&witness, &[0u8; 32]));
}
#[test]
fn test_chain_hash() {
let witnesses: Vec<WitnessLog> = (0..5)
.map(|i| {
WitnessLog::new(
[i as u8; 32],
[0u8; 32],
BackendKind::NativeSim,
i * 100,
i * 1000,
GateDecision::RanFull,
)
})
.collect();
let chain_hash1 = compute_chain_hash(&witnesses);
let chain_hash2 = compute_chain_hash(&witnesses);
assert_eq!(chain_hash1, chain_hash2);
}
#[test]
fn test_witness_proof() {
let witness = WitnessLog::empty();
let proof = WitnessProof::new(&witness);
assert!(proof.verify(&witness));
assert!(proof.timestamp_ns > 0);
}
}

View File

@@ -0,0 +1,311 @@
//! Witness log builder and utilities
use crate::types::{BackendKind, GateDecision, WitnessLog};
use std::time::Instant;
/// Builder for creating witness logs
pub struct WitnessBuilder {
model_hash: [u8; 32],
quant_hash: [u8; 32],
backend: BackendKind,
start_time: Instant,
cycles: u32,
gate_decision: GateDecision,
}
impl WitnessBuilder {
/// Start building a new witness
pub fn new(backend: BackendKind) -> Self {
Self {
model_hash: [0u8; 32],
quant_hash: [0u8; 32],
backend,
start_time: Instant::now(),
cycles: 0,
gate_decision: GateDecision::RanFull,
}
}
/// Set model hash
pub fn model_hash(mut self, hash: [u8; 32]) -> Self {
self.model_hash = hash;
self
}
/// Set quantization hash
pub fn quant_hash(mut self, hash: [u8; 32]) -> Self {
self.quant_hash = hash;
self
}
/// Set compute cycles
pub fn cycles(mut self, cycles: u32) -> Self {
self.cycles = cycles;
self
}
/// Set gate decision
pub fn gate_decision(mut self, decision: GateDecision) -> Self {
self.gate_decision = decision;
self
}
/// Build the witness log
pub fn build(self) -> WitnessLog {
let latency_ns = self.start_time.elapsed().as_nanos() as u32;
WitnessLog::new(
self.model_hash,
self.quant_hash,
self.backend,
self.cycles,
latency_ns,
self.gate_decision,
)
}
}
impl WitnessLog {
/// Convert to compact bytes for storage
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(80);
bytes.extend_from_slice(&self.model_hash);
bytes.extend_from_slice(&self.quant_hash);
bytes.push(self.backend as u8);
bytes.extend_from_slice(&self.cycles.to_le_bytes());
bytes.extend_from_slice(&self.latency_ns.to_le_bytes());
// Encode gate decision
match self.gate_decision {
GateDecision::RanFull => {
bytes.push(0);
bytes.push(0);
}
GateDecision::EarlyExit { layer } => {
bytes.push(1);
bytes.push(layer);
}
GateDecision::Skipped { reason } => {
bytes.push(2);
bytes.push(reason as u8);
}
}
bytes
}
/// Parse from bytes
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 75 {
return None;
}
let model_hash: [u8; 32] = bytes[0..32].try_into().ok()?;
let quant_hash: [u8; 32] = bytes[32..64].try_into().ok()?;
let backend = match bytes[64] {
0 => BackendKind::FpgaPcie,
1 => BackendKind::FpgaDaemon,
2 => BackendKind::WasmSim,
3 => BackendKind::NativeSim,
_ => BackendKind::NativeSim,
};
let cycles = u32::from_le_bytes(bytes[65..69].try_into().ok()?);
let latency_ns = u32::from_le_bytes(bytes[69..73].try_into().ok()?);
let gate_decision = match bytes[73] {
0 => GateDecision::RanFull,
1 => GateDecision::EarlyExit { layer: bytes[74] },
2 => GateDecision::Skipped {
reason: match bytes[74] {
0 => crate::types::SkipReason::LowCoherence,
1 => crate::types::SkipReason::PolicyDenied,
_ => crate::types::SkipReason::BudgetExceeded,
},
},
_ => GateDecision::RanFull,
};
Some(Self {
model_hash,
quant_hash,
backend,
cycles,
latency_ns,
gate_decision,
})
}
/// Get latency in microseconds
pub fn latency_us(&self) -> f64 {
self.latency_ns as f64 / 1000.0
}
/// Get latency in milliseconds
pub fn latency_ms(&self) -> f64 {
self.latency_ns as f64 / 1_000_000.0
}
/// Check if this was a successful full inference
pub fn is_full_inference(&self) -> bool {
matches!(self.gate_decision, GateDecision::RanFull)
}
/// Check if this was an early exit
pub fn is_early_exit(&self) -> bool {
matches!(self.gate_decision, GateDecision::EarlyExit { .. })
}
/// Check if inference was skipped
pub fn is_skipped(&self) -> bool {
matches!(self.gate_decision, GateDecision::Skipped { .. })
}
}
/// Witness log aggregator for collecting statistics
#[derive(Debug, Default)]
pub struct WitnessAggregator {
/// Total inferences
pub count: u64,
/// Total cycles
pub total_cycles: u64,
/// Total latency (ns)
pub total_latency_ns: u64,
/// Full inference count
pub full_count: u64,
/// Early exit count
pub early_exit_count: u64,
/// Skipped count
pub skipped_count: u64,
/// Sum of squares for variance calculation
latency_sq_sum: u128,
}
impl WitnessAggregator {
/// Create a new aggregator
pub fn new() -> Self {
Self::default()
}
/// Add a witness to the aggregate
pub fn add(&mut self, witness: &WitnessLog) {
self.count += 1;
self.total_cycles += witness.cycles as u64;
self.total_latency_ns += witness.latency_ns as u64;
self.latency_sq_sum += (witness.latency_ns as u128).pow(2);
match witness.gate_decision {
GateDecision::RanFull => self.full_count += 1,
GateDecision::EarlyExit { .. } => self.early_exit_count += 1,
GateDecision::Skipped { .. } => self.skipped_count += 1,
}
}
/// Get average latency (ns)
pub fn avg_latency_ns(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_latency_ns as f64 / self.count as f64
}
}
/// Get average cycles
pub fn avg_cycles(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_cycles as f64 / self.count as f64
}
}
/// Get latency standard deviation (ns)
pub fn latency_std_ns(&self) -> f64 {
if self.count <= 1 {
return 0.0;
}
let mean = self.avg_latency_ns();
let variance = (self.latency_sq_sum as f64 / self.count as f64) - (mean * mean);
variance.sqrt()
}
/// Get early exit rate
pub fn early_exit_rate(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.early_exit_count as f64 / self.count as f64
}
}
/// Get skip rate
pub fn skip_rate(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.skipped_count as f64 / self.count as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_witness_builder() {
let witness = WitnessBuilder::new(BackendKind::NativeSim)
.model_hash([1u8; 32])
.quant_hash([2u8; 32])
.cycles(1000)
.gate_decision(GateDecision::RanFull)
.build();
assert_eq!(witness.model_hash, [1u8; 32]);
assert_eq!(witness.backend, BackendKind::NativeSim);
assert_eq!(witness.cycles, 1000);
}
#[test]
fn test_witness_bytes_roundtrip() {
let witness = WitnessLog::new(
[0x42u8; 32],
[0x24u8; 32],
BackendKind::FpgaDaemon,
5000,
100_000,
GateDecision::EarlyExit { layer: 4 },
);
let bytes = witness.to_bytes();
let parsed = WitnessLog::from_bytes(&bytes).unwrap();
assert_eq!(witness.model_hash, parsed.model_hash);
assert_eq!(witness.quant_hash, parsed.quant_hash);
assert_eq!(witness.backend, parsed.backend);
assert_eq!(witness.cycles, parsed.cycles);
assert_eq!(witness.latency_ns, parsed.latency_ns);
}
#[test]
fn test_witness_aggregator() {
let mut agg = WitnessAggregator::new();
for i in 0..10 {
let mut witness = WitnessLog::empty();
witness.latency_ns = 1000 * (i + 1);
witness.cycles = 100 * (i + 1);
if i < 3 {
witness.gate_decision = GateDecision::EarlyExit { layer: 2 };
}
agg.add(&witness);
}
assert_eq!(agg.count, 10);
assert_eq!(agg.early_exit_count, 3);
assert!((agg.early_exit_rate() - 0.3).abs() < 0.01);
}
}

View File

@@ -0,0 +1,12 @@
//! Witness logging for audit trails and ReasoningBank integration
//!
//! Every inference produces a small witness bundle that records
//! what happened and enables verification and replay.
pub mod hash;
pub mod log;
// Re-export WitnessLog from types as the canonical location
pub use crate::types::WitnessLog;
pub use hash::{compute_witness_hash, verify_witness_hash};
pub use log::{WitnessAggregator, WitnessBuilder};