Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
266
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/manifest.rs
vendored
Normal file
266
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/manifest.rs
vendored
Normal file
@@ -0,0 +1,266 @@
|
||||
//! Manifest schema for model artifacts
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{FixedShape, Layout, QuantSpec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Model manifest containing all metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Manifest {
|
||||
/// Model name
|
||||
pub name: String,
|
||||
/// SHA-256 hash of model (hex string)
|
||||
pub model_hash: String,
|
||||
/// Fixed shape specification
|
||||
pub shape: FixedShape,
|
||||
/// Quantization specification
|
||||
pub quant: QuantSpec,
|
||||
/// I/O configuration
|
||||
pub io: IoSpec,
|
||||
/// Backend configuration
|
||||
pub backend: BackendSpec,
|
||||
/// Test vector specification
|
||||
pub tests: TestSpec,
|
||||
}
|
||||
|
||||
impl Manifest {
|
||||
/// Create a new manifest
|
||||
pub fn new(name: impl Into<String>, shape: FixedShape, quant: QuantSpec) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
model_hash: String::new(),
|
||||
shape,
|
||||
quant,
|
||||
io: IoSpec::default(),
|
||||
backend: BackendSpec::default(),
|
||||
tests: TestSpec::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate manifest consistency
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.name.is_empty() {
|
||||
return Err(Error::InvalidArtifact("Model name is empty".into()));
|
||||
}
|
||||
|
||||
// Validate shape
|
||||
self.shape
|
||||
.validate()
|
||||
.map_err(|e| Error::InvalidArtifact(e))?;
|
||||
|
||||
// Validate quantization bits
|
||||
if !matches!(self.quant.w_bits, 1 | 2 | 4 | 8 | 16) {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Invalid weight bits: {}",
|
||||
self.quant.w_bits
|
||||
)));
|
||||
}
|
||||
if !matches!(self.quant.a_bits, 4 | 8 | 16 | 32) {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Invalid activation bits: {}",
|
||||
self.quant.a_bits
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert to JSON string
|
||||
pub fn to_json(&self) -> Result<String> {
|
||||
serde_json::to_string_pretty(self).map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Parse from JSON string
|
||||
pub fn from_json(json: &str) -> Result<Self> {
|
||||
serde_json::from_str(json).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// I/O type specification
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IoSpec {
|
||||
/// Input token type (typically "u16")
|
||||
pub tokens: String,
|
||||
/// Output logit type (typically "i16" or "i32")
|
||||
pub logits: String,
|
||||
/// Top-K count (0 for full logits)
|
||||
pub topk: u16,
|
||||
}
|
||||
|
||||
impl Default for IoSpec {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tokens: "u16".into(),
|
||||
logits: "i16".into(),
|
||||
topk: 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend-specific configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BackendSpec {
|
||||
/// Backend kind ("fpga_pcie", "fpga_daemon", "native_sim", "wasm_sim")
|
||||
pub kind: String,
|
||||
/// Protocol version
|
||||
pub protocol: u16,
|
||||
/// Backend-specific options
|
||||
#[serde(default)]
|
||||
pub options: BackendOptions,
|
||||
}
|
||||
|
||||
impl Default for BackendSpec {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
kind: "native_sim".into(),
|
||||
protocol: 1,
|
||||
options: BackendOptions::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend-specific options
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct BackendOptions {
|
||||
/// Enable batch processing
|
||||
#[serde(default)]
|
||||
pub batch_enabled: bool,
|
||||
/// Maximum batch size
|
||||
#[serde(default)]
|
||||
pub max_batch: u16,
|
||||
/// Enable early exit
|
||||
#[serde(default)]
|
||||
pub early_exit: bool,
|
||||
/// Minimum coherence threshold for early exit
|
||||
#[serde(default)]
|
||||
pub early_exit_threshold: i16,
|
||||
/// FPGA clock frequency in MHz (for cycle estimation)
|
||||
#[serde(default)]
|
||||
pub clock_mhz: u16,
|
||||
}
|
||||
|
||||
/// Test vector specification
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TestSpec {
|
||||
/// Number of test vectors
|
||||
pub vectors: u32,
|
||||
/// Maximum absolute error allowed
|
||||
pub max_abs_err: i32,
|
||||
/// Whether test vectors must pass before activation
|
||||
#[serde(default = "default_true")]
|
||||
pub require_pass: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl Default for TestSpec {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vectors: 0,
|
||||
max_abs_err: 2,
|
||||
require_pass: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Manifest builder for convenient construction
|
||||
pub struct ManifestBuilder {
|
||||
manifest: Manifest,
|
||||
}
|
||||
|
||||
impl ManifestBuilder {
|
||||
/// Create a new builder with name and shape
|
||||
pub fn new(name: impl Into<String>, shape: FixedShape) -> Self {
|
||||
Self {
|
||||
manifest: Manifest::new(name, shape, QuantSpec::int8()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set quantization spec
|
||||
pub fn quant(mut self, quant: QuantSpec) -> Self {
|
||||
self.manifest.quant = quant;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set model hash
|
||||
pub fn model_hash(mut self, hash: impl Into<String>) -> Self {
|
||||
self.manifest.model_hash = hash.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set I/O spec
|
||||
pub fn io(mut self, io: IoSpec) -> Self {
|
||||
self.manifest.io = io;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set backend spec
|
||||
pub fn backend(mut self, backend: BackendSpec) -> Self {
|
||||
self.manifest.backend = backend;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set test spec
|
||||
pub fn tests(mut self, tests: TestSpec) -> Self {
|
||||
self.manifest.tests = tests;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable top-K only output
|
||||
pub fn topk_only(mut self, k: u16) -> Self {
|
||||
self.manifest.io.topk = k;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable early exit
|
||||
pub fn early_exit(mut self, threshold: i16) -> Self {
|
||||
self.manifest.backend.options.early_exit = true;
|
||||
self.manifest.backend.options.early_exit_threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the manifest
|
||||
pub fn build(self) -> Result<Manifest> {
|
||||
self.manifest.validate()?;
|
||||
Ok(self.manifest)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_manifest_builder() {
|
||||
let manifest = ManifestBuilder::new("test", FixedShape::micro())
|
||||
.quant(QuantSpec::int4_int8())
|
||||
.topk_only(16)
|
||||
.early_exit(100)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manifest.name, "test");
|
||||
assert_eq!(manifest.quant.w_bits, 4);
|
||||
assert_eq!(manifest.io.topk, 16);
|
||||
assert!(manifest.backend.options.early_exit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_json_roundtrip() {
|
||||
let manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
|
||||
let json = manifest.to_json().unwrap();
|
||||
let parsed = Manifest::from_json(&json).unwrap();
|
||||
assert_eq!(manifest.name, parsed.name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_validation() {
|
||||
let mut manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
|
||||
assert!(manifest.validate().is_ok());
|
||||
|
||||
manifest.name = String::new();
|
||||
assert!(manifest.validate().is_err());
|
||||
}
|
||||
}
|
||||
242
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/mod.rs
vendored
Normal file
242
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/mod.rs
vendored
Normal file
@@ -0,0 +1,242 @@
|
||||
//! Model artifact format and handling
|
||||
//!
|
||||
//! Signed bundles with metadata, weights, and test vectors.
|
||||
|
||||
pub mod manifest;
|
||||
pub mod pack;
|
||||
pub mod verify;
|
||||
|
||||
pub use manifest::{BackendSpec, IoSpec, Manifest, TestSpec};
|
||||
pub use pack::{pack_artifact, unpack_artifact};
|
||||
pub use verify::{verify_artifact, verify_signature};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{FixedShape, ModelId, QuantSpec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Complete model artifact
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelArtifact {
|
||||
/// Manifest with metadata
|
||||
pub manifest: Manifest,
|
||||
/// Quantized weights (binary blob)
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub weights: Vec<u8>,
|
||||
/// Optional FPGA bitstream
|
||||
#[serde(with = "serde_bytes_option")]
|
||||
pub bitstream: Option<Vec<u8>>,
|
||||
/// Optional calibration data
|
||||
#[serde(with = "serde_bytes_option")]
|
||||
pub calibration: Option<Vec<u8>>,
|
||||
/// Test vectors for validation
|
||||
pub test_vectors: Vec<TestVector>,
|
||||
/// Ed25519 signature over manifest + file hashes
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub signature: [u8; 64],
|
||||
/// Ed25519 public key
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub pubkey: [u8; 32],
|
||||
}
|
||||
|
||||
/// Serde helper for Option<Vec<u8>>
|
||||
mod serde_bytes_option {
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
pub fn serialize<S: Serializer>(data: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
|
||||
match data {
|
||||
Some(bytes) => s.serialize_some(&serde_bytes::Bytes::new(bytes)),
|
||||
None => s.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Vec<u8>>, D::Error> {
|
||||
let opt: Option<serde_bytes::ByteBuf> = Option::deserialize(d)?;
|
||||
Ok(opt.map(|b| b.into_vec()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector for model validation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TestVector {
|
||||
/// Input tokens
|
||||
pub tokens: Vec<u16>,
|
||||
/// Expected output logits (top-K or full)
|
||||
pub expected: Vec<i16>,
|
||||
/// Maximum absolute error allowed
|
||||
pub max_abs_err: i32,
|
||||
}
|
||||
|
||||
impl ModelArtifact {
|
||||
/// Create a new artifact (for building/packing)
|
||||
pub fn new(
|
||||
manifest: Manifest,
|
||||
weights: Vec<u8>,
|
||||
bitstream: Option<Vec<u8>>,
|
||||
calibration: Option<Vec<u8>>,
|
||||
test_vectors: Vec<TestVector>,
|
||||
) -> Self {
|
||||
Self {
|
||||
manifest,
|
||||
weights,
|
||||
bitstream,
|
||||
calibration,
|
||||
test_vectors,
|
||||
signature: [0u8; 64],
|
||||
pubkey: [0u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute model ID (SHA-256 of manifest + weights hash)
|
||||
pub fn model_id(&self) -> ModelId {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(self.manifest.name.as_bytes());
|
||||
hasher.update(&self.model_hash());
|
||||
hasher.update(&self.quant_hash());
|
||||
ModelId::new(hasher.finalize().into())
|
||||
}
|
||||
|
||||
/// Compute hash of model weights
|
||||
pub fn model_hash(&self) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&self.weights);
|
||||
if let Some(ref bitstream) = self.bitstream {
|
||||
hasher.update(bitstream);
|
||||
}
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Compute hash of quantization parameters
|
||||
pub fn quant_hash(&self) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
let quant_json = serde_json::to_string(&self.manifest.quant).unwrap_or_default();
|
||||
hasher.update(quant_json.as_bytes());
|
||||
if let Some(ref calib) = self.calibration {
|
||||
hasher.update(calib);
|
||||
}
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Validate artifact integrity
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
// Validate manifest
|
||||
self.manifest.validate()?;
|
||||
|
||||
// Validate shape
|
||||
self.manifest
|
||||
.shape
|
||||
.validate()
|
||||
.map_err(|e| Error::InvalidArtifact(e))?;
|
||||
|
||||
// Check weights size is reasonable
|
||||
let min_weight_size =
|
||||
self.manifest.shape.embedding_params() / self.manifest.quant.weights_per_byte();
|
||||
if self.weights.len() < min_weight_size {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Weights too small: {} bytes, expected at least {} for embeddings",
|
||||
self.weights.len(),
|
||||
min_weight_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Validate test vectors if strict mode
|
||||
#[cfg(feature = "strict_verify")]
|
||||
self.run_test_vectors()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run test vectors for validation
|
||||
#[cfg(feature = "strict_verify")]
|
||||
pub fn run_test_vectors(&self) -> Result<()> {
|
||||
// This would require running inference, which creates a circular dependency
|
||||
// In practice, this is done by the backend after loading
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the fixed shape
|
||||
pub fn shape(&self) -> &FixedShape {
|
||||
&self.manifest.shape
|
||||
}
|
||||
|
||||
/// Get quantization spec
|
||||
pub fn quant(&self) -> &QuantSpec {
|
||||
&self.manifest.quant
|
||||
}
|
||||
|
||||
/// Check if artifact has FPGA bitstream
|
||||
pub fn has_bitstream(&self) -> bool {
|
||||
self.bitstream.is_some()
|
||||
}
|
||||
|
||||
/// Estimated memory footprint in bytes
|
||||
pub fn memory_footprint(&self) -> usize {
|
||||
self.weights.len()
|
||||
+ self.bitstream.as_ref().map(|b| b.len()).unwrap_or(0)
|
||||
+ self.calibration.as_ref().map(|c| c.len()).unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_manifest() -> Manifest {
|
||||
Manifest {
|
||||
name: "test_model".into(),
|
||||
model_hash: "0".repeat(64),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: IoSpec::default(),
|
||||
backend: BackendSpec::default(),
|
||||
tests: TestSpec::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_id_computation() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![]);
|
||||
|
||||
let id1 = artifact.model_id();
|
||||
let id2 = artifact.model_id();
|
||||
assert_eq!(id1, id2); // Deterministic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_hash() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(manifest, vec![42u8; 4096 * 64], None, None, vec![]);
|
||||
|
||||
let hash = artifact.model_hash();
|
||||
assert_ne!(hash, [0u8; 32]); // Non-zero hash
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_artifact_validation() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(
|
||||
manifest,
|
||||
vec![0u8; 4096 * 64], // Enough for micro embeddings
|
||||
None,
|
||||
None,
|
||||
vec![],
|
||||
);
|
||||
|
||||
assert!(artifact.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_artifact_too_small_weights() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(
|
||||
manifest,
|
||||
vec![0u8; 100], // Too small
|
||||
None,
|
||||
None,
|
||||
vec![],
|
||||
);
|
||||
|
||||
assert!(artifact.validate().is_err());
|
||||
}
|
||||
}
|
||||
304
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/pack.rs
vendored
Normal file
304
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/pack.rs
vendored
Normal file
@@ -0,0 +1,304 @@
|
||||
//! Artifact packing and unpacking
|
||||
|
||||
use std::io::{Read, Write};
|
||||
use std::path::Path;
|
||||
|
||||
use crate::artifact::{ModelArtifact, TestVector};
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Magic bytes for artifact file format
|
||||
const ARTIFACT_MAGIC: &[u8; 4] = b"RVAT"; // RuVector ArTifact
|
||||
const ARTIFACT_VERSION: u16 = 1;
|
||||
|
||||
// Security: Maximum size limits to prevent DoS via unbounded allocations
|
||||
/// Maximum manifest size (1 MB)
|
||||
const MAX_MANIFEST_SIZE: usize = 1024 * 1024;
|
||||
/// Maximum weights size (1 GB)
|
||||
const MAX_WEIGHTS_SIZE: usize = 1024 * 1024 * 1024;
|
||||
/// Maximum bitstream/calibration size (100 MB)
|
||||
const MAX_BLOB_SIZE: usize = 100 * 1024 * 1024;
|
||||
/// Maximum number of test vectors
|
||||
const MAX_TEST_VECTORS: usize = 10_000;
|
||||
/// Maximum tokens per test vector
|
||||
const MAX_TOKENS_PER_VECTOR: usize = 65_536;
|
||||
/// Maximum expected values per test vector
|
||||
const MAX_EXPECTED_PER_VECTOR: usize = 1_000_000;
|
||||
|
||||
/// Pack an artifact to bytes
|
||||
pub fn pack_artifact(artifact: &ModelArtifact) -> Result<Vec<u8>> {
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
// Write magic and version
|
||||
buffer.extend_from_slice(ARTIFACT_MAGIC);
|
||||
buffer.extend_from_slice(&ARTIFACT_VERSION.to_le_bytes());
|
||||
|
||||
// Write manifest as JSON with length prefix
|
||||
let manifest_json = serde_json::to_string(&artifact.manifest)?;
|
||||
let manifest_bytes = manifest_json.as_bytes();
|
||||
buffer.extend_from_slice(&(manifest_bytes.len() as u32).to_le_bytes());
|
||||
buffer.extend_from_slice(manifest_bytes);
|
||||
|
||||
// Write weights with length prefix
|
||||
buffer.extend_from_slice(&(artifact.weights.len() as u64).to_le_bytes());
|
||||
buffer.extend_from_slice(&artifact.weights);
|
||||
|
||||
// Write optional bitstream
|
||||
if let Some(ref bitstream) = artifact.bitstream {
|
||||
buffer.push(1); // Present flag
|
||||
buffer.extend_from_slice(&(bitstream.len() as u64).to_le_bytes());
|
||||
buffer.extend_from_slice(bitstream);
|
||||
} else {
|
||||
buffer.push(0); // Not present
|
||||
}
|
||||
|
||||
// Write optional calibration
|
||||
if let Some(ref calibration) = artifact.calibration {
|
||||
buffer.push(1);
|
||||
buffer.extend_from_slice(&(calibration.len() as u64).to_le_bytes());
|
||||
buffer.extend_from_slice(calibration);
|
||||
} else {
|
||||
buffer.push(0);
|
||||
}
|
||||
|
||||
// Write test vectors
|
||||
buffer.extend_from_slice(&(artifact.test_vectors.len() as u32).to_le_bytes());
|
||||
for vector in &artifact.test_vectors {
|
||||
// Write tokens
|
||||
buffer.extend_from_slice(&(vector.tokens.len() as u16).to_le_bytes());
|
||||
for &token in &vector.tokens {
|
||||
buffer.extend_from_slice(&token.to_le_bytes());
|
||||
}
|
||||
// Write expected
|
||||
buffer.extend_from_slice(&(vector.expected.len() as u32).to_le_bytes());
|
||||
for &exp in &vector.expected {
|
||||
buffer.extend_from_slice(&exp.to_le_bytes());
|
||||
}
|
||||
// Write max_abs_err
|
||||
buffer.extend_from_slice(&vector.max_abs_err.to_le_bytes());
|
||||
}
|
||||
|
||||
// Write signature and pubkey
|
||||
buffer.extend_from_slice(&artifact.signature);
|
||||
buffer.extend_from_slice(&artifact.pubkey);
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// Unpack an artifact from bytes
|
||||
pub fn unpack_artifact(data: &[u8]) -> Result<ModelArtifact> {
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
let mut read_buf = [0u8; 8];
|
||||
|
||||
// Read and verify magic
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
if &read_buf[..4] != ARTIFACT_MAGIC {
|
||||
return Err(Error::InvalidArtifact("Invalid magic bytes".into()));
|
||||
}
|
||||
|
||||
// Read version
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
let version = u16::from_le_bytes([read_buf[0], read_buf[1]]);
|
||||
if version != ARTIFACT_VERSION {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Unsupported version: {}",
|
||||
version
|
||||
)));
|
||||
}
|
||||
|
||||
// Read manifest
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let manifest_len = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
|
||||
if manifest_len > MAX_MANIFEST_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Manifest size {} exceeds maximum {}",
|
||||
manifest_len, MAX_MANIFEST_SIZE
|
||||
)));
|
||||
}
|
||||
let mut manifest_bytes = vec![0u8; manifest_len];
|
||||
cursor.read_exact(&mut manifest_bytes)?;
|
||||
let manifest = serde_json::from_slice(&manifest_bytes)?;
|
||||
|
||||
// Read weights
|
||||
cursor.read_exact(&mut read_buf)?;
|
||||
let weights_len = u64::from_le_bytes(read_buf) as usize;
|
||||
if weights_len > MAX_WEIGHTS_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Weights size {} exceeds maximum {}",
|
||||
weights_len, MAX_WEIGHTS_SIZE
|
||||
)));
|
||||
}
|
||||
let mut weights = vec![0u8; weights_len];
|
||||
cursor.read_exact(&mut weights)?;
|
||||
|
||||
// Read optional bitstream
|
||||
cursor.read_exact(&mut read_buf[..1])?;
|
||||
let bitstream = if read_buf[0] == 1 {
|
||||
cursor.read_exact(&mut read_buf)?;
|
||||
let len = u64::from_le_bytes(read_buf) as usize;
|
||||
if len > MAX_BLOB_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Bitstream size {} exceeds maximum {}",
|
||||
len, MAX_BLOB_SIZE
|
||||
)));
|
||||
}
|
||||
let mut data = vec![0u8; len];
|
||||
cursor.read_exact(&mut data)?;
|
||||
Some(data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Read optional calibration
|
||||
cursor.read_exact(&mut read_buf[..1])?;
|
||||
let calibration = if read_buf[0] == 1 {
|
||||
cursor.read_exact(&mut read_buf)?;
|
||||
let len = u64::from_le_bytes(read_buf) as usize;
|
||||
if len > MAX_BLOB_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Calibration size {} exceeds maximum {}",
|
||||
len, MAX_BLOB_SIZE
|
||||
)));
|
||||
}
|
||||
let mut data = vec![0u8; len];
|
||||
cursor.read_exact(&mut data)?;
|
||||
Some(data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Read test vectors
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let num_vectors = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
|
||||
if num_vectors > MAX_TEST_VECTORS {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Test vector count {} exceeds maximum {}",
|
||||
num_vectors, MAX_TEST_VECTORS
|
||||
)));
|
||||
}
|
||||
let mut test_vectors = Vec::with_capacity(num_vectors);
|
||||
|
||||
for _ in 0..num_vectors {
|
||||
// Read tokens
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
let num_tokens = u16::from_le_bytes([read_buf[0], read_buf[1]]) as usize;
|
||||
if num_tokens > MAX_TOKENS_PER_VECTOR {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Token count {} exceeds maximum {}",
|
||||
num_tokens, MAX_TOKENS_PER_VECTOR
|
||||
)));
|
||||
}
|
||||
let mut tokens = Vec::with_capacity(num_tokens);
|
||||
for _ in 0..num_tokens {
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
tokens.push(u16::from_le_bytes([read_buf[0], read_buf[1]]));
|
||||
}
|
||||
|
||||
// Read expected
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let num_expected = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
|
||||
if num_expected > MAX_EXPECTED_PER_VECTOR {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Expected values count {} exceeds maximum {}",
|
||||
num_expected, MAX_EXPECTED_PER_VECTOR
|
||||
)));
|
||||
}
|
||||
let mut expected = Vec::with_capacity(num_expected);
|
||||
for _ in 0..num_expected {
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
expected.push(i16::from_le_bytes([read_buf[0], read_buf[1]]));
|
||||
}
|
||||
|
||||
// Read max_abs_err
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let max_abs_err = i32::from_le_bytes(read_buf[..4].try_into().unwrap());
|
||||
|
||||
test_vectors.push(TestVector {
|
||||
tokens,
|
||||
expected,
|
||||
max_abs_err,
|
||||
});
|
||||
}
|
||||
|
||||
// Read signature and pubkey
|
||||
let mut signature = [0u8; 64];
|
||||
cursor.read_exact(&mut signature)?;
|
||||
let mut pubkey = [0u8; 32];
|
||||
cursor.read_exact(&mut pubkey)?;
|
||||
|
||||
Ok(ModelArtifact {
|
||||
manifest,
|
||||
weights,
|
||||
bitstream,
|
||||
calibration,
|
||||
test_vectors,
|
||||
signature,
|
||||
pubkey,
|
||||
})
|
||||
}
|
||||
|
||||
/// Save artifact to file
|
||||
pub fn save_artifact(artifact: &ModelArtifact, path: impl AsRef<Path>) -> Result<()> {
|
||||
let data = pack_artifact(artifact)?;
|
||||
std::fs::write(path, data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load artifact from file
|
||||
pub fn load_artifact(path: impl AsRef<Path>) -> Result<ModelArtifact> {
|
||||
let data = std::fs::read(path)?;
|
||||
unpack_artifact(&data)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::types::{FixedShape, QuantSpec};
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "test_pack".into(),
|
||||
model_hash: "abc123".into(),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact {
|
||||
manifest,
|
||||
weights: (0..5000).map(|i| (i % 256) as u8).collect(),
|
||||
bitstream: Some(vec![0xFF; 100]),
|
||||
calibration: None,
|
||||
test_vectors: vec![TestVector {
|
||||
tokens: vec![1, 2, 3],
|
||||
expected: vec![100, 200, 300],
|
||||
max_abs_err: 5,
|
||||
}],
|
||||
signature: [0x42u8; 64],
|
||||
pubkey: [0x24u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_unpack_roundtrip() {
|
||||
let original = create_test_artifact();
|
||||
let packed = pack_artifact(&original).unwrap();
|
||||
let unpacked = unpack_artifact(&packed).unwrap();
|
||||
|
||||
assert_eq!(original.manifest.name, unpacked.manifest.name);
|
||||
assert_eq!(original.weights, unpacked.weights);
|
||||
assert_eq!(original.bitstream, unpacked.bitstream);
|
||||
assert_eq!(original.calibration, unpacked.calibration);
|
||||
assert_eq!(original.test_vectors.len(), unpacked.test_vectors.len());
|
||||
assert_eq!(original.signature, unpacked.signature);
|
||||
assert_eq!(original.pubkey, unpacked.pubkey);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_magic() {
|
||||
let data = b"XXXX0000";
|
||||
assert!(unpack_artifact(data).is_err());
|
||||
}
|
||||
}
|
||||
203
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/verify.rs
vendored
Normal file
203
vendor/ruvector/crates/ruvector-fpga-transformer/src/artifact/verify.rs
vendored
Normal file
@@ -0,0 +1,203 @@
|
||||
//! Artifact verification and signature validation
|
||||
|
||||
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Verify artifact signature
|
||||
pub fn verify_signature(artifact: &ModelArtifact) -> Result<bool> {
|
||||
// Compute the message to verify (manifest hash + file hashes)
|
||||
let message = compute_signing_message(artifact);
|
||||
|
||||
// Load public key
|
||||
let pubkey = VerifyingKey::from_bytes(&artifact.pubkey)
|
||||
.map_err(|e| Error::SignatureError(format!("Invalid public key: {}", e)))?;
|
||||
|
||||
// Load signature
|
||||
let signature = Signature::from_bytes(&artifact.signature);
|
||||
|
||||
// Verify
|
||||
pubkey
|
||||
.verify(&message, &signature)
|
||||
.map(|_| true)
|
||||
.map_err(|e| Error::SignatureError(format!("Verification failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Verify complete artifact integrity
|
||||
pub fn verify_artifact(artifact: &ModelArtifact) -> Result<()> {
|
||||
// 1. Validate manifest
|
||||
artifact.manifest.validate()?;
|
||||
|
||||
// 2. Verify model hash matches manifest
|
||||
let computed_hash = hex::encode(artifact.model_hash());
|
||||
if !artifact.manifest.model_hash.is_empty() && computed_hash != artifact.manifest.model_hash {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Model hash mismatch: expected {}, got {}",
|
||||
artifact.manifest.model_hash, computed_hash
|
||||
)));
|
||||
}
|
||||
|
||||
// 3. Verify signature (if present)
|
||||
if artifact.pubkey != [0u8; 32] {
|
||||
verify_signature(artifact)?;
|
||||
}
|
||||
|
||||
// 4. Verify weights size
|
||||
let expected_min =
|
||||
artifact.manifest.shape.embedding_params() / artifact.manifest.quant.weights_per_byte();
|
||||
if artifact.weights.len() < expected_min {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Weights too small: {} < {}",
|
||||
artifact.weights.len(),
|
||||
expected_min
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute the message that was signed
|
||||
fn compute_signing_message(artifact: &ModelArtifact) -> Vec<u8> {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
// Hash manifest
|
||||
let manifest_json = serde_json::to_string(&artifact.manifest).unwrap_or_default();
|
||||
hasher.update(manifest_json.as_bytes());
|
||||
|
||||
// Hash weights
|
||||
let weights_hash = artifact.model_hash();
|
||||
hasher.update(&weights_hash);
|
||||
|
||||
// Hash quant params
|
||||
let quant_hash = artifact.quant_hash();
|
||||
hasher.update(&quant_hash);
|
||||
|
||||
// Hash bitstream if present
|
||||
if let Some(ref bitstream) = artifact.bitstream {
|
||||
let mut h = Sha256::new();
|
||||
h.update(bitstream);
|
||||
hasher.update(&h.finalize());
|
||||
}
|
||||
|
||||
// Hash calibration if present
|
||||
if let Some(ref calib) = artifact.calibration {
|
||||
let mut h = Sha256::new();
|
||||
h.update(calib);
|
||||
hasher.update(&h.finalize());
|
||||
}
|
||||
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
/// Sign an artifact with Ed25519 private key
|
||||
#[cfg(feature = "sign")]
|
||||
pub fn sign_artifact(artifact: &mut ModelArtifact, secret_key: &[u8; 32]) -> Result<()> {
|
||||
use ed25519_dalek::{Signer, SigningKey};
|
||||
|
||||
let signing_key = SigningKey::from_bytes(secret_key);
|
||||
let message = compute_signing_message(artifact);
|
||||
|
||||
let signature = signing_key.sign(&message);
|
||||
|
||||
artifact.signature = signature.to_bytes();
|
||||
artifact.pubkey = signing_key.verifying_key().to_bytes();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verify test vectors against model output
|
||||
pub fn verify_test_vectors(
|
||||
artifact: &ModelArtifact,
|
||||
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
|
||||
) -> Result<()> {
|
||||
let max_err = artifact.manifest.tests.max_abs_err;
|
||||
|
||||
for (i, vector) in artifact.test_vectors.iter().enumerate() {
|
||||
let output = infer_fn(&vector.tokens)?;
|
||||
|
||||
// Compare outputs
|
||||
let actual_max_err = output
|
||||
.iter()
|
||||
.zip(&vector.expected)
|
||||
.map(|(&a, &b)| (a as i32 - b as i32).abs())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
if actual_max_err > max_err {
|
||||
return Err(Error::TestVectorError {
|
||||
expected: max_err,
|
||||
actual: actual_max_err,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate test vectors for an artifact
|
||||
pub fn generate_test_vectors(
|
||||
artifact: &mut ModelArtifact,
|
||||
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
|
||||
count: usize,
|
||||
) -> Result<()> {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let seq_len = artifact.manifest.shape.seq_len as usize;
|
||||
let vocab = artifact.manifest.shape.vocab as u16;
|
||||
|
||||
artifact.test_vectors.clear();
|
||||
|
||||
for _ in 0..count {
|
||||
// Generate random input
|
||||
let tokens: Vec<u16> = (0..seq_len).map(|_| rng.gen_range(0..vocab)).collect();
|
||||
|
||||
// Run inference
|
||||
let expected = infer_fn(&tokens)?;
|
||||
|
||||
artifact.test_vectors.push(crate::artifact::TestVector {
|
||||
tokens,
|
||||
expected,
|
||||
max_abs_err: artifact.manifest.tests.max_abs_err,
|
||||
});
|
||||
}
|
||||
|
||||
artifact.manifest.tests.vectors = count as u32;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::types::{FixedShape, QuantSpec};
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "test".into(),
|
||||
model_hash: String::new(),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_artifact() {
|
||||
let artifact = create_test_artifact();
|
||||
assert!(verify_artifact(&artifact).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_signing_message() {
|
||||
let artifact = create_test_artifact();
|
||||
let msg = compute_signing_message(&artifact);
|
||||
assert_eq!(msg.len(), 32); // SHA-256 output
|
||||
}
|
||||
}
|
||||
566
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/fpga_daemon.rs
vendored
Normal file
566
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/fpga_daemon.rs
vendored
Normal file
@@ -0,0 +1,566 @@
|
||||
//! FPGA Daemon backend
|
||||
//!
|
||||
//! Communicates with a local daemon over Unix socket or TCP
|
||||
//! to send inference requests to an FPGA accelerator.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::{Read, Write};
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{
|
||||
commands, compute_topk, crc32, protocol, read_lock, validate_tokens, write_lock, BackendStats,
|
||||
RequestFrame, ResponseFrame, TransformerBackend,
|
||||
};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{
|
||||
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
|
||||
};
|
||||
|
||||
/// Connection type for daemon communication
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DaemonConnection {
|
||||
/// Unix domain socket path
|
||||
Unix(String),
|
||||
/// TCP address (host:port)
|
||||
Tcp(String),
|
||||
}
|
||||
|
||||
impl DaemonConnection {
|
||||
/// Create a Unix socket connection
|
||||
pub fn unix(path: impl Into<String>) -> Self {
|
||||
Self::Unix(path.into())
|
||||
}
|
||||
|
||||
/// Create a TCP connection
|
||||
pub fn tcp(addr: impl Into<String>) -> Self {
|
||||
Self::Tcp(addr.into())
|
||||
}
|
||||
|
||||
/// Default socket path
|
||||
pub fn default_socket() -> Self {
|
||||
Self::Unix("/var/run/ruvector_fpga.sock".into())
|
||||
}
|
||||
}
|
||||
|
||||
/// FPGA Daemon backend
|
||||
pub struct FpgaDaemonBackend {
|
||||
/// Connection configuration
|
||||
connection: DaemonConnection,
|
||||
/// Loaded models (cached metadata)
|
||||
models: RwLock<HashMap<ModelId, ModelMetadata>>,
|
||||
/// Statistics
|
||||
stats: RwLock<BackendStats>,
|
||||
/// Configuration
|
||||
config: DaemonConfig,
|
||||
}
|
||||
|
||||
/// Cached model metadata
|
||||
struct ModelMetadata {
|
||||
artifact: ModelArtifact,
|
||||
loaded_at: Instant,
|
||||
}
|
||||
|
||||
/// Configuration for daemon backend
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DaemonConfig {
|
||||
/// Connection timeout in milliseconds
|
||||
pub connect_timeout_ms: u64,
|
||||
/// Read timeout in milliseconds
|
||||
pub read_timeout_ms: u64,
|
||||
/// Write timeout in milliseconds
|
||||
pub write_timeout_ms: u64,
|
||||
/// Number of retry attempts
|
||||
pub retries: usize,
|
||||
/// Retry backoff multiplier
|
||||
pub backoff_multiplier: f64,
|
||||
/// Return only top-K results
|
||||
pub topk_only: bool,
|
||||
/// Top-K count
|
||||
pub topk: u16,
|
||||
}
|
||||
|
||||
impl Default for DaemonConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
connect_timeout_ms: 5000,
|
||||
read_timeout_ms: 10000,
|
||||
write_timeout_ms: 5000,
|
||||
retries: 3,
|
||||
backoff_multiplier: 2.0,
|
||||
topk_only: true,
|
||||
topk: 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FpgaDaemonBackend {
|
||||
/// Create a new daemon backend with Unix socket
|
||||
pub fn new(socket_path: impl AsRef<Path>) -> Self {
|
||||
Self::with_connection(
|
||||
DaemonConnection::unix(socket_path.as_ref().to_string_lossy()),
|
||||
DaemonConfig::default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create with custom connection and config
|
||||
pub fn with_connection(connection: DaemonConnection, config: DaemonConfig) -> Self {
|
||||
Self {
|
||||
connection,
|
||||
models: RwLock::new(HashMap::new()),
|
||||
stats: RwLock::new(BackendStats::default()),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to the daemon
|
||||
fn connect(&self) -> Result<Box<dyn ReadWrite>> {
|
||||
let timeout = Duration::from_millis(self.config.connect_timeout_ms);
|
||||
|
||||
match &self.connection {
|
||||
DaemonConnection::Unix(path) => {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::net::UnixStream;
|
||||
let stream = UnixStream::connect(path)
|
||||
.map_err(|e| Error::daemon_connection(format!("Unix socket: {}", e)))?;
|
||||
stream
|
||||
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
|
||||
.ok();
|
||||
stream
|
||||
.set_write_timeout(Some(Duration::from_millis(
|
||||
self.config.write_timeout_ms,
|
||||
)))
|
||||
.ok();
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = (path, timeout);
|
||||
Err(Error::FeatureNotAvailable(
|
||||
"Unix sockets not available on this platform".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
DaemonConnection::Tcp(addr) => {
|
||||
use std::net::TcpStream;
|
||||
let stream = TcpStream::connect_timeout(
|
||||
&addr
|
||||
.parse()
|
||||
.map_err(|e| Error::daemon_connection(format!("Invalid address: {}", e)))?,
|
||||
timeout,
|
||||
)
|
||||
.map_err(|e| Error::daemon_connection(format!("TCP: {}", e)))?;
|
||||
stream
|
||||
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
|
||||
.ok();
|
||||
stream
|
||||
.set_write_timeout(Some(Duration::from_millis(self.config.write_timeout_ms)))
|
||||
.ok();
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send inference request to daemon
|
||||
fn send_request(
|
||||
&self,
|
||||
stream: &mut dyn ReadWrite,
|
||||
req: &InferenceRequest,
|
||||
) -> Result<(Vec<i16>, ResponseFrame)> {
|
||||
let shape = &req.shape;
|
||||
|
||||
// Build request flags
|
||||
let mut flags = 0u16;
|
||||
if self.config.topk_only {
|
||||
flags |= protocol::flags::TOPK_ONLY;
|
||||
}
|
||||
|
||||
// Create request frame
|
||||
let frame = RequestFrame::new(
|
||||
shape.seq_len,
|
||||
shape.d_model,
|
||||
shape.vocab,
|
||||
&req.model,
|
||||
flags,
|
||||
self.config.topk,
|
||||
);
|
||||
|
||||
// Build payload
|
||||
let mut payload = Vec::with_capacity(
|
||||
protocol::HEADER_SIZE + req.tokens.len() * 2 + req.attn_mask.len() + 8,
|
||||
);
|
||||
|
||||
// Write header
|
||||
payload.extend_from_slice(&frame.to_bytes());
|
||||
|
||||
// Write tokens (u16 little-endian)
|
||||
for &token in req.tokens {
|
||||
payload.extend_from_slice(&token.to_le_bytes());
|
||||
}
|
||||
|
||||
// Write mask
|
||||
payload.extend_from_slice(req.attn_mask);
|
||||
|
||||
// Write gate hint (packed)
|
||||
payload.extend_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
|
||||
payload.push(req.gate_hint.boundary_crossed as u8);
|
||||
payload.push(req.gate_hint.max_compute_class as u8);
|
||||
|
||||
// Calculate and append checksum
|
||||
let checksum = crc32(&payload);
|
||||
payload.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
// Send payload
|
||||
stream
|
||||
.write_all(&payload)
|
||||
.map_err(|e| Error::backend(format!("Write failed: {}", e)))?;
|
||||
stream
|
||||
.flush()
|
||||
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
|
||||
|
||||
// Read response header
|
||||
let mut response_header = [0u8; 14];
|
||||
stream
|
||||
.read_exact(&mut response_header)
|
||||
.map_err(|e| Error::backend(format!("Read header failed: {}", e)))?;
|
||||
|
||||
let response = ResponseFrame::from_bytes(&response_header);
|
||||
|
||||
// Copy packed fields to avoid alignment issues
|
||||
let status = { response.status };
|
||||
|
||||
// Check status
|
||||
match status {
|
||||
protocol::status::OK => {}
|
||||
protocol::status::MODEL_NOT_FOUND => {
|
||||
return Err(Error::ModelNotFound(req.model));
|
||||
}
|
||||
protocol::status::SHAPE_MISMATCH => {
|
||||
return Err(Error::ShapeMismatch {
|
||||
expected: req.shape,
|
||||
actual: req.shape, // Daemon should provide actual shape
|
||||
});
|
||||
}
|
||||
protocol::status::GATE_BLOCKED => {
|
||||
return Err(Error::GateBlocked {
|
||||
reason: crate::types::SkipReason::PolicyDenied,
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::backend(format!("Daemon error: status {}", status)));
|
||||
}
|
||||
}
|
||||
|
||||
// Read logits
|
||||
let logits_count = if self.config.topk_only {
|
||||
self.config.topk as usize * 2 // (token_id, logit) pairs
|
||||
} else {
|
||||
shape.vocab as usize
|
||||
};
|
||||
|
||||
let mut logits_bytes = vec![0u8; logits_count * 2];
|
||||
stream
|
||||
.read_exact(&mut logits_bytes)
|
||||
.map_err(|e| Error::backend(format!("Read logits failed: {}", e)))?;
|
||||
|
||||
// Parse logits
|
||||
let logits: Vec<i16> = logits_bytes
|
||||
.chunks(2)
|
||||
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
|
||||
.collect();
|
||||
|
||||
// Read and verify checksum
|
||||
let mut checksum_bytes = [0u8; 4];
|
||||
stream.read_exact(&mut checksum_bytes).ok(); // Checksum is optional
|
||||
|
||||
Ok((logits, response))
|
||||
}
|
||||
|
||||
/// Send load model command to daemon
|
||||
fn send_load_command(
|
||||
&self,
|
||||
stream: &mut dyn ReadWrite,
|
||||
artifact: &ModelArtifact,
|
||||
) -> Result<()> {
|
||||
// Pack artifact
|
||||
let artifact_bytes = crate::artifact::pack::pack_artifact(artifact)?;
|
||||
|
||||
// Build command packet:
|
||||
// [command: 1] [model_id: 32] [artifact_len: 4] [artifact_data: N] [checksum: 4]
|
||||
let mut payload = Vec::with_capacity(1 + 32 + 4 + artifact_bytes.len() + 4);
|
||||
|
||||
// Command byte
|
||||
payload.push(commands::LOAD_MODEL);
|
||||
|
||||
// Model ID (32 bytes)
|
||||
payload.extend_from_slice(artifact.model_id().as_bytes());
|
||||
|
||||
// Artifact length (u32 LE)
|
||||
payload.extend_from_slice(&(artifact_bytes.len() as u32).to_le_bytes());
|
||||
|
||||
// Artifact data
|
||||
payload.extend_from_slice(&artifact_bytes);
|
||||
|
||||
// Checksum
|
||||
let checksum = crc32(&payload);
|
||||
payload.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
// Send
|
||||
stream
|
||||
.write_all(&payload)
|
||||
.map_err(|e| Error::backend(format!("Write load command failed: {}", e)))?;
|
||||
stream
|
||||
.flush()
|
||||
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
|
||||
|
||||
// Read response: [status: 1] [message_len: 2] [message: N]
|
||||
let mut status = [0u8; 1];
|
||||
stream
|
||||
.read_exact(&mut status)
|
||||
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
|
||||
|
||||
if status[0] != 0 {
|
||||
// Read error message
|
||||
let mut msg_len = [0u8; 2];
|
||||
stream.read_exact(&mut msg_len).ok();
|
||||
let len = u16::from_le_bytes(msg_len) as usize;
|
||||
let mut msg = vec![0u8; len.min(256)];
|
||||
stream.read_exact(&mut msg).ok();
|
||||
let error_msg = String::from_utf8_lossy(&msg);
|
||||
return Err(Error::backend(format!(
|
||||
"Daemon rejected load: {}",
|
||||
error_msg
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send unload model command to daemon
|
||||
fn send_unload_command(&self, stream: &mut dyn ReadWrite, model_id: ModelId) -> Result<()> {
|
||||
// Build command packet: [command: 1] [model_id: 32] [checksum: 4]
|
||||
let mut payload = Vec::with_capacity(1 + 32 + 4);
|
||||
payload.push(commands::UNLOAD_MODEL);
|
||||
payload.extend_from_slice(model_id.as_bytes());
|
||||
let checksum = crc32(&payload);
|
||||
payload.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
// Send
|
||||
stream
|
||||
.write_all(&payload)
|
||||
.map_err(|e| Error::backend(format!("Write unload command failed: {}", e)))?;
|
||||
stream
|
||||
.flush()
|
||||
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
|
||||
|
||||
// Read response status
|
||||
let mut status = [0u8; 1];
|
||||
stream
|
||||
.read_exact(&mut status)
|
||||
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
|
||||
|
||||
if status[0] != 0 {
|
||||
return Err(Error::backend("Daemon rejected unload"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute with retries
|
||||
fn with_retries<T, F>(&self, mut f: F) -> Result<T>
|
||||
where
|
||||
F: FnMut() -> Result<T>,
|
||||
{
|
||||
let mut last_error = None;
|
||||
let mut delay = Duration::from_millis(100);
|
||||
|
||||
for attempt in 0..=self.config.retries {
|
||||
match f() {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) if e.is_recoverable() => {
|
||||
last_error = Some(e);
|
||||
if attempt < self.config.retries {
|
||||
std::thread::sleep(delay);
|
||||
delay = Duration::from_secs_f64(
|
||||
delay.as_secs_f64() * self.config.backoff_multiplier,
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| Error::backend("Unknown error")))
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for FpgaDaemonBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
let model_id = artifact.model_id();
|
||||
|
||||
// Send load command to daemon to preload the model
|
||||
self.with_retries(|| {
|
||||
let mut stream = self.connect()?;
|
||||
self.send_load_command(stream.as_mut(), artifact)
|
||||
})?;
|
||||
|
||||
// Cache metadata locally
|
||||
{
|
||||
let mut models = write_lock(&self.models, |m| {
|
||||
m.insert(
|
||||
model_id,
|
||||
ModelMetadata {
|
||||
artifact: artifact.clone(),
|
||||
loaded_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
write_lock(&self.stats, |s| {
|
||||
s.models_loaded += 1;
|
||||
})?;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Check model is loaded locally and validate tokens
|
||||
let model_metadata = read_lock(&self.models, |models| {
|
||||
models.get(&req.model).map(|m| m.artifact.clone())
|
||||
})?
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate tokens against vocabulary
|
||||
validate_tokens(req.tokens, model_metadata.manifest.shape.vocab)?;
|
||||
|
||||
// Execute with retries
|
||||
let (logits, response) = self.with_retries(|| {
|
||||
let mut stream = self.connect()?;
|
||||
self.send_request(stream.as_mut(), &req)
|
||||
})?;
|
||||
|
||||
let latency_ns = start.elapsed().as_nanos() as u32;
|
||||
|
||||
// Parse response
|
||||
let gate_decision = response.to_gate_decision();
|
||||
|
||||
// Build top-K if we got pairs
|
||||
let (logits_q, topk) = if self.config.topk_only {
|
||||
// logits contains (token_id, logit) pairs
|
||||
let pairs: Vec<(u16, i16)> = logits
|
||||
.chunks(2)
|
||||
.filter_map(|chunk| {
|
||||
if chunk.len() == 2 {
|
||||
Some((chunk[0] as u16, chunk[1]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
(vec![], Some(pairs))
|
||||
} else {
|
||||
// Full logits - use common compute_topk
|
||||
let topk = compute_topk(&logits, 16);
|
||||
(logits, Some(topk))
|
||||
};
|
||||
|
||||
// Copy packed fields to avoid alignment issues
|
||||
let resp_cycles = { response.cycles };
|
||||
let resp_latency_ns = { response.latency_ns };
|
||||
|
||||
// Create witness
|
||||
let witness = WitnessLog::new(
|
||||
model_metadata.model_hash(),
|
||||
model_metadata.quant_hash(),
|
||||
BackendKind::FpgaDaemon,
|
||||
resp_cycles,
|
||||
latency_ns.min(resp_latency_ns.max(latency_ns)),
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
// Update stats (with poison handling)
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.total_inferences += 1;
|
||||
stats.total_cycles += resp_cycles as u64;
|
||||
let n = stats.total_inferences;
|
||||
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(InferenceResult::new(logits_q, topk, witness))
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
// Send unload command to daemon
|
||||
self.with_retries(|| {
|
||||
let mut stream = self.connect()?;
|
||||
self.send_unload_command(stream.as_mut(), model)
|
||||
})?;
|
||||
|
||||
// Remove from local cache
|
||||
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
|
||||
|
||||
if removed {
|
||||
write_lock(&self.stats, |s| {
|
||||
s.models_loaded = s.models_loaded.saturating_sub(1);
|
||||
})?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::FpgaDaemon
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
self.stats.read().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait combining Read and Write for stream abstraction
|
||||
trait ReadWrite: Read + Write + Send {}
|
||||
impl<T: Read + Write + Send> ReadWrite for T {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_daemon_connection_types() {
|
||||
let unix = DaemonConnection::unix("/tmp/test.sock");
|
||||
assert!(matches!(unix, DaemonConnection::Unix(_)));
|
||||
|
||||
let tcp = DaemonConnection::tcp("127.0.0.1:8080");
|
||||
assert!(matches!(tcp, DaemonConnection::Tcp(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_defaults() {
|
||||
let config = DaemonConfig::default();
|
||||
assert_eq!(config.connect_timeout_ms, 5000);
|
||||
assert_eq!(config.retries, 3);
|
||||
assert!(config.topk_only);
|
||||
}
|
||||
}
|
||||
645
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/fpga_pcie.rs
vendored
Normal file
645
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/fpga_pcie.rs
vendored
Normal file
@@ -0,0 +1,645 @@
|
||||
//! FPGA PCIe backend
|
||||
//!
|
||||
//! Direct memory-mapped access to FPGA accelerator via PCIe.
|
||||
//! Uses DMA ring buffers for zero-copy, lock-free operation.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
use memmap2::{MmapMut, MmapOptions};
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{
|
||||
compute_topk, protocol, read_lock, validate_tokens, write_lock, BackendStats,
|
||||
TransformerBackend,
|
||||
};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{
|
||||
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
|
||||
};
|
||||
|
||||
/// PCIe device configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PcieConfig {
|
||||
/// Device path (e.g., /dev/ruvector0)
|
||||
pub device_path: String,
|
||||
/// BAR0 offset for control registers
|
||||
pub bar0_offset: usize,
|
||||
/// BAR1 offset for DMA buffers
|
||||
pub bar1_offset: usize,
|
||||
/// Number of request slots in ring buffer
|
||||
pub ring_slots: usize,
|
||||
/// Size of each request slot in bytes
|
||||
pub slot_size: usize,
|
||||
/// DMA timeout in milliseconds
|
||||
pub dma_timeout_ms: u64,
|
||||
/// Enable batch mode (multiple requests per DMA burst)
|
||||
pub batch_mode: bool,
|
||||
/// Maximum requests per batch
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for PcieConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_path: "/dev/ruvector0".into(),
|
||||
bar0_offset: 0,
|
||||
bar1_offset: 0x10000,
|
||||
ring_slots: 16,
|
||||
slot_size: 64 * 1024, // 64KB per slot
|
||||
dma_timeout_ms: 100,
|
||||
batch_mode: false,
|
||||
batch_size: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ring buffer slot state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
enum SlotState {
|
||||
Free = 0,
|
||||
Pending = 1,
|
||||
Complete = 2,
|
||||
Error = 3,
|
||||
}
|
||||
|
||||
/// DMA ring buffer for lock-free request/response handling
|
||||
struct DmaRingBuffer {
|
||||
/// Memory-mapped request buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
request_mmap: MmapMut,
|
||||
/// Memory-mapped response buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
response_mmap: MmapMut,
|
||||
/// Slot states
|
||||
slot_states: Vec<AtomicU32>,
|
||||
/// Producer index (next slot to write)
|
||||
producer_idx: AtomicU32,
|
||||
/// Consumer index (next slot to read)
|
||||
consumer_idx: AtomicU32,
|
||||
/// Number of slots
|
||||
num_slots: usize,
|
||||
/// Size per slot
|
||||
slot_size: usize,
|
||||
}
|
||||
|
||||
impl DmaRingBuffer {
|
||||
/// Create a new DMA ring buffer (mock for non-PCIe builds)
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
fn new(_config: &PcieConfig) -> Result<Self> {
|
||||
Err(Error::FeatureNotAvailable(
|
||||
"PCIe support not compiled".into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new DMA ring buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
fn new(config: &PcieConfig) -> Result<Self> {
|
||||
use std::fs::OpenOptions;
|
||||
|
||||
// Open device
|
||||
let file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.open(&config.device_path)
|
||||
.map_err(|e| Error::PcieError(format!("Failed to open device: {}", e)))?;
|
||||
|
||||
let total_size = config.ring_slots * config.slot_size;
|
||||
|
||||
// Map request buffer (BAR1)
|
||||
let request_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.offset(config.bar1_offset as u64)
|
||||
.len(total_size)
|
||||
.map_mut(&file)
|
||||
.map_err(|e| Error::PcieError(format!("Failed to map request buffer: {}", e)))?
|
||||
};
|
||||
|
||||
// Map response buffer (BAR1 + offset)
|
||||
let response_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.offset((config.bar1_offset + total_size) as u64)
|
||||
.len(total_size)
|
||||
.map_mut(&file)
|
||||
.map_err(|e| Error::PcieError(format!("Failed to map response buffer: {}", e)))?
|
||||
};
|
||||
|
||||
// Initialize slot states
|
||||
let slot_states: Vec<AtomicU32> = (0..config.ring_slots)
|
||||
.map(|_| AtomicU32::new(SlotState::Free as u32))
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
request_mmap,
|
||||
response_mmap,
|
||||
slot_states,
|
||||
producer_idx: AtomicU32::new(0),
|
||||
consumer_idx: AtomicU32::new(0),
|
||||
num_slots: config.ring_slots,
|
||||
slot_size: config.slot_size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Acquire a slot for writing
|
||||
fn acquire_slot(&self) -> Option<usize> {
|
||||
let producer = self.producer_idx.load(Ordering::Acquire);
|
||||
let slot = producer as usize % self.num_slots;
|
||||
|
||||
// Check if slot is free
|
||||
if self.slot_states[slot].load(Ordering::Acquire) == SlotState::Free as u32 {
|
||||
// Try to claim it
|
||||
if self.slot_states[slot]
|
||||
.compare_exchange(
|
||||
SlotState::Free as u32,
|
||||
SlotState::Pending as u32,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Relaxed,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
self.producer_idx
|
||||
.store(producer.wrapping_add(1), Ordering::Release);
|
||||
return Some(slot);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Release a slot after reading response
|
||||
fn release_slot(&self, slot: usize) {
|
||||
self.slot_states[slot].store(SlotState::Free as u32, Ordering::Release);
|
||||
self.consumer_idx.fetch_add(1, Ordering::AcqRel);
|
||||
}
|
||||
|
||||
/// Check if a slot is complete
|
||||
fn is_complete(&self, slot: usize) -> bool {
|
||||
self.slot_states[slot].load(Ordering::Acquire) == SlotState::Complete as u32
|
||||
}
|
||||
|
||||
/// Mark a slot as complete (called by FPGA via doorbell/interrupt)
|
||||
fn mark_complete(&self, slot: usize) {
|
||||
self.slot_states[slot].store(SlotState::Complete as u32, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Get request buffer for a slot
|
||||
#[cfg(feature = "pcie")]
|
||||
fn request_buffer(&mut self, slot: usize) -> &mut [u8] {
|
||||
let start = slot * self.slot_size;
|
||||
let end = start + self.slot_size;
|
||||
&mut self.request_mmap[start..end]
|
||||
}
|
||||
|
||||
/// Get response buffer for a slot
|
||||
#[cfg(feature = "pcie")]
|
||||
fn response_buffer(&self, slot: usize) -> &[u8] {
|
||||
let start = slot * self.slot_size;
|
||||
let end = start + self.slot_size;
|
||||
&self.response_mmap[start..end]
|
||||
}
|
||||
}
|
||||
|
||||
/// FPGA PCIe backend
|
||||
pub struct FpgaPcieBackend {
|
||||
/// Configuration
|
||||
config: PcieConfig,
|
||||
/// DMA ring buffer
|
||||
ring: Option<DmaRingBuffer>,
|
||||
/// Loaded models
|
||||
models: RwLock<HashMap<ModelId, ModelMetadata>>,
|
||||
/// Statistics
|
||||
stats: RwLock<BackendStats>,
|
||||
/// Total cycles counter
|
||||
total_cycles: AtomicU64,
|
||||
/// FPGA memory allocator state (next free offset)
|
||||
fpga_mem_offset: AtomicU64,
|
||||
/// FPGA memory total size (2GB default)
|
||||
fpga_mem_size: u64,
|
||||
}
|
||||
|
||||
/// Cached model metadata
|
||||
struct ModelMetadata {
|
||||
artifact: ModelArtifact,
|
||||
fpga_slot: u32, // Slot in FPGA memory where model is loaded
|
||||
weights_offset: u64, // Offset in FPGA DDR where weights are stored
|
||||
weights_size: usize, // Size of weights in bytes
|
||||
}
|
||||
|
||||
/// FPGA DDR base offset for model weights
|
||||
const FPGA_DDR_MODEL_BASE: u64 = 0x1000_0000; // 256MB offset
|
||||
|
||||
impl FpgaPcieBackend {
|
||||
/// Create a new PCIe backend
|
||||
pub fn new(config: PcieConfig) -> Result<Self> {
|
||||
#[cfg(feature = "pcie")]
|
||||
let ring = Some(DmaRingBuffer::new(&config)?);
|
||||
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
let ring = None;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
ring,
|
||||
models: RwLock::new(HashMap::new()),
|
||||
stats: RwLock::new(BackendStats::default()),
|
||||
total_cycles: AtomicU64::new(0),
|
||||
fpga_mem_offset: AtomicU64::new(FPGA_DDR_MODEL_BASE),
|
||||
fpga_mem_size: 2 * 1024 * 1024 * 1024, // 2GB
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_device() -> Result<Self> {
|
||||
Self::new(PcieConfig::default())
|
||||
}
|
||||
|
||||
/// Write inference request to DMA buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
fn write_request(
|
||||
&self,
|
||||
ring: &mut DmaRingBuffer,
|
||||
slot: usize,
|
||||
req: &InferenceRequest,
|
||||
) -> Result<()> {
|
||||
use crate::backend::{protocol, RequestFrame};
|
||||
|
||||
let buffer = ring.request_buffer(slot);
|
||||
let shape = &req.shape;
|
||||
|
||||
// Write header
|
||||
let frame = RequestFrame::new(shape.seq_len, shape.d_model, shape.vocab, &req.model, 0, 16);
|
||||
let header = frame.to_bytes();
|
||||
buffer[..protocol::HEADER_SIZE].copy_from_slice(&header);
|
||||
|
||||
let mut offset = protocol::HEADER_SIZE;
|
||||
|
||||
// Write tokens
|
||||
for &token in req.tokens {
|
||||
buffer[offset..offset + 2].copy_from_slice(&token.to_le_bytes());
|
||||
offset += 2;
|
||||
}
|
||||
|
||||
// Write mask
|
||||
buffer[offset..offset + req.attn_mask.len()].copy_from_slice(req.attn_mask);
|
||||
offset += req.attn_mask.len();
|
||||
|
||||
// Write gate hint
|
||||
buffer[offset..offset + 2].copy_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
|
||||
offset += 2;
|
||||
buffer[offset] = req.gate_hint.boundary_crossed as u8;
|
||||
offset += 1;
|
||||
buffer[offset] = req.gate_hint.max_compute_class as u8;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read inference response from DMA buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
fn read_response(
|
||||
&self,
|
||||
ring: &DmaRingBuffer,
|
||||
slot: usize,
|
||||
shape: &crate::types::FixedShape,
|
||||
) -> Result<(Vec<i16>, u32, u32, GateDecision)> {
|
||||
use crate::backend::ResponseFrame;
|
||||
|
||||
let buffer = ring.response_buffer(slot);
|
||||
|
||||
// Read response header
|
||||
let response = ResponseFrame::from_bytes(&buffer[..14].try_into().unwrap());
|
||||
|
||||
// Check status
|
||||
if response.status != 0 {
|
||||
return Err(Error::backend(format!(
|
||||
"FPGA error: status {}",
|
||||
response.status
|
||||
)));
|
||||
}
|
||||
|
||||
// Read logits
|
||||
let vocab = shape.vocab as usize;
|
||||
let mut logits = Vec::with_capacity(vocab);
|
||||
let mut offset = 14;
|
||||
|
||||
for _ in 0..vocab {
|
||||
let value = i16::from_le_bytes([buffer[offset], buffer[offset + 1]]);
|
||||
logits.push(value);
|
||||
offset += 2;
|
||||
}
|
||||
|
||||
Ok((
|
||||
logits,
|
||||
response.cycles,
|
||||
response.latency_ns,
|
||||
response.to_gate_decision(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Ring doorbell to notify FPGA of pending request
|
||||
#[cfg(feature = "pcie")]
|
||||
fn ring_doorbell(&self, _slot: usize) {
|
||||
// In a real implementation, this would write to a control register
|
||||
// to notify the FPGA that a new request is available
|
||||
}
|
||||
|
||||
/// Wait for response with polling
|
||||
fn wait_for_response(&self, ring: &DmaRingBuffer, slot: usize, timeout_ms: u64) -> Result<()> {
|
||||
let start = Instant::now();
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
|
||||
while !ring.is_complete(slot) {
|
||||
if start.elapsed() > timeout {
|
||||
return Err(Error::Timeout { ms: timeout_ms });
|
||||
}
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Allocate FPGA DDR memory for model weights
|
||||
fn allocate_fpga_memory(&self, size: usize) -> Result<u64> {
|
||||
// Align to 4KB boundary for DMA efficiency
|
||||
let aligned_size = (size + 0xFFF) & !0xFFF;
|
||||
|
||||
// Atomic allocation (simple bump allocator)
|
||||
let offset = self
|
||||
.fpga_mem_offset
|
||||
.fetch_add(aligned_size as u64, Ordering::SeqCst);
|
||||
|
||||
// Check for overflow
|
||||
if offset + aligned_size as u64 > self.fpga_mem_size {
|
||||
// Roll back allocation
|
||||
self.fpga_mem_offset
|
||||
.fetch_sub(aligned_size as u64, Ordering::SeqCst);
|
||||
return Err(Error::ResourceExhausted("FPGA DDR memory full".into()));
|
||||
}
|
||||
|
||||
Ok(offset)
|
||||
}
|
||||
|
||||
/// Upload model weights to FPGA DDR via DMA
|
||||
#[cfg(feature = "pcie")]
|
||||
fn upload_weights_dma(&self, weights: &[u8], fpga_offset: u64) -> Result<()> {
|
||||
// DMA transfer configuration
|
||||
const DMA_CHUNK_SIZE: usize = 64 * 1024; // 64KB per transfer
|
||||
|
||||
let ring = self
|
||||
.ring
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
|
||||
|
||||
// Transfer weights in chunks
|
||||
let mut transferred = 0usize;
|
||||
while transferred < weights.len() {
|
||||
let chunk_size = DMA_CHUNK_SIZE.min(weights.len() - transferred);
|
||||
|
||||
// Acquire a DMA slot
|
||||
let slot = loop {
|
||||
if let Some(s) = ring.acquire_slot() {
|
||||
break s;
|
||||
}
|
||||
std::hint::spin_loop();
|
||||
};
|
||||
|
||||
// Write DMA command to slot (simplified protocol)
|
||||
// In real hardware:
|
||||
// - Write target FPGA DDR address
|
||||
// - Write source offset in slot
|
||||
// - Write transfer length
|
||||
// - Ring doorbell
|
||||
|
||||
// For now, we simulate the DMA by marking complete
|
||||
ring.mark_complete(slot);
|
||||
|
||||
// Wait for completion
|
||||
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
|
||||
|
||||
// Release slot
|
||||
ring.release_slot(slot);
|
||||
|
||||
transferred += chunk_size;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Free FPGA DDR memory (simplified - real impl would use proper allocator)
|
||||
fn free_fpga_memory(&self, _offset: u64, _size: usize) {
|
||||
// In a production system, this would:
|
||||
// 1. Mark the memory region as free in an allocator
|
||||
// 2. Potentially compact memory if fragmentation is high
|
||||
// 3. Update hardware memory management unit
|
||||
//
|
||||
// For this implementation, we use a bump allocator without free.
|
||||
// Memory is reclaimed when all models are unloaded.
|
||||
}
|
||||
|
||||
/// Check if all models are unloaded and reset memory allocator
|
||||
fn maybe_reset_allocator(&self) {
|
||||
let models_empty = read_lock(&self.models, |m| m.is_empty()).unwrap_or(false);
|
||||
if models_empty {
|
||||
self.fpga_mem_offset
|
||||
.store(FPGA_DDR_MODEL_BASE, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for FpgaPcieBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
{
|
||||
let _ = artifact;
|
||||
return Err(Error::FeatureNotAvailable(
|
||||
"PCIe support not compiled".into(),
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
{
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
let model_id = artifact.model_id();
|
||||
let weights_size = artifact.weights.len();
|
||||
|
||||
// Allocate FPGA DDR memory for weights
|
||||
let weights_offset = self.allocate_fpga_memory(weights_size)?;
|
||||
|
||||
// Upload model weights to FPGA DDR via DMA
|
||||
if let Err(e) = self.upload_weights_dma(&artifact.weights, weights_offset) {
|
||||
// Roll back allocation on failure
|
||||
self.free_fpga_memory(weights_offset, weights_size);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Get slot number for this model
|
||||
let fpga_slot = read_lock(&self.models, |m| m.len() as u32)?;
|
||||
|
||||
// Store metadata
|
||||
write_lock(&self.models, |models| {
|
||||
models.insert(
|
||||
model_id,
|
||||
ModelMetadata {
|
||||
artifact: artifact.clone(),
|
||||
fpga_slot,
|
||||
weights_offset,
|
||||
weights_size,
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded += 1;
|
||||
})?;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
{
|
||||
let _ = req;
|
||||
return Err(Error::FeatureNotAvailable(
|
||||
"PCIe support not compiled".into(),
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
{
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Get model metadata
|
||||
let model_artifact = read_lock(&self.models, |models| {
|
||||
models.get(&req.model).map(|m| m.artifact.clone())
|
||||
})?
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate tokens against vocabulary
|
||||
validate_tokens(req.tokens, model_artifact.manifest.shape.vocab)?;
|
||||
|
||||
// Get ring buffer
|
||||
let ring = self
|
||||
.ring
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
|
||||
|
||||
// Acquire slot
|
||||
let slot = ring
|
||||
.acquire_slot()
|
||||
.ok_or_else(|| Error::ResourceExhausted("No DMA slots available".into()))?;
|
||||
|
||||
// Write request (need mutable access - simplified for now)
|
||||
// In production, this would use proper interior mutability
|
||||
// self.write_request(ring, slot, &req)?;
|
||||
|
||||
// Ring doorbell
|
||||
// self.ring_doorbell(slot);
|
||||
|
||||
// Wait for response
|
||||
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
|
||||
|
||||
// Read response
|
||||
let (logits, cycles, fpga_latency_ns, gate_decision) =
|
||||
self.read_response(ring, slot, &req.shape)?;
|
||||
|
||||
// Release slot
|
||||
ring.release_slot(slot);
|
||||
|
||||
let latency_ns = start.elapsed().as_nanos() as u32;
|
||||
|
||||
// Compute top-K using common utility
|
||||
let topk = compute_topk(&logits, 16);
|
||||
|
||||
// Create witness
|
||||
let witness = WitnessLog::new(
|
||||
model_artifact.model_hash(),
|
||||
model_artifact.quant_hash(),
|
||||
BackendKind::FpgaPcie,
|
||||
cycles,
|
||||
fpga_latency_ns.min(latency_ns),
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
// Update stats
|
||||
self.total_cycles
|
||||
.fetch_add(cycles as u64, Ordering::Relaxed);
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.total_inferences += 1;
|
||||
stats.total_cycles = self.total_cycles.load(Ordering::Relaxed);
|
||||
let n = stats.total_inferences;
|
||||
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(InferenceResult::new(logits, Some(topk), witness))
|
||||
}
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
// Remove from cache and get memory info for deallocation
|
||||
let removed = write_lock(&self.models, |models| {
|
||||
models
|
||||
.remove(&model)
|
||||
.map(|m| (m.weights_offset, m.weights_size))
|
||||
})?;
|
||||
|
||||
if let Some((offset, size)) = removed {
|
||||
// Free FPGA DDR memory
|
||||
self.free_fpga_memory(offset, size);
|
||||
|
||||
// Check if we can reset the allocator
|
||||
self.maybe_reset_allocator();
|
||||
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded = stats.models_loaded.saturating_sub(1);
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::FpgaPcie
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pcie_config_default() {
|
||||
let config = PcieConfig::default();
|
||||
assert_eq!(config.ring_slots, 16);
|
||||
assert_eq!(config.slot_size, 64 * 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slot_state_values() {
|
||||
assert_eq!(SlotState::Free as u8, 0);
|
||||
assert_eq!(SlotState::Pending as u8, 1);
|
||||
assert_eq!(SlotState::Complete as u8, 2);
|
||||
}
|
||||
}
|
||||
428
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/mod.rs
vendored
Normal file
428
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/mod.rs
vendored
Normal file
@@ -0,0 +1,428 @@
|
||||
//! Backend implementations for FPGA Transformer
|
||||
//!
|
||||
//! All backends implement the `TransformerBackend` trait for uniform API.
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::error::Result;
|
||||
use crate::types::{InferenceRequest, InferenceResult, ModelId};
|
||||
|
||||
#[cfg(feature = "native_sim")]
|
||||
pub mod native_sim;
|
||||
|
||||
#[cfg(feature = "daemon")]
|
||||
pub mod fpga_daemon;
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
pub mod fpga_pcie;
|
||||
|
||||
#[cfg(feature = "wasm")]
|
||||
pub mod wasm_sim;
|
||||
|
||||
/// Trait for transformer inference backends
|
||||
///
|
||||
/// All backends must be thread-safe and implement the same API.
|
||||
pub trait TransformerBackend: Send + Sync {
|
||||
/// Load a model artifact and return its ID
|
||||
///
|
||||
/// The artifact is validated, test vectors are run, and
|
||||
/// the model is prepared for inference.
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId>;
|
||||
|
||||
/// Run inference on the given request
|
||||
///
|
||||
/// The request must specify a model that has been loaded.
|
||||
/// Returns the inference result with witness log.
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult>;
|
||||
|
||||
/// Unload a model to free resources
|
||||
fn unload(&self, model: ModelId) -> Result<()>;
|
||||
|
||||
/// Check if a model is loaded
|
||||
fn is_loaded(&self, model: ModelId) -> bool;
|
||||
|
||||
/// Get the backend kind
|
||||
fn kind(&self) -> crate::types::BackendKind;
|
||||
|
||||
/// Get backend-specific statistics
|
||||
fn stats(&self) -> BackendStats {
|
||||
BackendStats::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BackendStats {
|
||||
/// Number of models currently loaded
|
||||
pub models_loaded: usize,
|
||||
/// Total inferences performed
|
||||
pub total_inferences: u64,
|
||||
/// Total cycles consumed (FPGA only)
|
||||
pub total_cycles: u64,
|
||||
/// Average latency in nanoseconds
|
||||
pub avg_latency_ns: u64,
|
||||
/// P99 latency in nanoseconds
|
||||
pub p99_latency_ns: u64,
|
||||
/// Number of early exits
|
||||
pub early_exits: u64,
|
||||
/// Number of skipped inferences
|
||||
pub skipped: u64,
|
||||
}
|
||||
|
||||
/// Protocol constants for daemon/PCIe communication
|
||||
pub mod protocol {
|
||||
/// Magic number for frame validation
|
||||
pub const MAGIC: u32 = 0x5256_5846; // "RVXF" - RuVector FPGA
|
||||
|
||||
/// Current protocol version
|
||||
pub const VERSION: u16 = 1;
|
||||
|
||||
/// Frame header size in bytes
|
||||
pub const HEADER_SIZE: usize = 24;
|
||||
|
||||
/// Maximum payload size
|
||||
pub const MAX_PAYLOAD: usize = 1024 * 1024; // 1MB
|
||||
|
||||
/// Request flags
|
||||
pub mod flags {
|
||||
/// Return only top-K predictions
|
||||
pub const TOPK_ONLY: u16 = 0x0001;
|
||||
/// Use LUT-based softmax
|
||||
pub const LUT_SOFTMAX: u16 = 0x0002;
|
||||
/// Enable early exit
|
||||
pub const EARLY_EXIT: u16 = 0x0004;
|
||||
/// Return detailed witness
|
||||
pub const WITNESS_DETAIL: u16 = 0x0008;
|
||||
}
|
||||
|
||||
/// Response status codes
|
||||
pub mod status {
|
||||
/// Success
|
||||
pub const OK: u16 = 0;
|
||||
/// Model not found
|
||||
pub const MODEL_NOT_FOUND: u16 = 1;
|
||||
/// Shape mismatch
|
||||
pub const SHAPE_MISMATCH: u16 = 2;
|
||||
/// Gate blocked
|
||||
pub const GATE_BLOCKED: u16 = 3;
|
||||
/// Internal error
|
||||
pub const INTERNAL_ERROR: u16 = 0xFFFF;
|
||||
}
|
||||
}
|
||||
|
||||
/// Request frame for wire protocol
|
||||
#[repr(C, packed)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RequestFrame {
|
||||
/// Magic number (MAGIC)
|
||||
pub magic: u32,
|
||||
/// Protocol version
|
||||
pub protocol: u16,
|
||||
/// Sequence length
|
||||
pub seq_len: u16,
|
||||
/// Model dimension
|
||||
pub d_model: u16,
|
||||
/// Vocabulary size
|
||||
pub vocab: u16,
|
||||
/// Model ID (lower 32 bits)
|
||||
pub model_id_low: u32,
|
||||
/// Model ID (upper 32 bits)
|
||||
pub model_id_high: u32,
|
||||
/// Request flags
|
||||
pub flags: u16,
|
||||
/// Top-K count (if TOPK_ONLY flag set)
|
||||
pub topk: u16,
|
||||
}
|
||||
|
||||
impl RequestFrame {
|
||||
/// Create a new request frame
|
||||
pub fn new(
|
||||
seq_len: u16,
|
||||
d_model: u16,
|
||||
vocab: u32,
|
||||
model_id: &ModelId,
|
||||
flags: u16,
|
||||
topk: u16,
|
||||
) -> Self {
|
||||
let id_bytes = model_id.as_bytes();
|
||||
let model_id_low = u32::from_le_bytes([id_bytes[0], id_bytes[1], id_bytes[2], id_bytes[3]]);
|
||||
let model_id_high =
|
||||
u32::from_le_bytes([id_bytes[4], id_bytes[5], id_bytes[6], id_bytes[7]]);
|
||||
|
||||
Self {
|
||||
magic: protocol::MAGIC,
|
||||
protocol: protocol::VERSION,
|
||||
seq_len,
|
||||
d_model,
|
||||
vocab: (vocab & 0xFFFF) as u16,
|
||||
model_id_low,
|
||||
model_id_high,
|
||||
flags,
|
||||
topk,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize to bytes
|
||||
pub fn to_bytes(&self) -> [u8; protocol::HEADER_SIZE] {
|
||||
let mut bytes = [0u8; protocol::HEADER_SIZE];
|
||||
bytes[0..4].copy_from_slice(&self.magic.to_le_bytes());
|
||||
bytes[4..6].copy_from_slice(&self.protocol.to_le_bytes());
|
||||
bytes[6..8].copy_from_slice(&self.seq_len.to_le_bytes());
|
||||
bytes[8..10].copy_from_slice(&self.d_model.to_le_bytes());
|
||||
bytes[10..12].copy_from_slice(&self.vocab.to_le_bytes());
|
||||
bytes[12..16].copy_from_slice(&self.model_id_low.to_le_bytes());
|
||||
bytes[16..20].copy_from_slice(&self.model_id_high.to_le_bytes());
|
||||
bytes[20..22].copy_from_slice(&self.flags.to_le_bytes());
|
||||
bytes[22..24].copy_from_slice(&self.topk.to_le_bytes());
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
/// Response frame from wire protocol
|
||||
#[repr(C, packed)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ResponseFrame {
|
||||
/// Status code
|
||||
pub status: u16,
|
||||
/// Latency in nanoseconds
|
||||
pub latency_ns: u32,
|
||||
/// Compute cycles
|
||||
pub cycles: u32,
|
||||
/// Gate decision (packed)
|
||||
pub gate_decision: u8,
|
||||
/// Exit layer (if early exit)
|
||||
pub exit_layer: u8,
|
||||
/// Skip reason (if skipped)
|
||||
pub skip_reason: u8,
|
||||
/// Reserved
|
||||
pub reserved: u8,
|
||||
}
|
||||
|
||||
impl ResponseFrame {
|
||||
/// Parse from bytes
|
||||
pub fn from_bytes(bytes: &[u8; 14]) -> Self {
|
||||
Self {
|
||||
status: u16::from_le_bytes([bytes[0], bytes[1]]),
|
||||
latency_ns: u32::from_le_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
|
||||
cycles: u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]),
|
||||
gate_decision: bytes[10],
|
||||
exit_layer: bytes[11],
|
||||
skip_reason: bytes[12],
|
||||
reserved: bytes[13],
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert gate decision to enum
|
||||
pub fn to_gate_decision(&self) -> crate::types::GateDecision {
|
||||
match self.gate_decision {
|
||||
0 => crate::types::GateDecision::RanFull,
|
||||
1 => crate::types::GateDecision::EarlyExit {
|
||||
layer: self.exit_layer,
|
||||
},
|
||||
2 => crate::types::GateDecision::Skipped {
|
||||
reason: match self.skip_reason {
|
||||
0 => crate::types::SkipReason::LowCoherence,
|
||||
1 => crate::types::SkipReason::PolicyDenied,
|
||||
_ => crate::types::SkipReason::BudgetExceeded,
|
||||
},
|
||||
},
|
||||
_ => crate::types::GateDecision::RanFull,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate CRC32 checksum for frame validation
|
||||
pub fn crc32(data: &[u8]) -> u32 {
|
||||
// Simple CRC32 implementation (could use crc32fast crate in production)
|
||||
let mut crc: u32 = 0xFFFFFFFF;
|
||||
for &byte in data {
|
||||
crc ^= byte as u32;
|
||||
for _ in 0..8 {
|
||||
crc = if crc & 1 != 0 {
|
||||
(crc >> 1) ^ 0xEDB88320
|
||||
} else {
|
||||
crc >> 1
|
||||
};
|
||||
}
|
||||
}
|
||||
!crc
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Common utilities shared across backends
|
||||
// ============================================================================
|
||||
|
||||
/// Compute top-K predictions from logits
|
||||
/// Returns sorted (token_id, logit) pairs, descending by logit value
|
||||
#[inline]
|
||||
pub fn compute_topk(logits: &[i16], k: usize) -> Vec<(u16, i16)> {
|
||||
if logits.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// For small K, partial sort is faster
|
||||
if k <= 32 && logits.len() > 100 {
|
||||
// Use partial heap-based selection
|
||||
let mut heap: Vec<(i16, u16)> = Vec::with_capacity(k + 1);
|
||||
for (i, &v) in logits.iter().enumerate() {
|
||||
if heap.len() < k {
|
||||
heap.push((v, i as u16));
|
||||
if heap.len() == k {
|
||||
// Heapify
|
||||
heap.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
}
|
||||
} else if v > heap[0].0 {
|
||||
heap[0] = (v, i as u16);
|
||||
// Maintain min-heap property
|
||||
let mut idx = 0;
|
||||
while idx * 2 + 1 < heap.len() {
|
||||
let left = idx * 2 + 1;
|
||||
let right = idx * 2 + 2;
|
||||
let mut smallest = idx;
|
||||
if heap[left].0 < heap[smallest].0 {
|
||||
smallest = left;
|
||||
}
|
||||
if right < heap.len() && heap[right].0 < heap[smallest].0 {
|
||||
smallest = right;
|
||||
}
|
||||
if smallest == idx {
|
||||
break;
|
||||
}
|
||||
heap.swap(idx, smallest);
|
||||
idx = smallest;
|
||||
}
|
||||
}
|
||||
}
|
||||
heap.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
heap.into_iter().map(|(v, i)| (i, v)).collect()
|
||||
} else {
|
||||
// Full sort for small arrays
|
||||
let mut indexed: Vec<(usize, i16)> = logits.iter().cloned().enumerate().collect();
|
||||
indexed.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
indexed
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(i, v)| (i as u16, v))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to safely read from RwLock, returning error on poison
|
||||
pub fn read_lock<T, R>(
|
||||
lock: &std::sync::RwLock<T>,
|
||||
f: impl FnOnce(&T) -> R,
|
||||
) -> crate::error::Result<R> {
|
||||
lock.read()
|
||||
.map(|guard| f(&*guard))
|
||||
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (read)".into()))
|
||||
}
|
||||
|
||||
/// Helper to safely write to RwLock, returning error on poison
|
||||
pub fn write_lock<T, R>(
|
||||
lock: &std::sync::RwLock<T>,
|
||||
f: impl FnOnce(&mut T) -> R,
|
||||
) -> crate::error::Result<R> {
|
||||
lock.write()
|
||||
.map(|mut guard| f(&mut *guard))
|
||||
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (write)".into()))
|
||||
}
|
||||
|
||||
/// Validate token indices against vocabulary size
|
||||
#[inline]
|
||||
pub fn validate_tokens(tokens: &[u16], vocab_size: u32) -> crate::error::Result<()> {
|
||||
for (i, &token) in tokens.iter().enumerate() {
|
||||
if token as u32 >= vocab_size {
|
||||
return Err(crate::error::Error::InvalidConfig(format!(
|
||||
"Token {} at index {} exceeds vocabulary size {}",
|
||||
token, i, vocab_size
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build witness log from inference metadata
|
||||
pub fn build_witness(
|
||||
model_hash: [u8; 32],
|
||||
quant_hash: [u8; 32],
|
||||
backend: crate::types::BackendKind,
|
||||
cycles: u32,
|
||||
latency_ns: u32,
|
||||
gate_decision: crate::types::GateDecision,
|
||||
) -> crate::types::WitnessLog {
|
||||
crate::types::WitnessLog::new(
|
||||
model_hash,
|
||||
quant_hash,
|
||||
backend,
|
||||
cycles,
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
)
|
||||
}
|
||||
|
||||
/// Command types for daemon protocol
|
||||
pub mod commands {
|
||||
/// Load model command
|
||||
pub const LOAD_MODEL: u8 = 0x01;
|
||||
/// Unload model command
|
||||
pub const UNLOAD_MODEL: u8 = 0x02;
|
||||
/// Inference request command
|
||||
pub const INFER: u8 = 0x03;
|
||||
/// Ping/health check command
|
||||
pub const PING: u8 = 0x04;
|
||||
/// Get status command
|
||||
pub const STATUS: u8 = 0x05;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_request_frame_roundtrip() {
|
||||
let model_id = ModelId::new([0x42u8; 32]);
|
||||
let frame = RequestFrame::new(64, 256, 32000, &model_id, 0, 16);
|
||||
let bytes = frame.to_bytes();
|
||||
|
||||
assert_eq!(bytes.len(), protocol::HEADER_SIZE);
|
||||
assert_eq!(&bytes[0..4], &protocol::MAGIC.to_le_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crc32() {
|
||||
let data = b"test data";
|
||||
let crc = crc32(data);
|
||||
// CRC should be consistent
|
||||
assert_eq!(crc, crc32(data));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_topk() {
|
||||
let logits: Vec<i16> = vec![100, 50, 300, 200, 150];
|
||||
let topk = compute_topk(&logits, 3);
|
||||
|
||||
assert_eq!(topk.len(), 3);
|
||||
assert_eq!(topk[0], (2, 300)); // Index 2, value 300
|
||||
assert_eq!(topk[1], (3, 200)); // Index 3, value 200
|
||||
assert_eq!(topk[2], (4, 150)); // Index 4, value 150
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_topk_large() {
|
||||
let logits: Vec<i16> = (0..1000).map(|i| (i * 7 % 500) as i16).collect();
|
||||
let topk = compute_topk(&logits, 10);
|
||||
|
||||
assert_eq!(topk.len(), 10);
|
||||
// Should be sorted descending
|
||||
for i in 1..topk.len() {
|
||||
assert!(topk[i - 1].1 >= topk[i].1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tokens() {
|
||||
assert!(validate_tokens(&[0, 1, 2], 100).is_ok());
|
||||
assert!(validate_tokens(&[99], 100).is_ok());
|
||||
assert!(validate_tokens(&[100], 100).is_err());
|
||||
assert!(validate_tokens(&[0, 50, 101], 100).is_err());
|
||||
}
|
||||
}
|
||||
544
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/native_sim.rs
vendored
Normal file
544
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/native_sim.rs
vendored
Normal file
@@ -0,0 +1,544 @@
|
||||
//! Native Rust simulator backend
|
||||
//!
|
||||
//! Provides a pure-Rust implementation of the transformer inference
|
||||
//! for testing, development, and fallback when no FPGA is available.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{
|
||||
compute_topk, read_lock, validate_tokens, write_lock, BackendStats, TransformerBackend,
|
||||
};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::gating::CoherenceGate;
|
||||
use crate::quant::{dequantize_i8, quantize_i16, softmax_lut};
|
||||
use crate::types::{
|
||||
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
|
||||
QuantSpec, SkipReason, WitnessLog,
|
||||
};
|
||||
|
||||
/// Loaded model data for native simulation
|
||||
struct LoadedModel {
|
||||
/// Model artifact (contains weights and config)
|
||||
artifact: ModelArtifact,
|
||||
/// Precomputed embedding matrix (dequantized for sim)
|
||||
embeddings: Vec<f32>,
|
||||
/// Layer weights (simplified for simulation)
|
||||
layers: Vec<LayerWeights>,
|
||||
/// Output projection
|
||||
output_proj: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Simplified layer weights for simulation
|
||||
struct LayerWeights {
|
||||
/// Attention Q projection
|
||||
wq: Vec<f32>,
|
||||
/// Attention K projection
|
||||
wk: Vec<f32>,
|
||||
/// Attention V projection
|
||||
wv: Vec<f32>,
|
||||
/// Attention output projection
|
||||
wo: Vec<f32>,
|
||||
/// FFN up projection
|
||||
w1: Vec<f32>,
|
||||
/// FFN down projection
|
||||
w2: Vec<f32>,
|
||||
/// Layer norm weights
|
||||
ln1_weight: Vec<f32>,
|
||||
ln2_weight: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Native simulator backend
|
||||
pub struct NativeSimBackend {
|
||||
/// Loaded models
|
||||
models: RwLock<HashMap<ModelId, Arc<LoadedModel>>>,
|
||||
/// Coherence gate
|
||||
gate: Arc<dyn CoherenceGate>,
|
||||
/// Statistics
|
||||
stats: RwLock<BackendStats>,
|
||||
/// Configuration
|
||||
config: NativeSimConfig,
|
||||
}
|
||||
|
||||
/// Configuration for native simulator
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NativeSimConfig {
|
||||
/// Maximum models to keep loaded
|
||||
pub max_models: usize,
|
||||
/// Enable detailed tracing
|
||||
pub trace: bool,
|
||||
/// Use LUT-based softmax
|
||||
pub lut_softmax: bool,
|
||||
/// Number of layers to simulate (0 = all)
|
||||
pub max_layers: usize,
|
||||
}
|
||||
|
||||
impl Default for NativeSimConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_models: 8,
|
||||
trace: false,
|
||||
lut_softmax: true,
|
||||
max_layers: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeSimBackend {
|
||||
/// Create a new native simulator backend
|
||||
pub fn new(gate: Arc<dyn CoherenceGate>) -> Self {
|
||||
Self::with_config(gate, NativeSimConfig::default())
|
||||
}
|
||||
|
||||
/// Create with custom configuration
|
||||
pub fn with_config(gate: Arc<dyn CoherenceGate>, config: NativeSimConfig) -> Self {
|
||||
Self {
|
||||
models: RwLock::new(HashMap::new()),
|
||||
gate,
|
||||
stats: RwLock::new(BackendStats::default()),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the core transformer inference
|
||||
fn run_inference(
|
||||
&self,
|
||||
model: &LoadedModel,
|
||||
tokens: &[u16],
|
||||
_attn_mask: &[u8],
|
||||
gate_hint: &GateHint,
|
||||
) -> Result<(Vec<i16>, GateDecision)> {
|
||||
let shape = &model.artifact.manifest.shape;
|
||||
let num_layers = model.layers.len();
|
||||
|
||||
// Check preflight gate
|
||||
let preflight = self.gate.preflight(gate_hint);
|
||||
if let GateDecision::Skipped { reason } = preflight {
|
||||
return Ok((
|
||||
vec![0i16; shape.vocab as usize],
|
||||
GateDecision::Skipped { reason },
|
||||
));
|
||||
}
|
||||
|
||||
// Initialize hidden states from embeddings
|
||||
let d_model = shape.d_model as usize;
|
||||
let seq_len = tokens.len();
|
||||
let mut hidden = vec![0.0f32; seq_len * d_model];
|
||||
|
||||
// Lookup embeddings
|
||||
for (i, &token) in tokens.iter().enumerate() {
|
||||
let offset = (token as usize) * d_model;
|
||||
if offset + d_model <= model.embeddings.len() {
|
||||
hidden[i * d_model..(i + 1) * d_model]
|
||||
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
|
||||
}
|
||||
}
|
||||
|
||||
// Run through layers
|
||||
let max_layers = if self.config.max_layers > 0 {
|
||||
self.config.max_layers.min(num_layers)
|
||||
} else {
|
||||
num_layers
|
||||
};
|
||||
|
||||
for layer_idx in 0..max_layers {
|
||||
let layer = &model.layers[layer_idx];
|
||||
|
||||
// Check layer checkpoint for early exit
|
||||
let coherence_signal = self.compute_coherence_signal(&hidden);
|
||||
if let Some(decision) = self.gate.checkpoint(layer_idx as u8, coherence_signal) {
|
||||
if let GateDecision::EarlyExit { layer } = decision {
|
||||
// Early exit - compute output from current hidden state
|
||||
let logits = self.compute_output(&hidden, &model.output_proj, shape);
|
||||
return Ok((logits, GateDecision::EarlyExit { layer }));
|
||||
}
|
||||
}
|
||||
|
||||
// Simplified attention + FFN (for simulation purposes)
|
||||
hidden = self.run_layer(&hidden, layer, shape);
|
||||
}
|
||||
|
||||
// Compute output logits
|
||||
let logits = self.compute_output(&hidden, &model.output_proj, shape);
|
||||
|
||||
Ok((logits, GateDecision::RanFull))
|
||||
}
|
||||
|
||||
/// Run a single transformer layer
|
||||
fn run_layer(&self, hidden: &[f32], layer: &LayerWeights, shape: &FixedShape) -> Vec<f32> {
|
||||
let d_model = shape.d_model as usize;
|
||||
let seq_len = hidden.len() / d_model;
|
||||
|
||||
// Simplified layer computation
|
||||
// In a real implementation, this would do full attention + FFN
|
||||
|
||||
let mut output = hidden.to_vec();
|
||||
|
||||
// Layer norm 1
|
||||
for t in 0..seq_len {
|
||||
let start = t * d_model;
|
||||
let end = start + d_model;
|
||||
layer_norm_inplace(&mut output[start..end], &layer.ln1_weight);
|
||||
}
|
||||
|
||||
// Simplified attention (just apply output projection as placeholder)
|
||||
// Real implementation would compute Q, K, V, attention scores, etc.
|
||||
if !layer.wo.is_empty() {
|
||||
let mut attn_out = vec![0.0f32; output.len()];
|
||||
for t in 0..seq_len {
|
||||
for i in 0..d_model {
|
||||
let mut sum = 0.0f32;
|
||||
for j in 0..d_model.min(layer.wo.len() / d_model) {
|
||||
sum += output[t * d_model + j] * layer.wo[j * d_model + i];
|
||||
}
|
||||
attn_out[t * d_model + i] = sum;
|
||||
}
|
||||
}
|
||||
// Residual connection
|
||||
for i in 0..output.len() {
|
||||
output[i] += attn_out[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Layer norm 2
|
||||
for t in 0..seq_len {
|
||||
let start = t * d_model;
|
||||
let end = start + d_model;
|
||||
layer_norm_inplace(&mut output[start..end], &layer.ln2_weight);
|
||||
}
|
||||
|
||||
// Simplified FFN (SwiGLU-like)
|
||||
if !layer.w1.is_empty() && !layer.w2.is_empty() {
|
||||
let ffn_dim = layer.w1.len() / d_model;
|
||||
let mut ffn_out = vec![0.0f32; output.len()];
|
||||
|
||||
for t in 0..seq_len {
|
||||
// Up projection
|
||||
let mut up = vec![0.0f32; ffn_dim];
|
||||
for i in 0..ffn_dim {
|
||||
for j in 0..d_model {
|
||||
up[i] += output[t * d_model + j] * layer.w1[j * ffn_dim + i];
|
||||
}
|
||||
// SiLU activation
|
||||
up[i] = up[i] * sigmoid(up[i]);
|
||||
}
|
||||
|
||||
// Down projection
|
||||
for i in 0..d_model {
|
||||
for j in 0..ffn_dim.min(layer.w2.len() / d_model) {
|
||||
ffn_out[t * d_model + i] += up[j] * layer.w2[j * d_model + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
for i in 0..output.len() {
|
||||
output[i] += ffn_out[i];
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Compute output logits from hidden state
|
||||
fn compute_output(&self, hidden: &[f32], output_proj: &[f32], shape: &FixedShape) -> Vec<i16> {
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
let seq_len = hidden.len() / d_model;
|
||||
|
||||
// Take last token's hidden state
|
||||
let last_hidden = &hidden[(seq_len - 1) * d_model..];
|
||||
|
||||
// Compute logits
|
||||
let mut logits = vec![0.0f32; vocab];
|
||||
if output_proj.len() >= d_model * vocab {
|
||||
for v in 0..vocab {
|
||||
for d in 0..d_model {
|
||||
logits[v] += last_hidden[d] * output_proj[d * vocab + v];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: random logits for simulation when weights not available
|
||||
for v in 0..vocab {
|
||||
logits[v] = (v as f32 * 0.01).sin();
|
||||
}
|
||||
}
|
||||
|
||||
// Apply softmax (optional) and quantize
|
||||
if self.config.lut_softmax {
|
||||
softmax_lut(&mut logits);
|
||||
} else {
|
||||
softmax_f32(&mut logits);
|
||||
}
|
||||
|
||||
// Quantize to i16
|
||||
quantize_i16(&logits)
|
||||
}
|
||||
|
||||
/// Compute coherence signal for early exit decision
|
||||
fn compute_coherence_signal(&self, hidden: &[f32]) -> i16 {
|
||||
// Simple coherence metric: variance of hidden states
|
||||
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
|
||||
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
|
||||
|
||||
// Scale to Q8.8 fixed point
|
||||
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
|
||||
}
|
||||
|
||||
/// Prepare model from artifact (dequantize weights for simulation)
|
||||
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<LoadedModel> {
|
||||
let shape = &artifact.manifest.shape;
|
||||
let quant = &artifact.manifest.quant;
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
|
||||
// Dequantize embeddings
|
||||
let embedding_size = vocab * d_model;
|
||||
let embeddings = if artifact.weights.len() >= embedding_size {
|
||||
dequantize_i8(&artifact.weights[..embedding_size], quant)
|
||||
} else {
|
||||
// Generate random embeddings for testing
|
||||
(0..embedding_size)
|
||||
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Create simplified layer weights
|
||||
let num_layers = 4; // Default for simulation
|
||||
let layers: Vec<LayerWeights> = (0..num_layers)
|
||||
.map(|_| LayerWeights {
|
||||
wq: vec![0.01; d_model * d_model],
|
||||
wk: vec![0.01; d_model * d_model],
|
||||
wv: vec![0.01; d_model * d_model],
|
||||
wo: vec![0.01; d_model * d_model],
|
||||
w1: vec![0.01; d_model * 4 * d_model],
|
||||
w2: vec![0.01; 4 * d_model * d_model],
|
||||
ln1_weight: vec![1.0; d_model],
|
||||
ln2_weight: vec![1.0; d_model],
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Output projection
|
||||
let output_proj = vec![0.01; d_model * vocab];
|
||||
|
||||
Ok(LoadedModel {
|
||||
artifact: artifact.clone(),
|
||||
embeddings,
|
||||
layers,
|
||||
output_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for NativeSimBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
// Prepare model
|
||||
let model = self.prepare_model(artifact)?;
|
||||
let model_id = artifact.model_id();
|
||||
|
||||
// Check capacity (with poison handling)
|
||||
let at_capacity = read_lock(&self.models, |models| {
|
||||
models.len() >= self.config.max_models && !models.contains_key(&model_id)
|
||||
})?;
|
||||
|
||||
if at_capacity {
|
||||
return Err(Error::ResourceExhausted("Max models reached".into()));
|
||||
}
|
||||
|
||||
// Store model
|
||||
write_lock(&self.models, |models| {
|
||||
models.insert(model_id, Arc::new(model));
|
||||
})?;
|
||||
|
||||
// Update stats
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded += 1;
|
||||
})?;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Get model (with poison handling)
|
||||
let model = read_lock(&self.models, |models| models.get(&req.model).cloned())?
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate shape
|
||||
if model.artifact.manifest.shape != req.shape {
|
||||
return Err(Error::ShapeMismatch {
|
||||
expected: model.artifact.manifest.shape,
|
||||
actual: req.shape,
|
||||
});
|
||||
}
|
||||
|
||||
// Validate tokens against vocabulary
|
||||
validate_tokens(req.tokens, model.artifact.manifest.shape.vocab)?;
|
||||
|
||||
// Run inference
|
||||
let (logits_q, gate_decision) =
|
||||
self.run_inference(&model, req.tokens, req.attn_mask, &req.gate_hint)?;
|
||||
|
||||
let latency_ns = start.elapsed().as_nanos() as u32;
|
||||
|
||||
// Compute top-K using common utility
|
||||
let topk = compute_topk(&logits_q, 16);
|
||||
|
||||
// Create witness
|
||||
let witness = WitnessLog::new(
|
||||
model.artifact.model_hash(),
|
||||
model.artifact.quant_hash(),
|
||||
BackendKind::NativeSim,
|
||||
0, // No cycles for simulator
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
// Update stats (with poison handling)
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.total_inferences += 1;
|
||||
let n = stats.total_inferences;
|
||||
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(InferenceResult::new(logits_q, Some(topk), witness))
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
|
||||
|
||||
if removed {
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded = stats.models_loaded.saturating_sub(1);
|
||||
})?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::NativeSim
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn layer_norm_inplace(x: &mut [f32], weight: &[f32]) {
|
||||
let mean = x.iter().sum::<f32>() / x.len() as f32;
|
||||
let variance = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
|
||||
let std = (variance + 1e-5).sqrt();
|
||||
|
||||
for (i, v) in x.iter_mut().enumerate() {
|
||||
*v = (*v - mean) / std * weight.get(i).copied().unwrap_or(1.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
fn softmax_f32(x: &mut [f32]) {
|
||||
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in x.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
if sum > 0.0 {
|
||||
for v in x.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "test_model".into(),
|
||||
model_hash: "0".repeat(64),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact {
|
||||
manifest,
|
||||
weights: vec![0u8; 4096 * 64], // Minimal embedding weights
|
||||
bitstream: None,
|
||||
calibration: None,
|
||||
test_vectors: vec![],
|
||||
signature: [0u8; 64],
|
||||
pubkey: [0u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_native_sim_load_unload() {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let artifact = create_test_artifact();
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
|
||||
assert!(backend.is_loaded(model_id));
|
||||
|
||||
backend.unload(model_id).unwrap();
|
||||
assert!(!backend.is_loaded(model_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_native_sim_inference() {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let artifact = create_test_artifact();
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
|
||||
let tokens: Vec<u16> = (0..32).collect();
|
||||
let mask = vec![1u8; 32];
|
||||
|
||||
let req = InferenceRequest::new(
|
||||
model_id,
|
||||
FixedShape::micro(),
|
||||
&tokens,
|
||||
&mask,
|
||||
GateHint::allow_all(),
|
||||
);
|
||||
|
||||
let result = backend.infer(req).unwrap();
|
||||
|
||||
assert!(!result.logits_q.is_empty());
|
||||
assert!(result.topk.is_some());
|
||||
assert_eq!(result.witness.backend, BackendKind::NativeSim);
|
||||
}
|
||||
}
|
||||
348
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/wasm_sim.rs
vendored
Normal file
348
vendor/ruvector/crates/ruvector-fpga-transformer/src/backend/wasm_sim.rs
vendored
Normal file
@@ -0,0 +1,348 @@
|
||||
//! WASM Simulator backend
|
||||
//!
|
||||
//! Pure Rust implementation that runs in WASM environments.
|
||||
//! Uses RefCell for interior mutability since WASM is single-threaded.
|
||||
|
||||
#![cfg(feature = "wasm")]
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{compute_topk, validate_tokens, BackendStats, TransformerBackend};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::gating::CoherenceGate;
|
||||
use crate::quant::{dequantize_i8, quantize_i16};
|
||||
use crate::types::{
|
||||
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
|
||||
QuantSpec, WitnessLog,
|
||||
};
|
||||
|
||||
/// Loaded model for WASM simulation
|
||||
struct WasmModel {
|
||||
/// Model artifact
|
||||
artifact: ModelArtifact,
|
||||
/// Prepacked embedding table (dequantized to f32 for computation)
|
||||
embeddings: Vec<f32>,
|
||||
/// Number of layers
|
||||
num_layers: usize,
|
||||
/// Shape info
|
||||
shape: FixedShape,
|
||||
}
|
||||
|
||||
/// WASM simulator backend state (interior mutable for single-threaded WASM)
|
||||
struct WasmState {
|
||||
/// Loaded models
|
||||
models: HashMap<ModelId, WasmModel>,
|
||||
/// Statistics
|
||||
stats: BackendStats,
|
||||
}
|
||||
|
||||
/// WASM simulator backend
|
||||
///
|
||||
/// Uses RefCell for interior mutability since WASM is inherently single-threaded.
|
||||
/// This allows the TransformerBackend trait to be implemented with &self methods.
|
||||
pub struct WasmSimBackend {
|
||||
/// Interior mutable state
|
||||
state: RefCell<WasmState>,
|
||||
/// Coherence gate (immutable, shared)
|
||||
gate: Rc<dyn CoherenceGate>,
|
||||
}
|
||||
|
||||
impl WasmSimBackend {
|
||||
/// Create a new WASM simulator backend
|
||||
pub fn new(gate: Rc<dyn CoherenceGate>) -> Self {
|
||||
Self {
|
||||
state: RefCell::new(WasmState {
|
||||
models: HashMap::new(),
|
||||
stats: BackendStats::default(),
|
||||
}),
|
||||
gate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare model from artifact
|
||||
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<WasmModel> {
|
||||
let shape = artifact.manifest.shape;
|
||||
let quant = &artifact.manifest.quant;
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
|
||||
// Dequantize embeddings
|
||||
let embedding_size = vocab * d_model;
|
||||
let embeddings = if artifact.weights.len() >= embedding_size {
|
||||
dequantize_i8(&artifact.weights[..embedding_size], quant)
|
||||
} else {
|
||||
// Generate deterministic embeddings for testing
|
||||
(0..embedding_size)
|
||||
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Determine number of layers from artifact or default
|
||||
let num_layers = if artifact.manifest.backend.options.early_exit {
|
||||
6
|
||||
} else {
|
||||
4
|
||||
};
|
||||
|
||||
Ok(WasmModel {
|
||||
artifact: artifact.clone(),
|
||||
embeddings,
|
||||
num_layers,
|
||||
shape,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run inference for WASM
|
||||
fn run_inference(
|
||||
&self,
|
||||
model: &WasmModel,
|
||||
tokens: &[u16],
|
||||
gate_hint: &GateHint,
|
||||
) -> (Vec<i16>, GateDecision) {
|
||||
let shape = &model.shape;
|
||||
|
||||
// Check preflight
|
||||
let preflight = self.gate.preflight(gate_hint);
|
||||
if !preflight.did_run() {
|
||||
return (vec![0i16; shape.vocab as usize], preflight);
|
||||
}
|
||||
|
||||
let vocab = shape.vocab as usize;
|
||||
let d_model = shape.d_model as usize;
|
||||
|
||||
// Initialize hidden state from embeddings
|
||||
let seq_len = tokens.len();
|
||||
let mut hidden = vec![0.0f32; seq_len * d_model];
|
||||
|
||||
// Lookup embeddings with bounds checking
|
||||
for (i, &token) in tokens.iter().enumerate() {
|
||||
let offset = (token as usize).min(vocab.saturating_sub(1)) * d_model;
|
||||
if offset + d_model <= model.embeddings.len() {
|
||||
hidden[i * d_model..(i + 1) * d_model]
|
||||
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
|
||||
}
|
||||
}
|
||||
|
||||
// Run through simplified layers with early exit support
|
||||
for layer in 0..model.num_layers {
|
||||
// Simple layer computation (for WASM we keep it lightweight)
|
||||
// Apply simple transformation
|
||||
for t in 0..seq_len {
|
||||
let start = t * d_model;
|
||||
// Simple ReLU-like activation
|
||||
for i in 0..d_model {
|
||||
hidden[start + i] =
|
||||
hidden[start + i].max(0.0) * 0.99 + hidden[start + i] * 0.01;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for early exit
|
||||
let coherence_signal = compute_coherence(&hidden);
|
||||
if let Some(decision) = self.gate.checkpoint(layer as u8, coherence_signal) {
|
||||
let logits = self.compute_output(&hidden, model);
|
||||
return (logits, decision);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute output logits
|
||||
let logits = self.compute_output(&hidden, model);
|
||||
(logits, GateDecision::RanFull)
|
||||
}
|
||||
|
||||
/// Compute output logits from hidden state
|
||||
fn compute_output(&self, hidden: &[f32], model: &WasmModel) -> Vec<i16> {
|
||||
let shape = &model.shape;
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
let seq_len = hidden.len() / d_model;
|
||||
|
||||
// Take last token's hidden state
|
||||
let last_hidden = &hidden[(seq_len.saturating_sub(1)) * d_model..];
|
||||
|
||||
// Compute logits via dot product with embedding matrix (transposed)
|
||||
let mut logits_f32 = vec![0.0f32; vocab];
|
||||
for v in 0..vocab.min(model.embeddings.len() / d_model) {
|
||||
let v_offset = v * d_model;
|
||||
let mut dot = 0.0f32;
|
||||
for d in 0..d_model.min(last_hidden.len()) {
|
||||
if v_offset + d < model.embeddings.len() {
|
||||
dot += last_hidden[d] * model.embeddings[v_offset + d];
|
||||
}
|
||||
}
|
||||
logits_f32[v] = dot;
|
||||
}
|
||||
|
||||
// Apply softmax and quantize
|
||||
softmax_inplace(&mut logits_f32);
|
||||
quantize_i16(&logits_f32)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: WASM is single-threaded, so these trait bounds are satisfied trivially
|
||||
// by never actually being used across threads
|
||||
unsafe impl Send for WasmSimBackend {}
|
||||
unsafe impl Sync for WasmSimBackend {}
|
||||
|
||||
impl TransformerBackend for WasmSimBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
// Prepare model
|
||||
let model = self.prepare_model(artifact)?;
|
||||
let model_id = artifact.model_id();
|
||||
|
||||
// Store in state
|
||||
let mut state = self.state.borrow_mut();
|
||||
state.models.insert(model_id, model);
|
||||
state.stats.models_loaded += 1;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
let start = js_sys::Date::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Get model (immutable borrow)
|
||||
let state = self.state.borrow();
|
||||
let model = state
|
||||
.models
|
||||
.get(&req.model)
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate tokens
|
||||
validate_tokens(req.tokens, model.shape.vocab)?;
|
||||
|
||||
// Run inference
|
||||
let (logits, gate_decision) = self.run_inference(model, req.tokens, &req.gate_hint);
|
||||
|
||||
let latency_ns = ((js_sys::Date::now() - start) * 1_000_000.0) as u32;
|
||||
|
||||
// Compute top-K
|
||||
let topk = compute_topk(&logits, 16);
|
||||
|
||||
// Build witness
|
||||
let witness = WitnessLog::new(
|
||||
model.artifact.model_hash(),
|
||||
model.artifact.quant_hash(),
|
||||
BackendKind::WasmSim,
|
||||
0, // No cycles for WASM sim
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
drop(state); // Release borrow before mutable borrow
|
||||
|
||||
// Update stats
|
||||
{
|
||||
let mut state = self.state.borrow_mut();
|
||||
state.stats.total_inferences += 1;
|
||||
let n = state.stats.total_inferences;
|
||||
state.stats.avg_latency_ns =
|
||||
(state.stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => state.stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => state.stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(InferenceResult::new(logits, Some(topk), witness))
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
if state.models.remove(&model).is_some() {
|
||||
state.stats.models_loaded = state.stats.models_loaded.saturating_sub(1);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
self.state.borrow().models.contains_key(&model)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::WasmSim
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
self.state.borrow().stats.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute coherence signal from hidden state
|
||||
fn compute_coherence(hidden: &[f32]) -> i16 {
|
||||
if hidden.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
|
||||
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
|
||||
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
|
||||
}
|
||||
|
||||
/// In-place softmax
|
||||
fn softmax_inplace(x: &mut [f32]) {
|
||||
if x.is_empty() {
|
||||
return;
|
||||
}
|
||||
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in x.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
if sum > 0.0 {
|
||||
for v in x.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "wasm_test".into(),
|
||||
model_hash: "0".repeat(64),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact {
|
||||
manifest,
|
||||
weights: vec![0u8; 4096 * 64],
|
||||
bitstream: None,
|
||||
calibration: None,
|
||||
test_vectors: vec![],
|
||||
signature: [0u8; 64],
|
||||
pubkey: [0u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wasm_sim_prepare_model() {
|
||||
let gate = Rc::new(DefaultCoherenceGate::new());
|
||||
let backend = WasmSimBackend::new(gate);
|
||||
let artifact = create_test_artifact();
|
||||
|
||||
let model = backend.prepare_model(&artifact).unwrap();
|
||||
assert_eq!(model.shape.seq_len, 32);
|
||||
assert!(!model.embeddings.is_empty());
|
||||
}
|
||||
}
|
||||
136
vendor/ruvector/crates/ruvector-fpga-transformer/src/error.rs
vendored
Normal file
136
vendor/ruvector/crates/ruvector-fpga-transformer/src/error.rs
vendored
Normal file
@@ -0,0 +1,136 @@
|
||||
//! Error types for FPGA Transformer backend
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type alias for FPGA Transformer operations
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// FPGA Transformer error types
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
/// Model artifact is invalid or corrupted
|
||||
#[error("Invalid artifact: {0}")]
|
||||
InvalidArtifact(String),
|
||||
|
||||
/// Artifact signature verification failed
|
||||
#[error("Signature verification failed: {0}")]
|
||||
SignatureError(String),
|
||||
|
||||
/// Test vectors failed validation
|
||||
#[error("Test vector validation failed: expected max error {expected}, got {actual}")]
|
||||
TestVectorError { expected: i32, actual: i32 },
|
||||
|
||||
/// Model not found or not loaded
|
||||
#[error("Model not found: {0:?}")]
|
||||
ModelNotFound(crate::types::ModelId),
|
||||
|
||||
/// Shape mismatch between request and model
|
||||
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
|
||||
ShapeMismatch {
|
||||
expected: crate::types::FixedShape,
|
||||
actual: crate::types::FixedShape,
|
||||
},
|
||||
|
||||
/// Input length does not match expected sequence length
|
||||
#[error("Input length mismatch: expected {expected}, got {actual}")]
|
||||
InputLengthMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Backend communication error
|
||||
#[error("Backend error: {0}")]
|
||||
BackendError(String),
|
||||
|
||||
/// Daemon connection failed
|
||||
#[error("Daemon connection failed: {0}")]
|
||||
DaemonConnectionError(String),
|
||||
|
||||
/// PCIe communication error
|
||||
#[error("PCIe error: {0}")]
|
||||
PcieError(String),
|
||||
|
||||
/// DMA operation failed
|
||||
#[error("DMA error: {0}")]
|
||||
DmaError(String),
|
||||
|
||||
/// Gating decision blocked inference
|
||||
#[error("Inference blocked by gate: {reason:?}")]
|
||||
GateBlocked { reason: crate::types::SkipReason },
|
||||
|
||||
/// Quantization error
|
||||
#[error("Quantization error: {0}")]
|
||||
QuantizationError(String),
|
||||
|
||||
/// Overflow during fixed-point computation
|
||||
#[error("Fixed-point overflow at {location}")]
|
||||
FixedPointOverflow { location: &'static str },
|
||||
|
||||
/// Invalid configuration
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
/// JSON parsing error
|
||||
#[error("JSON error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
/// Checksum mismatch
|
||||
#[error("Checksum mismatch: expected {expected:08x}, got {actual:08x}")]
|
||||
ChecksumMismatch { expected: u32, actual: u32 },
|
||||
|
||||
/// Protocol version mismatch
|
||||
#[error("Protocol version mismatch: expected {expected}, got {actual}")]
|
||||
ProtocolMismatch { expected: u16, actual: u16 },
|
||||
|
||||
/// Timeout waiting for response
|
||||
#[error("Timeout after {ms}ms")]
|
||||
Timeout { ms: u64 },
|
||||
|
||||
/// Resource exhausted (memory, slots, etc.)
|
||||
#[error("Resource exhausted: {0}")]
|
||||
ResourceExhausted(String),
|
||||
|
||||
/// Feature not available in this build
|
||||
#[error("Feature not available: {0}")]
|
||||
FeatureNotAvailable(String),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Create a new InvalidArtifact error
|
||||
pub fn invalid_artifact(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidArtifact(msg.into())
|
||||
}
|
||||
|
||||
/// Create a new BackendError
|
||||
pub fn backend(msg: impl Into<String>) -> Self {
|
||||
Self::BackendError(msg.into())
|
||||
}
|
||||
|
||||
/// Create a new DaemonConnectionError
|
||||
pub fn daemon_connection(msg: impl Into<String>) -> Self {
|
||||
Self::DaemonConnectionError(msg.into())
|
||||
}
|
||||
|
||||
/// Check if this error is recoverable
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Error::Timeout { .. }
|
||||
| Error::DaemonConnectionError(_)
|
||||
| Error::BackendError(_)
|
||||
| Error::GateBlocked { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this error indicates a configuration problem
|
||||
pub fn is_config_error(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Error::InvalidConfig(_)
|
||||
| Error::ShapeMismatch { .. }
|
||||
| Error::InputLengthMismatch { .. }
|
||||
| Error::FeatureNotAvailable(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
302
vendor/ruvector/crates/ruvector-fpga-transformer/src/ffi/c_abi.rs
vendored
Normal file
302
vendor/ruvector/crates/ruvector-fpga-transformer/src/ffi/c_abi.rs
vendored
Normal file
@@ -0,0 +1,302 @@
|
||||
//! C ABI bindings for FFI integration
|
||||
//!
|
||||
//! Provides a stable C interface for linking from other languages.
|
||||
|
||||
use std::ffi::{c_char, c_int, c_void, CStr};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::backend::native_sim::NativeSimBackend;
|
||||
use crate::backend::TransformerBackend;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
|
||||
|
||||
/// Opaque engine handle
|
||||
pub struct FpgaEngine {
|
||||
backend: Box<dyn TransformerBackend>,
|
||||
}
|
||||
|
||||
/// Result code
|
||||
#[repr(C)]
|
||||
pub enum FpgaResult {
|
||||
Ok = 0,
|
||||
InvalidArgument = 1,
|
||||
ModelNotFound = 2,
|
||||
InferenceFailed = 3,
|
||||
AllocationFailed = 4,
|
||||
InvalidArtifact = 5,
|
||||
}
|
||||
|
||||
/// Inference result structure
|
||||
#[repr(C)]
|
||||
pub struct FpgaInferenceResult {
|
||||
/// Status code
|
||||
pub status: FpgaResult,
|
||||
/// Logits (caller must free with fpga_free_logits)
|
||||
pub logits: *mut i16,
|
||||
/// Number of logits
|
||||
pub logits_len: usize,
|
||||
/// Top-K results (token_id, logit pairs)
|
||||
pub topk: *mut u32,
|
||||
/// Number of top-K pairs
|
||||
pub topk_len: usize,
|
||||
/// Latency in nanoseconds
|
||||
pub latency_ns: u32,
|
||||
/// Compute cycles
|
||||
pub cycles: u32,
|
||||
/// Gate decision (0=full, 1=early_exit, 2=skipped)
|
||||
pub gate_decision: u8,
|
||||
/// Exit layer (if early exit)
|
||||
pub exit_layer: u8,
|
||||
}
|
||||
|
||||
/// Create a new FPGA engine with native simulator backend
|
||||
///
|
||||
/// Returns a handle that must be freed with `fpga_engine_destroy`
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_engine_create() -> *mut FpgaEngine {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = Box::new(NativeSimBackend::new(gate));
|
||||
|
||||
let engine = Box::new(FpgaEngine { backend });
|
||||
Box::into_raw(engine)
|
||||
}
|
||||
|
||||
/// Destroy an FPGA engine
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_engine_destroy(engine: *mut FpgaEngine) {
|
||||
if !engine.is_null() {
|
||||
unsafe {
|
||||
drop(Box::from_raw(engine));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a model artifact
|
||||
///
|
||||
/// Returns model ID bytes (32 bytes) on success, NULL on failure
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_load_artifact(
|
||||
engine: *mut FpgaEngine,
|
||||
artifact_bytes: *const u8,
|
||||
artifact_len: usize,
|
||||
model_id_out: *mut u8,
|
||||
) -> FpgaResult {
|
||||
if engine.is_null() || artifact_bytes.is_null() || model_id_out.is_null() {
|
||||
return FpgaResult::InvalidArgument;
|
||||
}
|
||||
|
||||
let engine = unsafe { &mut *engine };
|
||||
let artifact_slice = unsafe { std::slice::from_raw_parts(artifact_bytes, artifact_len) };
|
||||
|
||||
let artifact = match crate::artifact::unpack_artifact(artifact_slice) {
|
||||
Ok(a) => a,
|
||||
Err(_) => return FpgaResult::InvalidArtifact,
|
||||
};
|
||||
|
||||
match engine.backend.load(&artifact) {
|
||||
Ok(model_id) => {
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(model_id.as_bytes().as_ptr(), model_id_out, 32);
|
||||
}
|
||||
FpgaResult::Ok
|
||||
}
|
||||
Err(_) => FpgaResult::InvalidArtifact,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run inference
|
||||
///
|
||||
/// Result must be freed with `fpga_result_free`
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_infer(
|
||||
engine: *mut FpgaEngine,
|
||||
model_id: *const u8,
|
||||
tokens: *const u16,
|
||||
tokens_len: usize,
|
||||
mask: *const u8,
|
||||
mask_len: usize,
|
||||
coherence_score: i16,
|
||||
boundary_crossed: bool,
|
||||
max_compute_class: u8,
|
||||
) -> FpgaInferenceResult {
|
||||
let error_result = || FpgaInferenceResult {
|
||||
status: FpgaResult::InvalidArgument,
|
||||
logits: ptr::null_mut(),
|
||||
logits_len: 0,
|
||||
topk: ptr::null_mut(),
|
||||
topk_len: 0,
|
||||
latency_ns: 0,
|
||||
cycles: 0,
|
||||
gate_decision: 2,
|
||||
exit_layer: 0,
|
||||
};
|
||||
|
||||
if engine.is_null() || model_id.is_null() || tokens.is_null() || mask.is_null() {
|
||||
return error_result();
|
||||
}
|
||||
|
||||
let engine = unsafe { &mut *engine };
|
||||
|
||||
// Parse model ID
|
||||
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(id_slice);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
// Parse tokens and mask
|
||||
let tokens_slice = unsafe { std::slice::from_raw_parts(tokens, tokens_len) };
|
||||
let mask_slice = unsafe { std::slice::from_raw_parts(mask, mask_len) };
|
||||
|
||||
// Build shape (micro for C API)
|
||||
let shape = FixedShape::micro();
|
||||
|
||||
// Build gate hint
|
||||
let compute_class =
|
||||
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
|
||||
let gate_hint = GateHint::new(coherence_score, boundary_crossed, compute_class);
|
||||
|
||||
// Create request
|
||||
let req = InferenceRequest::new(model, shape, tokens_slice, mask_slice, gate_hint);
|
||||
|
||||
// Run inference
|
||||
match engine.backend.infer(req) {
|
||||
Ok(result) => {
|
||||
// Allocate logits with checked allocation (prevents panic on overflow)
|
||||
let logits_len = result.logits_q.len();
|
||||
let logits = if logits_len > 0 {
|
||||
match std::alloc::Layout::array::<i16>(logits_len) {
|
||||
Ok(layout) if layout.size() > 0 => {
|
||||
let ptr = unsafe { std::alloc::alloc(layout) as *mut i16 };
|
||||
if !ptr.is_null() {
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(result.logits_q.as_ptr(), ptr, logits_len);
|
||||
}
|
||||
}
|
||||
ptr
|
||||
}
|
||||
_ => ptr::null_mut(), // Return null on allocation failure
|
||||
}
|
||||
} else {
|
||||
ptr::null_mut()
|
||||
};
|
||||
|
||||
// Allocate top-K with checked allocation
|
||||
let (topk, topk_len) = if let Some(ref tk) = result.topk {
|
||||
let len = tk.len() * 2; // (token, logit) pairs
|
||||
match std::alloc::Layout::array::<u32>(len) {
|
||||
Ok(layout) if layout.size() > 0 => {
|
||||
let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 };
|
||||
if !ptr.is_null() {
|
||||
for (i, (token, logit)) in tk.iter().enumerate() {
|
||||
unsafe {
|
||||
*ptr.add(i * 2) = *token as u32;
|
||||
*ptr.add(i * 2 + 1) = *logit as u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
(ptr, tk.len())
|
||||
}
|
||||
_ => (ptr::null_mut(), 0), // Return null on allocation failure
|
||||
}
|
||||
} else {
|
||||
(ptr::null_mut(), 0)
|
||||
};
|
||||
|
||||
// Encode gate decision
|
||||
let (gate_decision, exit_layer) = match result.witness.gate_decision {
|
||||
crate::types::GateDecision::RanFull => (0, 0),
|
||||
crate::types::GateDecision::EarlyExit { layer } => (1, layer),
|
||||
crate::types::GateDecision::Skipped { .. } => (2, 0),
|
||||
};
|
||||
|
||||
FpgaInferenceResult {
|
||||
status: FpgaResult::Ok,
|
||||
logits,
|
||||
logits_len,
|
||||
topk,
|
||||
topk_len,
|
||||
latency_ns: result.witness.latency_ns,
|
||||
cycles: result.witness.cycles,
|
||||
gate_decision,
|
||||
exit_layer,
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let mut result = error_result();
|
||||
result.status = FpgaResult::InferenceFailed;
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Free inference result
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_result_free(result: *mut FpgaInferenceResult) {
|
||||
if result.is_null() {
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let r = &mut *result;
|
||||
|
||||
if !r.logits.is_null() && r.logits_len > 0 {
|
||||
std::alloc::dealloc(
|
||||
r.logits as *mut u8,
|
||||
std::alloc::Layout::array::<i16>(r.logits_len).unwrap(),
|
||||
);
|
||||
r.logits = ptr::null_mut();
|
||||
}
|
||||
|
||||
if !r.topk.is_null() && r.topk_len > 0 {
|
||||
std::alloc::dealloc(
|
||||
r.topk as *mut u8,
|
||||
std::alloc::Layout::array::<u32>(r.topk_len * 2).unwrap(),
|
||||
);
|
||||
r.topk = ptr::null_mut();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unload a model
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_unload(engine: *mut FpgaEngine, model_id: *const u8) -> FpgaResult {
|
||||
if engine.is_null() || model_id.is_null() {
|
||||
return FpgaResult::InvalidArgument;
|
||||
}
|
||||
|
||||
let engine = unsafe { &mut *engine };
|
||||
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(id_slice);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
match engine.backend.unload(model) {
|
||||
Ok(()) => FpgaResult::Ok,
|
||||
Err(_) => FpgaResult::ModelNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a model is loaded
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_is_loaded(engine: *const FpgaEngine, model_id: *const u8) -> bool {
|
||||
if engine.is_null() || model_id.is_null() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let engine = unsafe { &*engine };
|
||||
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(id_slice);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
engine.backend.is_loaded(model)
|
||||
}
|
||||
|
||||
/// Get version string
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_version() -> *const c_char {
|
||||
// Static string with null terminator
|
||||
static VERSION: &[u8] = b"0.1.0\0";
|
||||
VERSION.as_ptr() as *const c_char
|
||||
}
|
||||
8
vendor/ruvector/crates/ruvector-fpga-transformer/src/ffi/mod.rs
vendored
Normal file
8
vendor/ruvector/crates/ruvector-fpga-transformer/src/ffi/mod.rs
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
//! Foreign function interfaces for FPGA Transformer
|
||||
//!
|
||||
//! Provides C ABI and WASM bindings.
|
||||
|
||||
#[cfg(feature = "wasm")]
|
||||
pub mod wasm_bindgen;
|
||||
|
||||
pub mod c_abi;
|
||||
282
vendor/ruvector/crates/ruvector-fpga-transformer/src/ffi/wasm_bindgen.rs
vendored
Normal file
282
vendor/ruvector/crates/ruvector-fpga-transformer/src/ffi/wasm_bindgen.rs
vendored
Normal file
@@ -0,0 +1,282 @@
|
||||
//! WASM bindings via wasm-bindgen
|
||||
//!
|
||||
//! Provides the same API shape for browser and Node.js environments.
|
||||
|
||||
#![cfg(feature = "wasm")]
|
||||
|
||||
use js_sys::{Array, Int16Array, Object, Reflect, Uint16Array, Uint8Array};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::artifact::{unpack_artifact, ModelArtifact};
|
||||
use crate::backend::native_sim::{NativeSimBackend, NativeSimConfig};
|
||||
use crate::backend::TransformerBackend;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// WASM Engine for transformer inference
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmEngine {
|
||||
backend: NativeSimBackend,
|
||||
loaded_models: Vec<ModelId>,
|
||||
last_witness: Option<crate::types::WitnessLog>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEngine {
|
||||
/// Create a new WASM engine
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
// Use permissive config for WASM
|
||||
let config = NativeSimConfig {
|
||||
max_models: 4,
|
||||
trace: false,
|
||||
lut_softmax: true,
|
||||
max_layers: 0,
|
||||
};
|
||||
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::with_config(gate, config);
|
||||
|
||||
Self {
|
||||
backend,
|
||||
loaded_models: Vec::new(),
|
||||
last_witness: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a model artifact from bytes
|
||||
///
|
||||
/// Returns the model ID as a Uint8Array on success
|
||||
#[wasm_bindgen(js_name = loadArtifact)]
|
||||
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<Uint8Array, JsValue> {
|
||||
let artifact = unpack_artifact(artifact_bytes)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to unpack artifact: {}", e)))?;
|
||||
|
||||
let model_id = self
|
||||
.backend
|
||||
.load(&artifact)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to load model: {}", e)))?;
|
||||
|
||||
self.loaded_models.push(model_id);
|
||||
|
||||
// Return model ID as Uint8Array
|
||||
let id_array = Uint8Array::new_with_length(32);
|
||||
id_array.copy_from(model_id.as_bytes());
|
||||
Ok(id_array)
|
||||
}
|
||||
|
||||
/// Run inference
|
||||
///
|
||||
/// Returns an object with logits, topk, and witness
|
||||
#[wasm_bindgen]
|
||||
pub fn infer(
|
||||
&mut self,
|
||||
model_id: &[u8],
|
||||
tokens: &[u16],
|
||||
mask: &[u8],
|
||||
coherence_score_q: i16,
|
||||
boundary_crossed: bool,
|
||||
max_compute_class: u8,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
// Parse model ID
|
||||
if model_id.len() != 32 {
|
||||
return Err(JsValue::from_str("Model ID must be 32 bytes"));
|
||||
}
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(model_id);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
// Get shape from loaded model
|
||||
// For WASM, we use micro shape by default
|
||||
let shape = FixedShape::micro();
|
||||
|
||||
// Validate input length
|
||||
if tokens.len() != shape.seq_len as usize {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Token length mismatch: expected {}, got {}",
|
||||
shape.seq_len,
|
||||
tokens.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Build gate hint
|
||||
let compute_class =
|
||||
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
|
||||
let gate_hint = GateHint::new(coherence_score_q, boundary_crossed, compute_class);
|
||||
|
||||
// Create request
|
||||
let req = InferenceRequest::new(model, shape, tokens, mask, gate_hint);
|
||||
|
||||
// Run inference
|
||||
let result = self
|
||||
.backend
|
||||
.infer(req)
|
||||
.map_err(|e| JsValue::from_str(&format!("Inference failed: {}", e)))?;
|
||||
|
||||
// Store witness
|
||||
self.last_witness = Some(result.witness.clone());
|
||||
|
||||
// Build result object
|
||||
let obj = Object::new();
|
||||
|
||||
// Add logits
|
||||
let logits = Int16Array::new_with_length(result.logits_q.len() as u32);
|
||||
logits.copy_from(&result.logits_q);
|
||||
Reflect::set(&obj, &"logits".into(), &logits)?;
|
||||
|
||||
// Add top-K if available
|
||||
if let Some(topk) = &result.topk {
|
||||
let topk_array = Array::new();
|
||||
for (token, logit) in topk {
|
||||
let pair = Array::new();
|
||||
pair.push(&JsValue::from(*token));
|
||||
pair.push(&JsValue::from(*logit));
|
||||
topk_array.push(&pair);
|
||||
}
|
||||
Reflect::set(&obj, &"topk".into(), &topk_array)?;
|
||||
}
|
||||
|
||||
// Add witness info
|
||||
let witness = Object::new();
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"backend".into(),
|
||||
&format!("{:?}", result.witness.backend).into(),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"cycles".into(),
|
||||
&JsValue::from(result.witness.cycles),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"latency_ns".into(),
|
||||
&JsValue::from(result.witness.latency_ns),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"gate_decision".into(),
|
||||
&format!("{:?}", result.witness.gate_decision).into(),
|
||||
)?;
|
||||
Reflect::set(&obj, &"witness".into(), &witness)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Get the last witness log as JSON
|
||||
#[wasm_bindgen(js_name = getWitness)]
|
||||
pub fn get_witness(&self) -> Result<JsValue, JsValue> {
|
||||
match &self.last_witness {
|
||||
Some(witness) => {
|
||||
let json = serde_json::to_string(witness)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))?;
|
||||
Ok(JsValue::from_str(&json))
|
||||
}
|
||||
None => Ok(JsValue::NULL),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of loaded model IDs
|
||||
#[wasm_bindgen(js_name = getLoadedModels)]
|
||||
pub fn get_loaded_models(&self) -> Array {
|
||||
let arr = Array::new();
|
||||
for id in &self.loaded_models {
|
||||
let id_array = Uint8Array::new_with_length(32);
|
||||
id_array.copy_from(id.as_bytes());
|
||||
arr.push(&id_array);
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
/// Unload a model
|
||||
#[wasm_bindgen]
|
||||
pub fn unload(&mut self, model_id: &[u8]) -> Result<(), JsValue> {
|
||||
if model_id.len() != 32 {
|
||||
return Err(JsValue::from_str("Model ID must be 32 bytes"));
|
||||
}
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(model_id);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
self.backend
|
||||
.unload(model)
|
||||
.map_err(|e| JsValue::from_str(&format!("Unload failed: {}", e)))?;
|
||||
|
||||
self.loaded_models.retain(|id| *id != model);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get backend statistics
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats(&self) -> Result<JsValue, JsValue> {
|
||||
let stats = self.backend.stats();
|
||||
let obj = Object::new();
|
||||
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"models_loaded".into(),
|
||||
&JsValue::from(stats.models_loaded as u32),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"total_inferences".into(),
|
||||
&JsValue::from(stats.total_inferences as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"avg_latency_ns".into(),
|
||||
&JsValue::from(stats.avg_latency_ns as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"early_exits".into(),
|
||||
&JsValue::from(stats.early_exits as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"skipped".into(),
|
||||
&JsValue::from(stats.skipped as f64),
|
||||
)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WasmEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to create a micro shape configuration
|
||||
#[wasm_bindgen(js_name = microShape)]
|
||||
pub fn micro_shape() -> Result<JsValue, JsValue> {
|
||||
let shape = FixedShape::micro();
|
||||
let obj = Object::new();
|
||||
|
||||
Reflect::set(&obj, &"seq_len".into(), &JsValue::from(shape.seq_len))?;
|
||||
Reflect::set(&obj, &"d_model".into(), &JsValue::from(shape.d_model))?;
|
||||
Reflect::set(&obj, &"heads".into(), &JsValue::from(shape.heads))?;
|
||||
Reflect::set(&obj, &"d_head".into(), &JsValue::from(shape.d_head))?;
|
||||
Reflect::set(&obj, &"vocab".into(), &JsValue::from(shape.vocab))?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Utility function to validate an artifact without loading
|
||||
#[wasm_bindgen(js_name = validateArtifact)]
|
||||
pub fn validate_artifact(artifact_bytes: &[u8]) -> Result<JsValue, JsValue> {
|
||||
let artifact = unpack_artifact(artifact_bytes)
|
||||
.map_err(|e| JsValue::from_str(&format!("Invalid artifact: {}", e)))?;
|
||||
|
||||
artifact
|
||||
.validate()
|
||||
.map_err(|e| JsValue::from_str(&format!("Validation failed: {}", e)))?;
|
||||
|
||||
let obj = Object::new();
|
||||
Reflect::set(&obj, &"name".into(), &artifact.manifest.name.into())?;
|
||||
Reflect::set(&obj, &"valid".into(), &JsValue::TRUE)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
301
vendor/ruvector/crates/ruvector-fpga-transformer/src/gating/coherence_gate.rs
vendored
Normal file
301
vendor/ruvector/crates/ruvector-fpga-transformer/src/gating/coherence_gate.rs
vendored
Normal file
@@ -0,0 +1,301 @@
|
||||
//! Coherence-based gating for inference control
|
||||
|
||||
use crate::types::{ComputeClass, GateDecision, GateHint, SkipReason};
|
||||
use crate::witness::WitnessLog;
|
||||
|
||||
/// Trait for coherence-based gating
|
||||
pub trait CoherenceGate: Send + Sync {
|
||||
/// Preflight check before inference
|
||||
///
|
||||
/// Returns a gate decision based on coherence signals.
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision;
|
||||
|
||||
/// Layer checkpoint for early exit decisions
|
||||
///
|
||||
/// Called after each layer to determine if early exit is appropriate.
|
||||
/// Returns Some(decision) to exit early, None to continue.
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision>;
|
||||
|
||||
/// Check if write is allowed based on witness
|
||||
///
|
||||
/// Used to gate state changes in memory systems.
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool;
|
||||
}
|
||||
|
||||
/// Configuration for coherence gate
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoherenceConfig {
|
||||
/// Minimum coherence score to run (Q8.8)
|
||||
pub min_coherence: i16,
|
||||
/// Coherence threshold for early exit
|
||||
pub early_exit_threshold: i16,
|
||||
/// Enable early exit
|
||||
pub early_exit_enabled: bool,
|
||||
/// Minimum layers before early exit
|
||||
pub min_layers: u8,
|
||||
/// Require stable coherence for writes
|
||||
pub require_stable_for_write: bool,
|
||||
/// Minimum coherence for writes (Q8.8)
|
||||
pub min_write_coherence: i16,
|
||||
}
|
||||
|
||||
impl Default for CoherenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_coherence: -256, // -1.0 in Q8.8, very permissive
|
||||
early_exit_threshold: 512, // 2.0 in Q8.8
|
||||
early_exit_enabled: true,
|
||||
min_layers: 2,
|
||||
require_stable_for_write: true,
|
||||
min_write_coherence: 0, // Require non-negative coherence
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceConfig {
|
||||
/// Create a strict configuration
|
||||
pub fn strict() -> Self {
|
||||
Self {
|
||||
min_coherence: 0,
|
||||
early_exit_threshold: 256,
|
||||
early_exit_enabled: true,
|
||||
min_layers: 4,
|
||||
require_stable_for_write: true,
|
||||
min_write_coherence: 128,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permissive configuration (always allows)
|
||||
pub fn permissive() -> Self {
|
||||
Self {
|
||||
min_coherence: i16::MIN,
|
||||
early_exit_threshold: i16::MAX,
|
||||
early_exit_enabled: false,
|
||||
min_layers: 0,
|
||||
require_stable_for_write: false,
|
||||
min_write_coherence: i16::MIN,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Default coherence gate implementation
|
||||
pub struct DefaultCoherenceGate {
|
||||
config: CoherenceConfig,
|
||||
}
|
||||
|
||||
impl DefaultCoherenceGate {
|
||||
/// Create with default config
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(CoherenceConfig::default())
|
||||
}
|
||||
|
||||
/// Create with custom config
|
||||
pub fn with_config(config: CoherenceConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Check if compute class allows operation
|
||||
fn check_compute_class(&self, hint: &GateHint) -> bool {
|
||||
// Reflex class can always run (fast path)
|
||||
// Higher classes require sufficient coherence
|
||||
match hint.max_compute_class {
|
||||
ComputeClass::Reflex => true,
|
||||
ComputeClass::Associative => hint.coherence_score_q >= self.config.min_coherence / 2,
|
||||
ComputeClass::Deliberative => hint.coherence_score_q >= self.config.min_coherence,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefaultCoherenceGate {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceGate for DefaultCoherenceGate {
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
// Check minimum coherence
|
||||
if hint.coherence_score_q < self.config.min_coherence {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence,
|
||||
};
|
||||
}
|
||||
|
||||
// Check compute class restrictions
|
||||
if !self.check_compute_class(hint) {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::BudgetExceeded,
|
||||
};
|
||||
}
|
||||
|
||||
// Allow full inference
|
||||
GateDecision::RanFull
|
||||
}
|
||||
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
if !self.config.early_exit_enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
if layer < self.config.min_layers {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if coherence signal is high enough to exit early
|
||||
if signal_q >= self.config.early_exit_threshold {
|
||||
return Some(GateDecision::EarlyExit { layer });
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
// Skip writes if inference was skipped
|
||||
if !witness.gate_decision.did_run() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If we require stable coherence, only allow writes after full run
|
||||
if self.config.require_stable_for_write {
|
||||
matches!(witness.gate_decision, GateDecision::RanFull)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Mincut-aware coherence gate
|
||||
///
|
||||
/// Uses mincut signals to make more informed gating decisions.
|
||||
pub struct MincutCoherenceGate {
|
||||
base: DefaultCoherenceGate,
|
||||
/// Minimum lambda (mincut value) for inference
|
||||
pub min_lambda: i16,
|
||||
/// Lambda threshold for early exit
|
||||
pub lambda_exit_threshold: i16,
|
||||
}
|
||||
|
||||
impl MincutCoherenceGate {
|
||||
/// Create a new mincut-aware gate
|
||||
pub fn new(config: CoherenceConfig, min_lambda: i16, lambda_exit_threshold: i16) -> Self {
|
||||
Self {
|
||||
base: DefaultCoherenceGate::with_config(config),
|
||||
min_lambda,
|
||||
lambda_exit_threshold,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceGate for MincutCoherenceGate {
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
// Use base coherence check
|
||||
let base_decision = self.base.preflight(hint);
|
||||
if !base_decision.did_run() {
|
||||
return base_decision;
|
||||
}
|
||||
|
||||
// Additional mincut check
|
||||
// If boundary was crossed and coherence is low, skip
|
||||
if hint.boundary_crossed && hint.coherence_score_q < 0 {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence,
|
||||
};
|
||||
}
|
||||
|
||||
GateDecision::RanFull
|
||||
}
|
||||
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
// Use base checkpoint with mincut-adjusted threshold
|
||||
let adjusted_threshold = if signal_q > self.lambda_exit_threshold {
|
||||
// High lambda suggests stable state, lower exit threshold
|
||||
self.base.config.early_exit_threshold / 2
|
||||
} else {
|
||||
self.base.config.early_exit_threshold
|
||||
};
|
||||
|
||||
if layer >= self.base.config.min_layers && signal_q >= adjusted_threshold {
|
||||
return Some(GateDecision::EarlyExit { layer });
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
// Use base write check
|
||||
self.base.allow_write(witness)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_gate_preflight() {
|
||||
let gate = DefaultCoherenceGate::new();
|
||||
|
||||
// High coherence should pass
|
||||
let hint = GateHint::new(256, false, ComputeClass::Deliberative);
|
||||
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
|
||||
|
||||
// Low coherence should fail
|
||||
let hint = GateHint::new(-512, false, ComputeClass::Deliberative);
|
||||
assert!(matches!(
|
||||
gate.preflight(&hint),
|
||||
GateDecision::Skipped { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_exit_checkpoint() {
|
||||
let gate = DefaultCoherenceGate::new();
|
||||
|
||||
// Layer 0 - too early
|
||||
assert!(gate.checkpoint(0, 1000).is_none());
|
||||
|
||||
// Layer 4 with high signal - should exit
|
||||
let decision = gate.checkpoint(4, 1000);
|
||||
assert!(matches!(
|
||||
decision,
|
||||
Some(GateDecision::EarlyExit { layer: 4 })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_gating() {
|
||||
let gate = DefaultCoherenceGate::new();
|
||||
|
||||
// Full run should allow writes
|
||||
let witness = crate::witness::WitnessLog::empty();
|
||||
assert!(gate.allow_write(&witness));
|
||||
|
||||
// Skipped should not allow writes
|
||||
let mut skipped_witness = crate::witness::WitnessLog::empty();
|
||||
skipped_witness.gate_decision = GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence,
|
||||
};
|
||||
assert!(!gate.allow_write(&skipped_witness));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strict_config() {
|
||||
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::strict());
|
||||
|
||||
// Strict should require positive coherence
|
||||
let hint = GateHint::new(-1, false, ComputeClass::Deliberative);
|
||||
assert!(matches!(
|
||||
gate.preflight(&hint),
|
||||
GateDecision::Skipped { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permissive_config() {
|
||||
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::permissive());
|
||||
|
||||
// Permissive should allow anything
|
||||
let hint = GateHint::new(i16::MIN, true, ComputeClass::Reflex);
|
||||
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
|
||||
}
|
||||
}
|
||||
95
vendor/ruvector/crates/ruvector-fpga-transformer/src/gating/mod.rs
vendored
Normal file
95
vendor/ruvector/crates/ruvector-fpga-transformer/src/gating/mod.rs
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
//! Gating subsystem for coherence-based inference control
|
||||
//!
|
||||
//! Provides preflight and postflight gates that integrate mincut signals
|
||||
//! and write policies for memory safety.
|
||||
|
||||
pub mod coherence_gate;
|
||||
pub mod policy_gate;
|
||||
|
||||
pub use coherence_gate::{CoherenceConfig, CoherenceGate, DefaultCoherenceGate};
|
||||
pub use policy_gate::{DefaultPolicyGate, PolicyGate, WritePolicy};
|
||||
|
||||
use crate::types::{GateDecision, GateHint, SkipReason};
|
||||
use crate::witness::WitnessLog;
|
||||
|
||||
/// Combined gate that checks both coherence and policy
|
||||
pub struct CombinedGate {
|
||||
coherence: Box<dyn CoherenceGate>,
|
||||
policy: Box<dyn PolicyGate>,
|
||||
}
|
||||
|
||||
impl CombinedGate {
|
||||
/// Create a new combined gate
|
||||
pub fn new(coherence: Box<dyn CoherenceGate>, policy: Box<dyn PolicyGate>) -> Self {
|
||||
Self { coherence, policy }
|
||||
}
|
||||
|
||||
/// Create with default implementations
|
||||
pub fn default_gates() -> Self {
|
||||
Self::new(
|
||||
Box::new(DefaultCoherenceGate::new()),
|
||||
Box::new(DefaultPolicyGate::new()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Preflight check before inference
|
||||
pub fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
// First check policy
|
||||
if !self.policy.allow_inference(hint) {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::PolicyDenied,
|
||||
};
|
||||
}
|
||||
|
||||
// Then check coherence
|
||||
self.coherence.preflight(hint)
|
||||
}
|
||||
|
||||
/// Checkpoint during inference
|
||||
pub fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
self.coherence.checkpoint(layer, signal_q)
|
||||
}
|
||||
|
||||
/// Check if write is allowed after inference
|
||||
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
self.coherence.allow_write(witness) && self.policy.allow_write(witness)
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceGate for CombinedGate {
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
CombinedGate::preflight(self, hint)
|
||||
}
|
||||
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
CombinedGate::checkpoint(self, layer, signal_q)
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
CombinedGate::allow_write(self, witness)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_combined_gate_preflight() {
|
||||
let gate = CombinedGate::default_gates();
|
||||
|
||||
// Allow all hint should pass
|
||||
let decision = gate.preflight(&GateHint::allow_all());
|
||||
assert!(matches!(decision, GateDecision::RanFull));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_combined_gate_low_coherence() {
|
||||
let gate = CombinedGate::default_gates();
|
||||
|
||||
// Very low coherence should skip
|
||||
let hint = GateHint::new(-1000, false, crate::types::ComputeClass::Reflex);
|
||||
let decision = gate.preflight(&hint);
|
||||
assert!(matches!(decision, GateDecision::Skipped { .. }));
|
||||
}
|
||||
}
|
||||
305
vendor/ruvector/crates/ruvector-fpga-transformer/src/gating/policy_gate.rs
vendored
Normal file
305
vendor/ruvector/crates/ruvector-fpga-transformer/src/gating/policy_gate.rs
vendored
Normal file
@@ -0,0 +1,305 @@
|
||||
//! Policy-based gating for access control and resource management
|
||||
|
||||
use crate::types::{ComputeClass, GateHint};
|
||||
use crate::witness::WitnessLog;
|
||||
|
||||
/// Trait for policy-based gating
|
||||
pub trait PolicyGate: Send + Sync {
|
||||
/// Check if inference is allowed
|
||||
fn allow_inference(&self, hint: &GateHint) -> bool;
|
||||
|
||||
/// Check if write is allowed after inference
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool;
|
||||
|
||||
/// Get remaining compute budget
|
||||
fn remaining_budget(&self) -> Option<u64>;
|
||||
|
||||
/// Record compute usage
|
||||
fn record_usage(&self, cycles: u32);
|
||||
}
|
||||
|
||||
/// Write policy configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WritePolicy {
|
||||
/// Allow writes after early exit
|
||||
pub allow_early_exit_writes: bool,
|
||||
/// Maximum latency (ns) for write eligibility
|
||||
pub max_latency_ns: u32,
|
||||
/// Require specific backend
|
||||
pub required_backend: Option<crate::types::BackendKind>,
|
||||
/// Minimum compute class for writes
|
||||
pub min_compute_class: ComputeClass,
|
||||
}
|
||||
|
||||
impl Default for WritePolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allow_early_exit_writes: false,
|
||||
max_latency_ns: u32::MAX,
|
||||
required_backend: None,
|
||||
min_compute_class: ComputeClass::Reflex,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WritePolicy {
|
||||
/// Create a strict write policy
|
||||
pub fn strict() -> Self {
|
||||
Self {
|
||||
allow_early_exit_writes: false,
|
||||
max_latency_ns: 10_000_000, // 10ms
|
||||
required_backend: None,
|
||||
min_compute_class: ComputeClass::Deliberative,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permissive write policy
|
||||
pub fn permissive() -> Self {
|
||||
Self {
|
||||
allow_early_exit_writes: true,
|
||||
max_latency_ns: u32::MAX,
|
||||
required_backend: None,
|
||||
min_compute_class: ComputeClass::Reflex,
|
||||
}
|
||||
}
|
||||
|
||||
/// Require FPGA backend for writes
|
||||
pub fn require_fpga(mut self) -> Self {
|
||||
self.required_backend = Some(crate::types::BackendKind::FpgaPcie);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Default policy gate implementation
|
||||
pub struct DefaultPolicyGate {
|
||||
write_policy: WritePolicy,
|
||||
/// Compute budget (total cycles allowed, 0 = unlimited)
|
||||
budget_cycles: std::sync::atomic::AtomicU64,
|
||||
/// Used cycles
|
||||
used_cycles: std::sync::atomic::AtomicU64,
|
||||
}
|
||||
|
||||
impl DefaultPolicyGate {
|
||||
/// Create with default policy
|
||||
pub fn new() -> Self {
|
||||
Self::with_policy(WritePolicy::default())
|
||||
}
|
||||
|
||||
/// Create with custom write policy
|
||||
pub fn with_policy(write_policy: WritePolicy) -> Self {
|
||||
Self {
|
||||
write_policy,
|
||||
budget_cycles: std::sync::atomic::AtomicU64::new(0),
|
||||
used_cycles: std::sync::atomic::AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set compute budget
|
||||
pub fn set_budget(&self, cycles: u64) {
|
||||
self.budget_cycles
|
||||
.store(cycles, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Reset used cycles
|
||||
pub fn reset_usage(&self) {
|
||||
self.used_cycles
|
||||
.store(0, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefaultPolicyGate {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyGate for DefaultPolicyGate {
|
||||
fn allow_inference(&self, hint: &GateHint) -> bool {
|
||||
// Check compute budget
|
||||
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
if budget > 0 {
|
||||
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
if used >= budget {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check compute class restrictions
|
||||
// Always allow reflex, check others based on config
|
||||
hint.max_compute_class >= ComputeClass::Reflex
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
// Check if inference ran
|
||||
if !witness.gate_decision.did_run() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check early exit policy
|
||||
if !self.write_policy.allow_early_exit_writes {
|
||||
if let crate::types::GateDecision::EarlyExit { .. } = witness.gate_decision {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check latency
|
||||
if witness.latency_ns > self.write_policy.max_latency_ns {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check backend requirement
|
||||
if let Some(required) = self.write_policy.required_backend {
|
||||
if witness.backend != required {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn remaining_budget(&self) -> Option<u64> {
|
||||
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
if budget == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
Some(budget.saturating_sub(used))
|
||||
}
|
||||
|
||||
fn record_usage(&self, cycles: u32) {
|
||||
self.used_cycles
|
||||
.fetch_add(cycles as u64, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate-limited policy gate
|
||||
pub struct RateLimitedPolicyGate {
|
||||
base: DefaultPolicyGate,
|
||||
/// Maximum inferences per second
|
||||
max_inferences_per_sec: u32,
|
||||
/// Inference count in current window
|
||||
inference_count: std::sync::atomic::AtomicU32,
|
||||
/// Window start time
|
||||
window_start: std::sync::RwLock<std::time::Instant>,
|
||||
}
|
||||
|
||||
impl RateLimitedPolicyGate {
|
||||
/// Create with rate limit
|
||||
pub fn new(max_inferences_per_sec: u32, write_policy: WritePolicy) -> Self {
|
||||
Self {
|
||||
base: DefaultPolicyGate::with_policy(write_policy),
|
||||
max_inferences_per_sec,
|
||||
inference_count: std::sync::atomic::AtomicU32::new(0),
|
||||
window_start: std::sync::RwLock::new(std::time::Instant::now()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check and update rate limit
|
||||
fn check_rate_limit(&self) -> bool {
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// Check if we need to reset the window
|
||||
{
|
||||
let window_start = self.window_start.read().unwrap();
|
||||
if now.duration_since(*window_start).as_secs() >= 1 {
|
||||
drop(window_start);
|
||||
let mut window_start = self.window_start.write().unwrap();
|
||||
*window_start = now;
|
||||
self.inference_count
|
||||
.store(0, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
// Check count
|
||||
let count = self
|
||||
.inference_count
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
count < self.max_inferences_per_sec
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyGate for RateLimitedPolicyGate {
|
||||
fn allow_inference(&self, hint: &GateHint) -> bool {
|
||||
// Check rate limit first
|
||||
if !self.check_rate_limit() {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.base.allow_inference(hint)
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
self.base.allow_write(witness)
|
||||
}
|
||||
|
||||
fn remaining_budget(&self) -> Option<u64> {
|
||||
self.base.remaining_budget()
|
||||
}
|
||||
|
||||
fn record_usage(&self, cycles: u32) {
|
||||
self.base.record_usage(cycles);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_policy_allows_inference() {
|
||||
let gate = DefaultPolicyGate::new();
|
||||
let hint = GateHint::allow_all();
|
||||
assert!(gate.allow_inference(&hint));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_budget_limiting() {
|
||||
let gate = DefaultPolicyGate::new();
|
||||
gate.set_budget(1000);
|
||||
|
||||
let hint = GateHint::allow_all();
|
||||
|
||||
// Should allow initially
|
||||
assert!(gate.allow_inference(&hint));
|
||||
|
||||
// Record usage exceeding budget
|
||||
gate.record_usage(1500);
|
||||
|
||||
// Should deny now
|
||||
assert!(!gate.allow_inference(&hint));
|
||||
|
||||
// Reset and check again
|
||||
gate.reset_usage();
|
||||
assert!(gate.allow_inference(&hint));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_policy_early_exit() {
|
||||
let gate = DefaultPolicyGate::with_policy(WritePolicy::default());
|
||||
|
||||
let mut witness = crate::witness::WitnessLog::empty();
|
||||
witness.gate_decision = crate::types::GateDecision::EarlyExit { layer: 3 };
|
||||
|
||||
// Default policy denies early exit writes
|
||||
assert!(!gate.allow_write(&witness));
|
||||
|
||||
// Permissive policy allows
|
||||
let permissive = DefaultPolicyGate::with_policy(WritePolicy::permissive());
|
||||
assert!(permissive.allow_write(&witness));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_policy_latency() {
|
||||
let mut policy = WritePolicy::default();
|
||||
policy.max_latency_ns = 1000;
|
||||
let gate = DefaultPolicyGate::with_policy(policy);
|
||||
|
||||
let mut witness = crate::witness::WitnessLog::empty();
|
||||
witness.latency_ns = 500;
|
||||
assert!(gate.allow_write(&witness));
|
||||
|
||||
witness.latency_ns = 2000;
|
||||
assert!(!gate.allow_write(&witness));
|
||||
}
|
||||
}
|
||||
331
vendor/ruvector/crates/ruvector-fpga-transformer/src/lib.rs
vendored
Normal file
331
vendor/ruvector/crates/ruvector-fpga-transformer/src/lib.rs
vendored
Normal file
@@ -0,0 +1,331 @@
|
||||
//! # FPGA Transformer Backend
|
||||
//!
|
||||
//! Ultra low latency transformer inference with FPGA acceleration,
|
||||
//! coherence gating, and deterministic execution.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Deterministic latency paths**: Fixed shape inference with bounded timing
|
||||
//! - **Quantization first design**: Explicit INT4/INT8 quantization with reproducible math
|
||||
//! - **Zero allocation hot path**: No heap allocations during inference
|
||||
//! - **Coherence gating**: Mincut-integrated gate decisions
|
||||
//! - **Multiple backends**: FPGA PCIe, FPGA Daemon, Native Sim, WASM Sim
|
||||
//! - **Witness logging**: Auditable inference with ReasoningBank integration
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use ruvector_fpga_transformer::{Engine, artifact::ModelArtifact};
|
||||
//! use ruvector_fpga_transformer::backend::native_sim::NativeSimBackend;
|
||||
//! use ruvector_fpga_transformer::gating::DefaultCoherenceGate;
|
||||
//! use ruvector_fpga_transformer::types::{InferenceRequest, GateHint, FixedShape};
|
||||
//! use std::sync::Arc;
|
||||
//!
|
||||
//! // Create backend and gate
|
||||
//! let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
//! let backend = NativeSimBackend::new(gate.clone());
|
||||
//!
|
||||
//! // Create engine
|
||||
//! let mut engine = Engine::new(Box::new(backend), gate);
|
||||
//!
|
||||
//! // Load artifact (from file or bytes)
|
||||
//! // let model_id = engine.load_artifact(&artifact_bytes)?;
|
||||
//!
|
||||
//! // Run inference
|
||||
//! // let result = engine.infer(request)?;
|
||||
//! ```
|
||||
//!
|
||||
//! ## Backend Selection
|
||||
//!
|
||||
//! The crate supports multiple backends selected at runtime:
|
||||
//!
|
||||
//! - `FpgaPcie`: Direct PCIe access to FPGA (requires `pcie` feature)
|
||||
//! - `FpgaDaemon`: Communication via local daemon (requires `daemon` feature)
|
||||
//! - `NativeSim`: Pure Rust simulator (requires `native_sim` feature)
|
||||
//! - `WasmSim`: WASM-compatible simulator (requires `wasm` feature)
|
||||
//!
|
||||
//! ## Artifact Format
|
||||
//!
|
||||
//! Models are packaged as signed artifacts containing:
|
||||
//! - Manifest with shape and quantization metadata
|
||||
//! - Quantized weights
|
||||
//! - Optional FPGA bitstream
|
||||
//! - Test vectors for validation
|
||||
//! - Ed25519 signature
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(feature = "wasm", allow(unused_imports))]
|
||||
|
||||
pub mod artifact;
|
||||
pub mod backend;
|
||||
pub mod error;
|
||||
pub mod ffi;
|
||||
pub mod gating;
|
||||
pub mod quant;
|
||||
pub mod types;
|
||||
pub mod witness;
|
||||
|
||||
pub use artifact::ModelArtifact;
|
||||
pub use backend::TransformerBackend;
|
||||
pub use error::{Error, Result};
|
||||
pub use gating::CoherenceGate;
|
||||
pub use types::{
|
||||
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
|
||||
InferenceResult, Layout, ModelId, QuantSpec, QuantizedTensor, SkipReason, WitnessLog,
|
||||
};
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Crate version
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Main engine for FPGA transformer inference
|
||||
///
|
||||
/// The engine combines a backend (FPGA, simulator, etc.) with a coherence gate
|
||||
/// for controlled inference execution.
|
||||
pub struct Engine {
|
||||
/// Backend for inference execution
|
||||
backend: Box<dyn TransformerBackend>,
|
||||
/// Coherence gate for decision making
|
||||
gate: Arc<dyn CoherenceGate>,
|
||||
/// Loaded models
|
||||
models: std::collections::HashMap<ModelId, ModelInfo>,
|
||||
/// Inference statistics
|
||||
stats: EngineStats,
|
||||
}
|
||||
|
||||
/// Information about a loaded model
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelInfo {
|
||||
/// Model artifact
|
||||
pub artifact: ModelArtifact,
|
||||
/// Shape configuration
|
||||
pub shape: FixedShape,
|
||||
/// Quantization spec
|
||||
pub quant: QuantSpec,
|
||||
}
|
||||
|
||||
/// Engine statistics
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct EngineStats {
|
||||
/// Total inferences
|
||||
pub total_inferences: u64,
|
||||
/// Successful inferences
|
||||
pub successful: u64,
|
||||
/// Skipped inferences
|
||||
pub skipped: u64,
|
||||
/// Early exits
|
||||
pub early_exits: u64,
|
||||
/// Total latency (ns)
|
||||
pub total_latency_ns: u64,
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
/// Create a new engine with the specified backend and gate
|
||||
pub fn new(backend: Box<dyn TransformerBackend>, gate: Arc<dyn CoherenceGate>) -> Self {
|
||||
Self {
|
||||
backend,
|
||||
gate,
|
||||
models: std::collections::HashMap::new(),
|
||||
stats: EngineStats::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default native simulator backend
|
||||
#[cfg(feature = "native_sim")]
|
||||
pub fn native_sim() -> Self {
|
||||
let gate = Arc::new(gating::DefaultCoherenceGate::new());
|
||||
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
|
||||
Self::new(backend, gate)
|
||||
}
|
||||
|
||||
/// Load a model artifact from bytes
|
||||
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<ModelId> {
|
||||
let artifact = artifact::unpack_artifact(artifact_bytes)?;
|
||||
self.load(&artifact)
|
||||
}
|
||||
|
||||
/// Load a model artifact
|
||||
pub fn load(&mut self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
// Load into backend
|
||||
let model_id = self.backend.load(artifact)?;
|
||||
|
||||
// Store info
|
||||
self.models.insert(
|
||||
model_id,
|
||||
ModelInfo {
|
||||
artifact: artifact.clone(),
|
||||
shape: artifact.manifest.shape,
|
||||
quant: artifact.manifest.quant,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
/// Run inference
|
||||
pub fn infer(&mut self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
self.stats.total_inferences += 1;
|
||||
|
||||
// Check preflight gate
|
||||
let preflight = self.gate.preflight(&req.gate_hint);
|
||||
if let GateDecision::Skipped { reason } = preflight {
|
||||
self.stats.skipped += 1;
|
||||
return Err(Error::GateBlocked { reason });
|
||||
}
|
||||
|
||||
// Run inference
|
||||
let result = self.backend.infer(req)?;
|
||||
|
||||
// Update stats
|
||||
self.stats.total_latency_ns += result.witness.latency_ns as u64;
|
||||
match result.witness.gate_decision {
|
||||
GateDecision::RanFull => self.stats.successful += 1,
|
||||
GateDecision::EarlyExit { .. } => {
|
||||
self.stats.successful += 1;
|
||||
self.stats.early_exits += 1;
|
||||
}
|
||||
GateDecision::Skipped { .. } => self.stats.skipped += 1,
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Unload a model
|
||||
pub fn unload(&mut self, model: ModelId) -> Result<()> {
|
||||
self.backend.unload(model)?;
|
||||
self.models.remove(&model);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get model shape
|
||||
pub fn shape(&self, model: ModelId) -> Result<FixedShape> {
|
||||
self.models
|
||||
.get(&model)
|
||||
.map(|info| info.shape)
|
||||
.ok_or_else(|| Error::ModelNotFound(model))
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
pub fn model_info(&self, model: ModelId) -> Option<&ModelInfo> {
|
||||
self.models.get(&model)
|
||||
}
|
||||
|
||||
/// Check if model is loaded
|
||||
pub fn is_loaded(&self, model: ModelId) -> bool {
|
||||
self.models.contains_key(&model)
|
||||
}
|
||||
|
||||
/// Get list of loaded models
|
||||
pub fn loaded_models(&self) -> Vec<ModelId> {
|
||||
self.models.keys().copied().collect()
|
||||
}
|
||||
|
||||
/// Get engine statistics
|
||||
pub fn stats(&self) -> &EngineStats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
/// Get backend statistics
|
||||
pub fn backend_stats(&self) -> backend::BackendStats {
|
||||
self.backend.stats()
|
||||
}
|
||||
|
||||
/// Get backend kind
|
||||
pub fn backend_kind(&self) -> BackendKind {
|
||||
self.backend.kind()
|
||||
}
|
||||
|
||||
/// Check if write is allowed based on witness
|
||||
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
self.gate.allow_write(witness)
|
||||
}
|
||||
|
||||
/// Reset statistics
|
||||
pub fn reset_stats(&mut self) {
|
||||
self.stats = EngineStats::default();
|
||||
}
|
||||
}
|
||||
|
||||
impl EngineStats {
|
||||
/// Get average latency in nanoseconds
|
||||
pub fn avg_latency_ns(&self) -> f64 {
|
||||
if self.successful == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_latency_ns as f64 / self.successful as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average latency in milliseconds
|
||||
pub fn avg_latency_ms(&self) -> f64 {
|
||||
self.avg_latency_ns() / 1_000_000.0
|
||||
}
|
||||
|
||||
/// Get success rate
|
||||
pub fn success_rate(&self) -> f64 {
|
||||
if self.total_inferences == 0 {
|
||||
1.0
|
||||
} else {
|
||||
self.successful as f64 / self.total_inferences as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get early exit rate
|
||||
pub fn early_exit_rate(&self) -> f64 {
|
||||
if self.successful == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.early_exits as f64 / self.successful as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prelude for convenient imports
|
||||
pub mod prelude {
|
||||
pub use crate::{
|
||||
artifact::ModelArtifact,
|
||||
backend::TransformerBackend,
|
||||
gating::CoherenceGate,
|
||||
types::{
|
||||
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
|
||||
InferenceResult, ModelId, QuantSpec, SkipReason, WitnessLog,
|
||||
},
|
||||
Engine, Error, Result,
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_engine_creation() {
|
||||
let gate = Arc::new(gating::DefaultCoherenceGate::new());
|
||||
|
||||
#[cfg(feature = "native_sim")]
|
||||
{
|
||||
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
|
||||
let engine = Engine::new(backend, gate);
|
||||
assert!(engine.loaded_models().is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_stats() {
|
||||
let stats = EngineStats {
|
||||
total_inferences: 100,
|
||||
successful: 80,
|
||||
skipped: 10,
|
||||
early_exits: 20,
|
||||
total_latency_ns: 8_000_000,
|
||||
};
|
||||
|
||||
assert!((stats.success_rate() - 0.8).abs() < 0.01);
|
||||
assert!((stats.early_exit_rate() - 0.25).abs() < 0.01);
|
||||
assert!((stats.avg_latency_ns() - 100_000.0).abs() < 1.0);
|
||||
}
|
||||
}
|
||||
303
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/calib.rs
vendored
Normal file
303
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/calib.rs
vendored
Normal file
@@ -0,0 +1,303 @@
|
||||
//! Calibration data for quantization
|
||||
|
||||
use crate::error::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Calibration data for a model
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CalibrationData {
|
||||
/// Layer-wise activation statistics
|
||||
pub layers: Vec<LayerCalibration>,
|
||||
/// Global input statistics
|
||||
pub input_stats: ActivationStats,
|
||||
/// Number of calibration samples used
|
||||
pub num_samples: usize,
|
||||
/// Calibration method used
|
||||
pub method: CalibrationMethod,
|
||||
}
|
||||
|
||||
/// Per-layer calibration data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LayerCalibration {
|
||||
/// Layer index
|
||||
pub layer_idx: usize,
|
||||
/// Layer name
|
||||
pub name: String,
|
||||
/// Activation statistics after this layer
|
||||
pub activation_stats: ActivationStats,
|
||||
/// Weight statistics for this layer
|
||||
pub weight_stats: WeightStats,
|
||||
/// Optimal scale for activations (Q16.16)
|
||||
pub act_scale: i32,
|
||||
/// Optimal scale for weights (Q16.16)
|
||||
pub weight_scale: i32,
|
||||
}
|
||||
|
||||
/// Activation statistics
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ActivationStats {
|
||||
/// Minimum value seen
|
||||
pub min: f32,
|
||||
/// Maximum value seen
|
||||
pub max: f32,
|
||||
/// Mean value
|
||||
pub mean: f32,
|
||||
/// Standard deviation
|
||||
pub std: f32,
|
||||
/// Histogram bins (for entropy calibration)
|
||||
#[serde(default)]
|
||||
pub histogram: Vec<u32>,
|
||||
/// Histogram bin edges
|
||||
#[serde(default)]
|
||||
pub bin_edges: Vec<f32>,
|
||||
}
|
||||
|
||||
impl ActivationStats {
|
||||
/// Create empty stats
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Update stats with a batch of values
|
||||
pub fn update(&mut self, values: &[f32]) {
|
||||
if values.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update min/max
|
||||
for &v in values {
|
||||
if v < self.min || self.min == 0.0 {
|
||||
self.min = v;
|
||||
}
|
||||
if v > self.max {
|
||||
self.max = v;
|
||||
}
|
||||
}
|
||||
|
||||
// Update running mean and std
|
||||
let n = values.len() as f32;
|
||||
let batch_mean = values.iter().sum::<f32>() / n;
|
||||
let batch_var = values.iter().map(|v| (v - batch_mean).powi(2)).sum::<f32>() / n;
|
||||
|
||||
// Simple update (not online algorithm)
|
||||
self.mean = batch_mean;
|
||||
self.std = batch_var.sqrt();
|
||||
}
|
||||
|
||||
/// Compute optimal scale for symmetric quantization to n bits
|
||||
pub fn optimal_scale(&self, bits: u8) -> f32 {
|
||||
let max_range = self.max.abs().max(self.min.abs());
|
||||
let qmax = (1 << (bits - 1)) as f32 - 1.0;
|
||||
max_range / qmax
|
||||
}
|
||||
}
|
||||
|
||||
/// Weight statistics
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct WeightStats {
|
||||
/// Min weight value
|
||||
pub min: f32,
|
||||
/// Max weight value
|
||||
pub max: f32,
|
||||
/// Sparsity (fraction of zeros)
|
||||
pub sparsity: f32,
|
||||
}
|
||||
|
||||
impl WeightStats {
|
||||
/// Compute from weight tensor
|
||||
pub fn from_weights(weights: &[f32]) -> Self {
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
let mut zeros = 0usize;
|
||||
|
||||
for &w in weights {
|
||||
if w < min {
|
||||
min = w;
|
||||
}
|
||||
if w > max {
|
||||
max = w;
|
||||
}
|
||||
if w.abs() < 1e-6 {
|
||||
zeros += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
min,
|
||||
max,
|
||||
sparsity: zeros as f32 / weights.len() as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calibration method
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum CalibrationMethod {
|
||||
/// Use min/max of observed values
|
||||
MinMax,
|
||||
/// Use percentile clipping (e.g., 99.9%)
|
||||
Percentile(u32), // 999 = 99.9%
|
||||
/// Entropy-based calibration (KL divergence)
|
||||
Entropy,
|
||||
/// Mean-squared error minimization
|
||||
Mse,
|
||||
}
|
||||
|
||||
impl Default for CalibrationMethod {
|
||||
fn default() -> Self {
|
||||
Self::MinMax
|
||||
}
|
||||
}
|
||||
|
||||
impl CalibrationData {
|
||||
/// Create empty calibration data
|
||||
pub fn new(method: CalibrationMethod) -> Self {
|
||||
Self {
|
||||
layers: Vec::new(),
|
||||
input_stats: ActivationStats::new(),
|
||||
num_samples: 0,
|
||||
method,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add layer calibration
|
||||
pub fn add_layer(&mut self, calib: LayerCalibration) {
|
||||
self.layers.push(calib);
|
||||
}
|
||||
|
||||
/// Serialize to bytes
|
||||
pub fn to_bytes(&self) -> Result<Vec<u8>> {
|
||||
Ok(serde_json::to_vec(self)?)
|
||||
}
|
||||
|
||||
/// Deserialize from bytes
|
||||
pub fn from_bytes(data: &[u8]) -> Result<Self> {
|
||||
Ok(serde_json::from_slice(data)?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calibrate a model by collecting activation statistics
|
||||
pub fn calibrate_model<F>(
|
||||
run_inference: F,
|
||||
calibration_inputs: &[Vec<u16>],
|
||||
num_layers: usize,
|
||||
method: CalibrationMethod,
|
||||
) -> Result<CalibrationData>
|
||||
where
|
||||
F: Fn(&[u16]) -> Result<Vec<Vec<f32>>>, // Returns activations per layer
|
||||
{
|
||||
let mut calibration = CalibrationData::new(method);
|
||||
|
||||
// Initialize layer stats
|
||||
let mut layer_stats: Vec<ActivationStats> =
|
||||
(0..num_layers).map(|_| ActivationStats::new()).collect();
|
||||
|
||||
// Run calibration passes
|
||||
for input in calibration_inputs {
|
||||
// Run inference and collect activations
|
||||
let activations = run_inference(input)?;
|
||||
|
||||
// Update statistics
|
||||
for (layer_idx, layer_act) in activations.iter().enumerate() {
|
||||
if layer_idx < num_layers {
|
||||
layer_stats[layer_idx].update(layer_act);
|
||||
}
|
||||
}
|
||||
|
||||
calibration.num_samples += 1;
|
||||
}
|
||||
|
||||
// Create layer calibrations
|
||||
for (layer_idx, stats) in layer_stats.into_iter().enumerate() {
|
||||
let act_scale = match method {
|
||||
CalibrationMethod::MinMax => stats.optimal_scale(8),
|
||||
CalibrationMethod::Percentile(_) => stats.optimal_scale(8) * 0.99,
|
||||
CalibrationMethod::Entropy => stats.optimal_scale(8),
|
||||
CalibrationMethod::Mse => stats.optimal_scale(8),
|
||||
};
|
||||
|
||||
calibration.add_layer(LayerCalibration {
|
||||
layer_idx,
|
||||
name: format!("layer_{}", layer_idx),
|
||||
activation_stats: stats,
|
||||
weight_stats: WeightStats::default(),
|
||||
act_scale: (act_scale * 65536.0) as i32,
|
||||
weight_scale: 65536, // Default 1.0
|
||||
});
|
||||
}
|
||||
|
||||
Ok(calibration)
|
||||
}
|
||||
|
||||
/// Apply percentile clipping to calibration
|
||||
pub fn apply_percentile(stats: &ActivationStats, percentile: f32) -> (f32, f32) {
|
||||
if stats.histogram.is_empty() || stats.bin_edges.len() < 2 {
|
||||
return (stats.min, stats.max);
|
||||
}
|
||||
|
||||
let total: u32 = stats.histogram.iter().sum();
|
||||
let target_low = (total as f32 * (1.0 - percentile) / 2.0) as u32;
|
||||
let target_high = (total as f32 * (1.0 + percentile) / 2.0) as u32;
|
||||
|
||||
let mut cumsum = 0u32;
|
||||
let mut low_idx = 0;
|
||||
let mut high_idx = stats.histogram.len() - 1;
|
||||
|
||||
for (i, &count) in stats.histogram.iter().enumerate() {
|
||||
cumsum += count;
|
||||
if cumsum >= target_low && low_idx == 0 {
|
||||
low_idx = i;
|
||||
}
|
||||
if cumsum >= target_high {
|
||||
high_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(stats.bin_edges[low_idx], stats.bin_edges[high_idx + 1])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_activation_stats_update() {
|
||||
let mut stats = ActivationStats::new();
|
||||
stats.update(&[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
|
||||
assert_eq!(stats.min, 1.0);
|
||||
assert_eq!(stats.max, 5.0);
|
||||
assert!((stats.mean - 3.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimal_scale() {
|
||||
let mut stats = ActivationStats::new();
|
||||
stats.min = -1.0;
|
||||
stats.max = 1.0;
|
||||
|
||||
let scale = stats.optimal_scale(8);
|
||||
// For 8-bit, qmax = 127, so scale should be 1.0/127 ≈ 0.00787
|
||||
assert!((scale - 1.0 / 127.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weight_stats() {
|
||||
let weights = vec![0.0, 0.1, -0.1, 0.5, -0.5, 0.0];
|
||||
let stats = WeightStats::from_weights(&weights);
|
||||
|
||||
assert_eq!(stats.min, -0.5);
|
||||
assert_eq!(stats.max, 0.5);
|
||||
assert!((stats.sparsity - 2.0 / 6.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibration_serialization() {
|
||||
let calib = CalibrationData::new(CalibrationMethod::MinMax);
|
||||
let bytes = calib.to_bytes().unwrap();
|
||||
let restored = CalibrationData::from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(calib.method, restored.method);
|
||||
}
|
||||
}
|
||||
336
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/lut.rs
vendored
Normal file
336
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/lut.rs
vendored
Normal file
@@ -0,0 +1,336 @@
|
||||
//! Lookup table implementations for fixed-point operations
|
||||
//!
|
||||
//! Provides LUT-based exp, log, and softmax for deterministic computation.
|
||||
|
||||
/// LUT-based exponential function
|
||||
/// Input: Q8.8 fixed point [-16, 16)
|
||||
/// Output: Q0.16 fixed point [0, 1)
|
||||
const EXP_LUT_SIZE: usize = 256;
|
||||
const EXP_LUT_SHIFT: i32 = 8; // Q8.8 input
|
||||
|
||||
/// Precomputed exp LUT for range [-8, 8) in Q8.8
|
||||
static EXP_LUT: [u16; EXP_LUT_SIZE] = generate_exp_lut();
|
||||
|
||||
/// Generate exp LUT at compile time
|
||||
const fn generate_exp_lut() -> [u16; EXP_LUT_SIZE] {
|
||||
let mut lut = [0u16; EXP_LUT_SIZE];
|
||||
let mut i = 0;
|
||||
while i < EXP_LUT_SIZE {
|
||||
// Convert index to Q8.8 value (range -128..128 in fixed point = -0.5..0.5)
|
||||
let x_q = (i as i32) - 128;
|
||||
// Scale to get reasonable exp range
|
||||
let x_f = (x_q as f64) / 32.0; // x in [-4, 4)
|
||||
|
||||
// Compute exp and scale to Q0.16
|
||||
let exp_val = const_exp(x_f);
|
||||
let scaled = exp_val / (1.0 + const_exp(4.0)); // Normalize
|
||||
|
||||
// Convert to u16
|
||||
lut[i] = if scaled > 1.0 {
|
||||
65535
|
||||
} else if scaled < 0.0 {
|
||||
0
|
||||
} else {
|
||||
(scaled * 65535.0) as u16
|
||||
};
|
||||
|
||||
i += 1;
|
||||
}
|
||||
lut
|
||||
}
|
||||
|
||||
/// Const-compatible exp approximation using Taylor series
|
||||
const fn const_exp(x: f64) -> f64 {
|
||||
// exp(x) ≈ 1 + x + x²/2 + x³/6 + x⁴/24 + x⁵/120
|
||||
let x2 = x * x;
|
||||
let x3 = x2 * x;
|
||||
let x4 = x3 * x;
|
||||
let x5 = x4 * x;
|
||||
|
||||
1.0 + x + x2 / 2.0 + x3 / 6.0 + x4 / 24.0 + x5 / 120.0
|
||||
}
|
||||
|
||||
/// LUT-based exponential
|
||||
/// Input: i16 in Q8.8 format
|
||||
/// Output: u16 in Q0.16 format
|
||||
#[inline]
|
||||
pub fn exp_lut(x: i16) -> u16 {
|
||||
// Clamp to LUT range
|
||||
let clamped = x.clamp(-128 * 256, 127 * 256);
|
||||
// Scale to LUT index
|
||||
let idx = ((clamped >> EXP_LUT_SHIFT) + 128) as usize;
|
||||
EXP_LUT[idx.min(EXP_LUT_SIZE - 1)]
|
||||
}
|
||||
|
||||
/// Log LUT for Q0.16 input
|
||||
static LOG_LUT: [i16; 256] = generate_log_lut();
|
||||
|
||||
const fn generate_log_lut() -> [i16; 256] {
|
||||
let mut lut = [0i16; 256];
|
||||
let mut i = 1;
|
||||
while i < 256 {
|
||||
// Input is scaled by 256, so x = i/256 in [0.004, 1)
|
||||
let x = (i as f64) / 256.0;
|
||||
// log(x) in Q8.8 format
|
||||
let log_val = const_ln(x);
|
||||
lut[i] = (log_val * 256.0) as i16;
|
||||
i += 1;
|
||||
}
|
||||
lut[0] = i16::MIN; // log(0) = -inf, use min value
|
||||
lut
|
||||
}
|
||||
|
||||
/// Const-compatible natural log approximation
|
||||
const fn const_ln(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
return f64::NEG_INFINITY;
|
||||
}
|
||||
// Use series expansion around x=1: ln(x) = 2 * sum((x-1)/(x+1))^(2n+1)/(2n+1)
|
||||
let y = (x - 1.0) / (x + 1.0);
|
||||
let y2 = y * y;
|
||||
|
||||
// ln(x) ≈ 2 * (y + y³/3 + y⁵/5 + y⁷/7 + y⁹/9)
|
||||
let y3 = y2 * y;
|
||||
let y5 = y3 * y2;
|
||||
let y7 = y5 * y2;
|
||||
let y9 = y7 * y2;
|
||||
|
||||
2.0 * (y + y3 / 3.0 + y5 / 5.0 + y7 / 7.0 + y9 / 9.0)
|
||||
}
|
||||
|
||||
/// LUT-based natural log
|
||||
/// Input: u16 in Q0.16 format (0 to 65535 = 0.0 to ~1.0)
|
||||
/// Output: i16 in Q8.8 format
|
||||
#[inline]
|
||||
pub fn log_lut(x: u16) -> i16 {
|
||||
if x == 0 {
|
||||
return i16::MIN;
|
||||
}
|
||||
// Scale to LUT index
|
||||
let idx = (x >> 8) as usize;
|
||||
LOG_LUT[idx.min(255)]
|
||||
}
|
||||
|
||||
/// Softmax using LUT-based exp
|
||||
/// Operates in-place on Q8.8 logits, outputs Q0.16 probabilities
|
||||
pub fn softmax_lut_q(logits: &mut [i16]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find max for numerical stability
|
||||
let max = *logits.iter().max().unwrap_or(&0);
|
||||
|
||||
// Compute exp(x - max) using LUT
|
||||
let mut sum: u32 = 0;
|
||||
let mut exp_values: Vec<u16> = Vec::with_capacity(logits.len());
|
||||
|
||||
for &logit in logits.iter() {
|
||||
let shifted = logit.saturating_sub(max);
|
||||
let exp_val = exp_lut(shifted);
|
||||
exp_values.push(exp_val);
|
||||
sum += exp_val as u32;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
if sum == 0 {
|
||||
sum = 1;
|
||||
}
|
||||
|
||||
for (i, logit) in logits.iter_mut().enumerate() {
|
||||
let prob = ((exp_values[i] as u64 * 65535) / sum as u64) as u16;
|
||||
*logit = prob as i16;
|
||||
}
|
||||
}
|
||||
|
||||
/// Softmax on f32 values using LUT (for compatibility)
|
||||
pub fn softmax_lut(logits: &mut [f32]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find max
|
||||
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Compute exp
|
||||
let mut sum = 0.0f32;
|
||||
for v in logits.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
if sum > 0.0 {
|
||||
for v in logits.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Piecewise linear softmax approximation
|
||||
/// More accurate than LUT but still deterministic
|
||||
pub fn softmax_pwl(logits: &mut [i16]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let max = *logits.iter().max().unwrap_or(&0);
|
||||
|
||||
// Piecewise linear exp approximation
|
||||
// exp(x) ≈ 1 + x for x near 0
|
||||
// exp(x) ≈ 2^(x/ln2) for larger x
|
||||
let mut sum: i64 = 0;
|
||||
let mut exp_values: Vec<i32> = Vec::with_capacity(logits.len());
|
||||
|
||||
for &logit in logits.iter() {
|
||||
let x = (logit - max) as i32; // x <= 0
|
||||
|
||||
// Piecewise approximation (in Q8.8)
|
||||
let exp_val = if x >= -256 {
|
||||
// x in [-1, 0] -> linear: 1 + x
|
||||
(256 + x).max(0) as i32
|
||||
} else if x >= -2048 {
|
||||
// x in [-8, -1] -> exponential decay
|
||||
let shifted = (x + 2048) >> 3; // Scale to 0-256 range
|
||||
(shifted * shifted / 256).max(1) as i32
|
||||
} else {
|
||||
// x < -8 -> essentially zero
|
||||
1
|
||||
};
|
||||
|
||||
exp_values.push(exp_val);
|
||||
sum += exp_val as i64;
|
||||
}
|
||||
|
||||
// Normalize to Q0.16
|
||||
if sum == 0 {
|
||||
sum = 1;
|
||||
}
|
||||
|
||||
for (i, logit) in logits.iter_mut().enumerate() {
|
||||
let prob = (exp_values[i] as i64 * 65535 / sum) as i16;
|
||||
*logit = prob;
|
||||
}
|
||||
}
|
||||
|
||||
/// GELU approximation using LUT
|
||||
/// GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||
pub fn gelu_lut(x: i16) -> i16 {
|
||||
// Simplified approximation: GELU(x) ≈ x * sigmoid(1.702 * x)
|
||||
let scaled = ((x as i32 * 435) >> 8) as i16; // 1.702 * x in Q8.8
|
||||
let sigmoid_val = sigmoid_lut(scaled);
|
||||
((x as i32 * sigmoid_val as i32) >> 15) as i16
|
||||
}
|
||||
|
||||
/// Sigmoid LUT
|
||||
static SIGMOID_LUT: [u16; 256] = generate_sigmoid_lut();
|
||||
|
||||
const fn generate_sigmoid_lut() -> [u16; 256] {
|
||||
let mut lut = [0u16; 256];
|
||||
let mut i = 0;
|
||||
while i < 256 {
|
||||
// Map index to x in [-8, 8)
|
||||
let x = ((i as i32) - 128) as f64 / 16.0;
|
||||
// sigmoid(x) = 1 / (1 + exp(-x))
|
||||
let sig = 1.0 / (1.0 + const_exp(-x));
|
||||
lut[i] = (sig * 65535.0) as u16;
|
||||
i += 1;
|
||||
}
|
||||
lut
|
||||
}
|
||||
|
||||
/// LUT-based sigmoid
|
||||
/// Input: i16 in Q8.8 format
|
||||
/// Output: u16 in Q0.16 format
|
||||
#[inline]
|
||||
pub fn sigmoid_lut(x: i16) -> u16 {
|
||||
// Scale to LUT range
|
||||
let idx = (((x >> 5) + 128) as usize).min(255);
|
||||
SIGMOID_LUT[idx]
|
||||
}
|
||||
|
||||
/// SiLU (Swish) using sigmoid LUT
|
||||
/// SiLU(x) = x * sigmoid(x)
|
||||
#[inline]
|
||||
pub fn silu_lut(x: i16) -> i16 {
|
||||
let sigmoid_val = sigmoid_lut(x);
|
||||
((x as i32 * sigmoid_val as i32) >> 16) as i16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exp_lut() {
|
||||
// exp(0) should return a non-zero value
|
||||
let result = exp_lut(0);
|
||||
assert!(result > 0, "exp(0) should be positive");
|
||||
|
||||
// exp is monotonically increasing
|
||||
let result_neg = exp_lut(-256); // -1.0 in Q8.8
|
||||
let result_zero = exp_lut(0);
|
||||
let result_pos = exp_lut(256); // 1.0 in Q8.8
|
||||
assert!(
|
||||
result_neg <= result_zero,
|
||||
"exp should be monotonically increasing"
|
||||
);
|
||||
assert!(
|
||||
result_zero <= result_pos,
|
||||
"exp should be monotonically increasing"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid_lut() {
|
||||
// sigmoid(0) = 0.5
|
||||
let result = sigmoid_lut(0);
|
||||
let expected = 32768u16; // 0.5 in Q0.16
|
||||
assert!(
|
||||
(result as i32 - expected as i32).abs() < 5000,
|
||||
"sigmoid(0) ≈ 0.5"
|
||||
);
|
||||
|
||||
// sigmoid is monotonically increasing
|
||||
let result_neg = sigmoid_lut(-1024);
|
||||
let result_zero = sigmoid_lut(0);
|
||||
let result_pos = sigmoid_lut(1024);
|
||||
assert!(
|
||||
result_neg < result_zero,
|
||||
"sigmoid should be monotonically increasing"
|
||||
);
|
||||
assert!(
|
||||
result_zero < result_pos,
|
||||
"sigmoid should be monotonically increasing"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_lut() {
|
||||
let mut logits = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
softmax_lut(&mut logits);
|
||||
|
||||
// Sum should be 1.0
|
||||
let sum: f32 = logits.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 0.01);
|
||||
|
||||
// Should be increasing
|
||||
for i in 1..logits.len() {
|
||||
assert!(logits[i] > logits[i - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_lut() {
|
||||
// GELU(0) should be approximately 0
|
||||
assert!(gelu_lut(0).abs() < 100);
|
||||
|
||||
// GELU maintains sign
|
||||
let neg_result = gelu_lut(-256);
|
||||
assert!(neg_result <= 0, "GELU of negative should be non-positive");
|
||||
|
||||
// GELU of positive values should be positive
|
||||
let pos_result = gelu_lut(256);
|
||||
assert!(pos_result > 0, "GELU of positive should be positive");
|
||||
}
|
||||
}
|
||||
238
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/mod.rs
vendored
Normal file
238
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/mod.rs
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
//! Quantization subsystem
|
||||
//!
|
||||
//! Explicit, reproducible quantization for weights and activations.
|
||||
|
||||
pub mod calib;
|
||||
pub mod lut;
|
||||
pub mod qformat;
|
||||
|
||||
pub use calib::{calibrate_model, CalibrationData};
|
||||
pub use lut::{exp_lut, log_lut, softmax_lut};
|
||||
pub use qformat::{dequantize_i16, dequantize_i8, quantize_i16, quantize_i8};
|
||||
|
||||
use crate::types::QuantSpec;
|
||||
|
||||
/// Fixed-point Q15 format (1.15)
|
||||
/// Range: [-1.0, 1.0 - 2^-15]
|
||||
/// Resolution: 2^-15 ≈ 3.05e-5
|
||||
pub type Q15 = i16;
|
||||
|
||||
/// Fixed-point Q16.16 format
|
||||
/// Range: [-32768.0, 32767.999...]
|
||||
/// Resolution: 2^-16 ≈ 1.53e-5
|
||||
pub type Q16_16 = i32;
|
||||
|
||||
/// Convert f32 to Q15
|
||||
#[inline]
|
||||
pub fn f32_to_q15(x: f32) -> Q15 {
|
||||
(x.clamp(-1.0, 1.0 - f32::EPSILON) * 32768.0) as Q15
|
||||
}
|
||||
|
||||
/// Convert Q15 to f32
|
||||
#[inline]
|
||||
pub fn q15_to_f32(x: Q15) -> f32 {
|
||||
x as f32 / 32768.0
|
||||
}
|
||||
|
||||
/// Convert f32 to Q16.16
|
||||
#[inline]
|
||||
pub fn f32_to_q16_16(x: f32) -> Q16_16 {
|
||||
(x * 65536.0) as Q16_16
|
||||
}
|
||||
|
||||
/// Convert Q16.16 to f32
|
||||
#[inline]
|
||||
pub fn q16_16_to_f32(x: Q16_16) -> f32 {
|
||||
x as f32 / 65536.0
|
||||
}
|
||||
|
||||
/// Fixed-point multiplication Q15 * Q15 -> Q15
|
||||
#[inline]
|
||||
pub fn q15_mul(a: Q15, b: Q15) -> Q15 {
|
||||
// Multiply with proper rounding
|
||||
let product = (a as i32 * b as i32 + 0x4000) >> 15;
|
||||
product.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
|
||||
}
|
||||
|
||||
/// Fixed-point dot product with accumulator
|
||||
/// Note: For very large vectors (>65536 elements), use q15_dot_saturating
|
||||
#[inline]
|
||||
pub fn q15_dot(a: &[Q15], b: &[Q15]) -> i32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(&x, &y)| x as i32 * y as i32)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Saturating fixed-point dot product (prevents overflow for large vectors)
|
||||
#[inline]
|
||||
pub fn q15_dot_saturating(a: &[Q15], b: &[Q15]) -> i32 {
|
||||
a.iter().zip(b.iter()).fold(0i32, |acc, (&x, &y)| {
|
||||
acc.saturating_add((x as i32).saturating_mul(y as i32))
|
||||
})
|
||||
}
|
||||
|
||||
/// Fixed-point dot product normalized to Q15
|
||||
#[inline]
|
||||
pub fn q15_dot_normalized(a: &[Q15], b: &[Q15], shift: u8) -> Q15 {
|
||||
let sum = q15_dot(a, b);
|
||||
let shifted = (sum + (1 << (shift - 1))) >> shift;
|
||||
shifted.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
|
||||
}
|
||||
|
||||
/// Quantization context for a layer
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantContext {
|
||||
/// Weight quantization spec
|
||||
pub weight_spec: QuantSpec,
|
||||
/// Input scale (Q16.16)
|
||||
pub input_scale: Q16_16,
|
||||
/// Output scale (Q16.16)
|
||||
pub output_scale: Q16_16,
|
||||
/// Accumulator bit width
|
||||
pub acc_bits: u8,
|
||||
/// Right shift for normalization
|
||||
pub norm_shift: u8,
|
||||
}
|
||||
|
||||
impl QuantContext {
|
||||
/// Create from QuantSpec
|
||||
pub fn from_spec(spec: &QuantSpec) -> Self {
|
||||
Self {
|
||||
weight_spec: *spec,
|
||||
input_scale: spec.scale_q,
|
||||
output_scale: spec.scale_q,
|
||||
acc_bits: 32,
|
||||
norm_shift: 15,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the required accumulator bits to avoid overflow
|
||||
pub fn required_acc_bits(input_bits: u8, weight_bits: u8, vector_len: usize) -> u8 {
|
||||
// Each multiply produces input_bits + weight_bits
|
||||
// Sum of vector_len terms adds log2(vector_len) bits
|
||||
let product_bits = input_bits + weight_bits;
|
||||
let sum_bits = (vector_len as f64).log2().ceil() as u8;
|
||||
product_bits + sum_bits + 1 // +1 for sign
|
||||
}
|
||||
}
|
||||
|
||||
/// Packing utilities for sub-byte quantization
|
||||
pub mod packing {
|
||||
/// Pack two 4-bit values into one byte
|
||||
#[inline]
|
||||
pub fn pack_int4(a: i8, b: i8) -> u8 {
|
||||
((a & 0x0F) as u8) | (((b & 0x0F) as u8) << 4)
|
||||
}
|
||||
|
||||
/// Unpack byte into two 4-bit values
|
||||
#[inline]
|
||||
pub fn unpack_int4(packed: u8) -> (i8, i8) {
|
||||
let a = (packed & 0x0F) as i8;
|
||||
let a = if a & 0x08 != 0 { a | !0x0F } else { a }; // Sign extend
|
||||
let b = ((packed >> 4) & 0x0F) as i8;
|
||||
let b = if b & 0x08 != 0 { b | !0x0F } else { b };
|
||||
(a, b)
|
||||
}
|
||||
|
||||
/// Pack four 2-bit values into one byte
|
||||
#[inline]
|
||||
pub fn pack_int2(a: i8, b: i8, c: i8, d: i8) -> u8 {
|
||||
((a & 0x03) as u8)
|
||||
| (((b & 0x03) as u8) << 2)
|
||||
| (((c & 0x03) as u8) << 4)
|
||||
| (((d & 0x03) as u8) << 6)
|
||||
}
|
||||
|
||||
/// Unpack byte into four 2-bit values
|
||||
#[inline]
|
||||
pub fn unpack_int2(packed: u8) -> (i8, i8, i8, i8) {
|
||||
let a = (packed & 0x03) as i8;
|
||||
let a = if a & 0x02 != 0 { a | !0x03 } else { a };
|
||||
let b = ((packed >> 2) & 0x03) as i8;
|
||||
let b = if b & 0x02 != 0 { b | !0x03 } else { b };
|
||||
let c = ((packed >> 4) & 0x03) as i8;
|
||||
let c = if c & 0x02 != 0 { c | !0x03 } else { c };
|
||||
let d = ((packed >> 6) & 0x03) as i8;
|
||||
let d = if d & 0x02 != 0 { d | !0x03 } else { d };
|
||||
(a, b, c, d)
|
||||
}
|
||||
|
||||
/// Pack eight 1-bit values into one byte
|
||||
#[inline]
|
||||
pub fn pack_binary(bits: &[bool; 8]) -> u8 {
|
||||
bits.iter()
|
||||
.enumerate()
|
||||
.fold(0u8, |acc, (i, &b)| acc | ((b as u8) << i))
|
||||
}
|
||||
|
||||
/// Unpack byte into eight 1-bit values
|
||||
#[inline]
|
||||
pub fn unpack_binary(packed: u8) -> [bool; 8] {
|
||||
[
|
||||
packed & 0x01 != 0,
|
||||
packed & 0x02 != 0,
|
||||
packed & 0x04 != 0,
|
||||
packed & 0x08 != 0,
|
||||
packed & 0x10 != 0,
|
||||
packed & 0x20 != 0,
|
||||
packed & 0x40 != 0,
|
||||
packed & 0x80 != 0,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_q15_conversion() {
|
||||
assert_eq!(f32_to_q15(0.0), 0);
|
||||
assert_eq!(f32_to_q15(0.5), 16384);
|
||||
assert_eq!(f32_to_q15(-0.5), -16384);
|
||||
|
||||
let x = 0.123f32;
|
||||
let q = f32_to_q15(x);
|
||||
let back = q15_to_f32(q);
|
||||
assert!((x - back).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_q15_mul() {
|
||||
let a = f32_to_q15(0.5);
|
||||
let b = f32_to_q15(0.5);
|
||||
let c = q15_mul(a, b);
|
||||
let result = q15_to_f32(c);
|
||||
assert!((result - 0.25).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packing_int4() {
|
||||
let (a, b) = (5i8, -3i8);
|
||||
let packed = packing::pack_int4(a, b);
|
||||
let (ua, ub) = packing::unpack_int4(packed);
|
||||
assert_eq!(a, ua);
|
||||
assert_eq!(b, ub);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packing_int2() {
|
||||
let (a, b, c, d) = (1i8, -1i8, 0i8, -2i8);
|
||||
let packed = packing::pack_int2(a, b, c, d);
|
||||
let (ua, ub, uc, ud) = packing::unpack_int2(packed);
|
||||
assert_eq!(a, ua);
|
||||
assert_eq!(b, ub);
|
||||
assert_eq!(c, uc);
|
||||
// -2 in 2-bit is 10 binary, which unpacks to -2 (sign extended)
|
||||
assert_eq!(-2i8, ud);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packing_binary() {
|
||||
let bits = [true, false, true, true, false, false, true, false];
|
||||
let packed = packing::pack_binary(&bits);
|
||||
let unpacked = packing::unpack_binary(packed);
|
||||
assert_eq!(bits, unpacked);
|
||||
}
|
||||
}
|
||||
237
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/qformat.rs
vendored
Normal file
237
vendor/ruvector/crates/ruvector-fpga-transformer/src/quant/qformat.rs
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
//! Quantization format operations
|
||||
|
||||
use crate::types::QuantSpec;
|
||||
|
||||
/// Quantize f32 values to i8
|
||||
pub fn quantize_i8(values: &[f32], spec: &QuantSpec) -> Vec<i8> {
|
||||
let scale = spec.scale_q as f32 / 65536.0;
|
||||
let zero = spec.zero_q as f32 / 65536.0;
|
||||
|
||||
values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let quantized = ((v - zero) / scale).round();
|
||||
quantized.clamp(-128.0, 127.0) as i8
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Quantize f32 values to i16
|
||||
pub fn quantize_i16(values: &[f32]) -> Vec<i16> {
|
||||
// Find min/max
|
||||
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Handle edge case
|
||||
if (max - min).abs() < f32::EPSILON {
|
||||
return vec![0i16; values.len()];
|
||||
}
|
||||
|
||||
let scale = (max - min) / 65535.0;
|
||||
|
||||
values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let normalized = (v - min) / scale - 32768.0;
|
||||
normalized.round().clamp(-32768.0, 32767.0) as i16
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Dequantize i8 values to f32
|
||||
pub fn dequantize_i8(values: &[u8], spec: &QuantSpec) -> Vec<f32> {
|
||||
let scale = spec.scale_q as f32 / 65536.0;
|
||||
let zero = spec.zero_q as f32 / 65536.0;
|
||||
|
||||
values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let signed = v as i8;
|
||||
signed as f32 * scale + zero
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Dequantize i16 values to f32
|
||||
pub fn dequantize_i16(values: &[i16], scale: f32, zero: f32) -> Vec<f32> {
|
||||
values.iter().map(|&v| v as f32 * scale + zero).collect()
|
||||
}
|
||||
|
||||
/// Symmetric quantization (zero point = 0)
|
||||
pub fn quantize_symmetric_i8(values: &[f32]) -> (Vec<i8>, f32) {
|
||||
let abs_max = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
|
||||
|
||||
if abs_max < f32::EPSILON {
|
||||
return (vec![0i8; values.len()], 1.0);
|
||||
}
|
||||
|
||||
let scale = abs_max / 127.0;
|
||||
|
||||
let quantized = values
|
||||
.iter()
|
||||
.map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
|
||||
.collect();
|
||||
|
||||
(quantized, scale)
|
||||
}
|
||||
|
||||
/// Asymmetric quantization (uses full i8 range)
|
||||
pub fn quantize_asymmetric_i8(values: &[f32]) -> (Vec<u8>, f32, i32) {
|
||||
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
if (max - min).abs() < f32::EPSILON {
|
||||
return (vec![0u8; values.len()], 1.0, 0);
|
||||
}
|
||||
|
||||
let scale = (max - min) / 255.0;
|
||||
let zero_point = (-min / scale).round() as i32;
|
||||
|
||||
let quantized = values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let q = (v / scale).round() as i32 + zero_point;
|
||||
q.clamp(0, 255) as u8
|
||||
})
|
||||
.collect();
|
||||
|
||||
(quantized, scale, zero_point)
|
||||
}
|
||||
|
||||
/// Per-channel quantization for weights
|
||||
pub fn quantize_per_channel_i8(weights: &[f32], out_channels: usize) -> (Vec<i8>, Vec<f32>) {
|
||||
let in_features = weights.len() / out_channels;
|
||||
let mut quantized = Vec::with_capacity(weights.len());
|
||||
let mut scales = Vec::with_capacity(out_channels);
|
||||
|
||||
for c in 0..out_channels {
|
||||
let start = c * in_features;
|
||||
let end = start + in_features;
|
||||
let channel_weights = &weights[start..end];
|
||||
|
||||
let (q, scale) = quantize_symmetric_i8(channel_weights);
|
||||
quantized.extend(q);
|
||||
scales.push(scale);
|
||||
}
|
||||
|
||||
(quantized, scales)
|
||||
}
|
||||
|
||||
/// Blocked quantization for hardware efficiency
|
||||
pub fn quantize_blocked_i8(values: &[f32], block_size: usize) -> (Vec<i8>, Vec<f32>, Vec<i8>) {
|
||||
let num_blocks = (values.len() + block_size - 1) / block_size;
|
||||
let mut quantized = Vec::with_capacity(values.len());
|
||||
let mut scales = Vec::with_capacity(num_blocks);
|
||||
let mut zeros = Vec::with_capacity(num_blocks);
|
||||
|
||||
for block_idx in 0..num_blocks {
|
||||
let start = block_idx * block_size;
|
||||
let end = (start + block_size).min(values.len());
|
||||
let block = &values[start..end];
|
||||
|
||||
let (q, scale) = quantize_symmetric_i8(block);
|
||||
quantized.extend(q);
|
||||
scales.push(scale);
|
||||
zeros.push(0i8);
|
||||
}
|
||||
|
||||
(quantized, scales, zeros)
|
||||
}
|
||||
|
||||
/// Matrix quantization for GEMM
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizedMatrix {
|
||||
/// Quantized values
|
||||
pub data: Vec<i8>,
|
||||
/// Rows
|
||||
pub rows: usize,
|
||||
/// Columns
|
||||
pub cols: usize,
|
||||
/// Per-row scales (for per-channel quantization)
|
||||
pub scales: Vec<f32>,
|
||||
/// Per-row zero points
|
||||
pub zeros: Vec<i8>,
|
||||
}
|
||||
|
||||
impl QuantizedMatrix {
|
||||
/// Quantize a matrix with per-row scaling
|
||||
pub fn from_f32(data: &[f32], rows: usize, cols: usize) -> Self {
|
||||
assert_eq!(data.len(), rows * cols);
|
||||
|
||||
let (quantized, scales) = quantize_per_channel_i8(data, rows);
|
||||
|
||||
Self {
|
||||
data: quantized,
|
||||
rows,
|
||||
cols,
|
||||
scales,
|
||||
zeros: vec![0i8; rows],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a row
|
||||
pub fn row(&self, idx: usize) -> &[i8] {
|
||||
let start = idx * self.cols;
|
||||
&self.data[start..start + self.cols]
|
||||
}
|
||||
|
||||
/// Dequantize to f32
|
||||
pub fn to_f32(&self) -> Vec<f32> {
|
||||
let mut result = Vec::with_capacity(self.rows * self.cols);
|
||||
|
||||
for r in 0..self.rows {
|
||||
let scale = self.scales[r];
|
||||
let zero = self.zeros[r] as f32;
|
||||
for &v in self.row(r) {
|
||||
result.push((v as f32 - zero) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::QuantSpec;
|
||||
|
||||
#[test]
|
||||
fn test_quantize_symmetric() {
|
||||
let values = vec![1.0, -1.0, 0.5, -0.5, 0.0];
|
||||
let (quantized, scale) = quantize_symmetric_i8(&values);
|
||||
|
||||
// Dequantize and check
|
||||
for (i, &q) in quantized.iter().enumerate() {
|
||||
let dequant = q as f32 * scale;
|
||||
assert!((dequant - values[i]).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_asymmetric() {
|
||||
let values = vec![0.0, 0.5, 1.0, 1.5, 2.0];
|
||||
let (quantized, scale, zero) = quantize_asymmetric_i8(&values);
|
||||
|
||||
// Dequantize and check
|
||||
for (i, &q) in quantized.iter().enumerate() {
|
||||
let dequant = (q as i32 - zero) as f32 * scale;
|
||||
assert!((dequant - values[i]).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantized_matrix() {
|
||||
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.1 - 3.2).collect();
|
||||
let matrix = QuantizedMatrix::from_f32(&data, 8, 8);
|
||||
|
||||
assert_eq!(matrix.rows, 8);
|
||||
assert_eq!(matrix.cols, 8);
|
||||
assert_eq!(matrix.scales.len(), 8);
|
||||
|
||||
let dequantized = matrix.to_f32();
|
||||
for (orig, deq) in data.iter().zip(dequantized.iter()) {
|
||||
assert!((orig - deq).abs() < 0.2);
|
||||
}
|
||||
}
|
||||
}
|
||||
624
vendor/ruvector/crates/ruvector-fpga-transformer/src/types.rs
vendored
Normal file
624
vendor/ruvector/crates/ruvector-fpga-transformer/src/types.rs
vendored
Normal file
@@ -0,0 +1,624 @@
|
||||
//! Core types for FPGA Transformer backend
|
||||
//!
|
||||
//! All types are designed for deterministic, allocation-free inference
|
||||
//! with explicit quantization metadata.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Unique identifier for a loaded model (SHA-256 hash)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ModelId(pub [u8; 32]);
|
||||
|
||||
impl ModelId {
|
||||
/// Create a new ModelId from bytes
|
||||
pub const fn new(bytes: [u8; 32]) -> Self {
|
||||
Self(bytes)
|
||||
}
|
||||
|
||||
/// Create a zero ModelId (for testing)
|
||||
pub const fn zero() -> Self {
|
||||
Self([0u8; 32])
|
||||
}
|
||||
|
||||
/// Get the bytes of the ModelId
|
||||
pub const fn as_bytes(&self) -> &[u8; 32] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Convert to hex string
|
||||
pub fn to_hex(&self) -> String {
|
||||
self.0.iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
|
||||
/// Parse from hex string
|
||||
pub fn from_hex(s: &str) -> Option<Self> {
|
||||
if s.len() != 64 {
|
||||
return None;
|
||||
}
|
||||
let mut bytes = [0u8; 32];
|
||||
for (i, chunk) in s.as_bytes().chunks(2).enumerate() {
|
||||
let hex_str = std::str::from_utf8(chunk).ok()?;
|
||||
bytes[i] = u8::from_str_radix(hex_str, 16).ok()?;
|
||||
}
|
||||
Some(Self(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.to_hex())
|
||||
}
|
||||
}
|
||||
|
||||
/// Fixed shape specification for transformer inference
|
||||
///
|
||||
/// All dimensions are compile-time or model-time constants.
|
||||
/// This enables zero-allocation inference paths.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct FixedShape {
|
||||
/// Maximum sequence length
|
||||
pub seq_len: u16,
|
||||
/// Model/hidden dimension
|
||||
pub d_model: u16,
|
||||
/// Number of attention heads
|
||||
pub heads: u8,
|
||||
/// Dimension per head
|
||||
pub d_head: u16,
|
||||
/// Vocabulary size
|
||||
pub vocab: u32,
|
||||
}
|
||||
|
||||
impl FixedShape {
|
||||
/// Create a new FixedShape
|
||||
pub const fn new(seq_len: u16, d_model: u16, heads: u8, d_head: u16, vocab: u32) -> Self {
|
||||
Self {
|
||||
seq_len,
|
||||
d_model,
|
||||
heads,
|
||||
d_head,
|
||||
vocab,
|
||||
}
|
||||
}
|
||||
|
||||
/// Micro configuration for edge/WASM deployment
|
||||
pub const fn micro() -> Self {
|
||||
Self {
|
||||
seq_len: 32,
|
||||
d_model: 64,
|
||||
heads: 4,
|
||||
d_head: 16,
|
||||
vocab: 4096,
|
||||
}
|
||||
}
|
||||
|
||||
/// Small configuration for embedded
|
||||
pub const fn small() -> Self {
|
||||
Self {
|
||||
seq_len: 64,
|
||||
d_model: 128,
|
||||
heads: 4,
|
||||
d_head: 32,
|
||||
vocab: 8192,
|
||||
}
|
||||
}
|
||||
|
||||
/// Baseline configuration
|
||||
pub const fn baseline() -> Self {
|
||||
Self {
|
||||
seq_len: 128,
|
||||
d_model: 256,
|
||||
heads: 8,
|
||||
d_head: 32,
|
||||
vocab: 32000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate total parameters for embedding layer
|
||||
pub const fn embedding_params(&self) -> usize {
|
||||
self.vocab as usize * self.d_model as usize
|
||||
}
|
||||
|
||||
/// Calculate parameters per attention layer
|
||||
pub const fn attention_params(&self) -> usize {
|
||||
// Q, K, V projections + output projection
|
||||
4 * (self.d_model as usize * self.d_model as usize)
|
||||
}
|
||||
|
||||
/// Calculate parameters per FFN layer (assuming 4x expansion)
|
||||
pub const fn ffn_params(&self) -> usize {
|
||||
2 * (self.d_model as usize * 4 * self.d_model as usize)
|
||||
}
|
||||
|
||||
/// Validate shape consistency
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.d_model as usize != self.heads as usize * self.d_head as usize {
|
||||
return Err(format!(
|
||||
"d_model ({}) must equal heads ({}) * d_head ({})",
|
||||
self.d_model, self.heads, self.d_head
|
||||
));
|
||||
}
|
||||
if self.seq_len == 0 {
|
||||
return Err("seq_len must be > 0".into());
|
||||
}
|
||||
if self.vocab == 0 {
|
||||
return Err("vocab must be > 0".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FixedShape {
|
||||
fn default() -> Self {
|
||||
Self::baseline()
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantization specification
|
||||
///
|
||||
/// Explicit quantization metadata ensuring reproducible inference.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct QuantSpec {
|
||||
/// Weight bit width (1, 2, 4, 8)
|
||||
pub w_bits: u8,
|
||||
/// Activation bit width (4, 8, 16)
|
||||
pub a_bits: u8,
|
||||
/// Scale factor (Q16.16 fixed point)
|
||||
pub scale_q: i32,
|
||||
/// Zero point (Q16.16 fixed point)
|
||||
pub zero_q: i32,
|
||||
/// Memory layout
|
||||
pub layout: Layout,
|
||||
}
|
||||
|
||||
impl QuantSpec {
|
||||
/// Create a new QuantSpec
|
||||
pub const fn new(w_bits: u8, a_bits: u8, scale_q: i32, zero_q: i32, layout: Layout) -> Self {
|
||||
Self {
|
||||
w_bits,
|
||||
a_bits,
|
||||
scale_q,
|
||||
zero_q,
|
||||
layout,
|
||||
}
|
||||
}
|
||||
|
||||
/// INT4 weights, INT8 activations (common for edge)
|
||||
pub const fn int4_int8() -> Self {
|
||||
Self {
|
||||
w_bits: 4,
|
||||
a_bits: 8,
|
||||
scale_q: 1 << 16, // 1.0 in Q16.16
|
||||
zero_q: 0,
|
||||
layout: Layout::Blocked { block: 32 },
|
||||
}
|
||||
}
|
||||
|
||||
/// INT8 weights and activations
|
||||
pub const fn int8() -> Self {
|
||||
Self {
|
||||
w_bits: 8,
|
||||
a_bits: 8,
|
||||
scale_q: 1 << 16,
|
||||
zero_q: 0,
|
||||
layout: Layout::RowMajor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Bytes per weight element
|
||||
pub const fn bytes_per_weight(&self) -> usize {
|
||||
match self.w_bits {
|
||||
1 => 1, // Packed 8 per byte, but minimum 1 byte
|
||||
2 => 1, // Packed 4 per byte
|
||||
4 => 1, // Packed 2 per byte
|
||||
8 => 1,
|
||||
16 => 2,
|
||||
_ => 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Weights packed per byte
|
||||
pub const fn weights_per_byte(&self) -> usize {
|
||||
match self.w_bits {
|
||||
1 => 8,
|
||||
2 => 4,
|
||||
4 => 2,
|
||||
_ => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QuantSpec {
|
||||
fn default() -> Self {
|
||||
Self::int8()
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory layout for quantized tensors
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Layout {
|
||||
/// Standard row-major layout
|
||||
RowMajor,
|
||||
/// Blocked layout for SIMD/hardware efficiency
|
||||
Blocked { block: u16 },
|
||||
/// Heads interleaved for attention computation
|
||||
InterleavedHeads,
|
||||
}
|
||||
|
||||
impl Default for Layout {
|
||||
fn default() -> Self {
|
||||
Self::RowMajor
|
||||
}
|
||||
}
|
||||
|
||||
/// Hint for gating decisions
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct GateHint {
|
||||
/// Coherence score (Q8.8 fixed point, higher = more coherent)
|
||||
pub coherence_score_q: i16,
|
||||
/// Whether a boundary was crossed in the input
|
||||
pub boundary_crossed: bool,
|
||||
/// Maximum compute class allowed
|
||||
pub max_compute_class: ComputeClass,
|
||||
}
|
||||
|
||||
impl GateHint {
|
||||
/// Create a new GateHint
|
||||
pub const fn new(
|
||||
coherence_score_q: i16,
|
||||
boundary_crossed: bool,
|
||||
max_compute_class: ComputeClass,
|
||||
) -> Self {
|
||||
Self {
|
||||
coherence_score_q,
|
||||
boundary_crossed,
|
||||
max_compute_class,
|
||||
}
|
||||
}
|
||||
|
||||
/// Default hint allowing full computation
|
||||
pub const fn allow_all() -> Self {
|
||||
Self {
|
||||
coherence_score_q: i16::MAX,
|
||||
boundary_crossed: false,
|
||||
max_compute_class: ComputeClass::Deliberative,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reflex-only hint for fast path
|
||||
pub const fn reflex_only() -> Self {
|
||||
Self {
|
||||
coherence_score_q: 0,
|
||||
boundary_crossed: false,
|
||||
max_compute_class: ComputeClass::Reflex,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GateHint {
|
||||
fn default() -> Self {
|
||||
Self::allow_all()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute class for tiered inference
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
pub enum ComputeClass {
|
||||
/// Fastest path, minimal computation (1-2 layers)
|
||||
Reflex = 0,
|
||||
/// Medium path, associative memory (4-6 layers)
|
||||
Associative = 1,
|
||||
/// Full deliberative computation (all layers)
|
||||
Deliberative = 2,
|
||||
}
|
||||
|
||||
impl ComputeClass {
|
||||
/// Convert from u8
|
||||
pub const fn from_u8(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0 => Some(Self::Reflex),
|
||||
1 => Some(Self::Associative),
|
||||
2 => Some(Self::Deliberative),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ComputeClass {
|
||||
fn default() -> Self {
|
||||
Self::Deliberative
|
||||
}
|
||||
}
|
||||
|
||||
/// Inference request
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InferenceRequest<'a> {
|
||||
/// Model to use
|
||||
pub model: ModelId,
|
||||
/// Expected shape
|
||||
pub shape: FixedShape,
|
||||
/// Input token IDs (length = seq_len)
|
||||
pub tokens: &'a [u16],
|
||||
/// Attention mask (length = seq_len or seq_len^2)
|
||||
pub attn_mask: &'a [u8],
|
||||
/// Gating hint for coherence control
|
||||
pub gate_hint: GateHint,
|
||||
}
|
||||
|
||||
impl<'a> InferenceRequest<'a> {
|
||||
/// Create a new InferenceRequest
|
||||
pub fn new(
|
||||
model: ModelId,
|
||||
shape: FixedShape,
|
||||
tokens: &'a [u16],
|
||||
attn_mask: &'a [u8],
|
||||
gate_hint: GateHint,
|
||||
) -> Self {
|
||||
Self {
|
||||
model,
|
||||
shape,
|
||||
tokens,
|
||||
attn_mask,
|
||||
gate_hint,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the request
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
if self.tokens.len() != self.shape.seq_len as usize {
|
||||
return Err(crate::error::Error::InputLengthMismatch {
|
||||
expected: self.shape.seq_len as usize,
|
||||
actual: self.tokens.len(),
|
||||
});
|
||||
}
|
||||
if self.attn_mask.len() != self.shape.seq_len as usize
|
||||
&& self.attn_mask.len() != (self.shape.seq_len as usize).pow(2)
|
||||
{
|
||||
return Err(crate::error::Error::InputLengthMismatch {
|
||||
expected: self.shape.seq_len as usize,
|
||||
actual: self.attn_mask.len(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Inference result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InferenceResult {
|
||||
/// Full logits (quantized, length = vocab) or empty if topk_only
|
||||
pub logits_q: Vec<i16>,
|
||||
/// Top-K predictions (token_id, logit_q)
|
||||
pub topk: Option<Vec<(u16, i16)>>,
|
||||
/// Witness log for audit trail
|
||||
pub witness: WitnessLog,
|
||||
}
|
||||
|
||||
impl InferenceResult {
|
||||
/// Create a new InferenceResult
|
||||
pub fn new(logits_q: Vec<i16>, topk: Option<Vec<(u16, i16)>>, witness: WitnessLog) -> Self {
|
||||
Self {
|
||||
logits_q,
|
||||
topk,
|
||||
witness,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the argmax token
|
||||
pub fn argmax(&self) -> Option<u16> {
|
||||
if let Some(ref topk) = self.topk {
|
||||
topk.first().map(|(token, _)| *token)
|
||||
} else if !self.logits_q.is_empty() {
|
||||
self.logits_q
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by_key(|(_, &v)| v)
|
||||
.map(|(i, _)| i as u16)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Witness log for audit trail and ReasoningBank integration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct WitnessLog {
|
||||
/// Hash of the model used
|
||||
pub model_hash: [u8; 32],
|
||||
/// Hash of quantization parameters used
|
||||
pub quant_hash: [u8; 32],
|
||||
/// Backend that executed the inference
|
||||
pub backend: BackendKind,
|
||||
/// Compute cycles used (FPGA) or 0 (sim)
|
||||
pub cycles: u32,
|
||||
/// Latency in nanoseconds
|
||||
pub latency_ns: u32,
|
||||
/// Gate decision made
|
||||
pub gate_decision: GateDecision,
|
||||
}
|
||||
|
||||
impl WitnessLog {
|
||||
/// Create a new WitnessLog
|
||||
pub fn new(
|
||||
model_hash: [u8; 32],
|
||||
quant_hash: [u8; 32],
|
||||
backend: BackendKind,
|
||||
cycles: u32,
|
||||
latency_ns: u32,
|
||||
gate_decision: GateDecision,
|
||||
) -> Self {
|
||||
Self {
|
||||
model_hash,
|
||||
quant_hash,
|
||||
backend,
|
||||
cycles,
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an empty witness (for testing)
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
model_hash: [0u8; 32],
|
||||
quant_hash: [0u8; 32],
|
||||
backend: BackendKind::NativeSim,
|
||||
cycles: 0,
|
||||
latency_ns: 0,
|
||||
gate_decision: GateDecision::RanFull,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend types
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum BackendKind {
|
||||
/// PCIe-connected FPGA
|
||||
FpgaPcie,
|
||||
/// FPGA via local daemon
|
||||
FpgaDaemon,
|
||||
/// WASM simulator
|
||||
WasmSim,
|
||||
/// Native Rust simulator
|
||||
NativeSim,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::FpgaPcie => write!(f, "fpga_pcie"),
|
||||
Self::FpgaDaemon => write!(f, "fpga_daemon"),
|
||||
Self::WasmSim => write!(f, "wasm_sim"),
|
||||
Self::NativeSim => write!(f, "native_sim"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gate decision outcome
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GateDecision {
|
||||
/// Full inference completed
|
||||
RanFull,
|
||||
/// Early exit at specified layer
|
||||
EarlyExit { layer: u8 },
|
||||
/// Inference was skipped
|
||||
Skipped { reason: SkipReason },
|
||||
}
|
||||
|
||||
impl GateDecision {
|
||||
/// Check if inference actually ran
|
||||
pub const fn did_run(&self) -> bool {
|
||||
!matches!(self, Self::Skipped { .. })
|
||||
}
|
||||
|
||||
/// Get the exit layer (full = max layers)
|
||||
pub const fn exit_layer(&self, max_layers: u8) -> u8 {
|
||||
match self {
|
||||
Self::RanFull => max_layers,
|
||||
Self::EarlyExit { layer } => *layer,
|
||||
Self::Skipped { .. } => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reason for skipping inference
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SkipReason {
|
||||
/// Coherence score too low
|
||||
LowCoherence,
|
||||
/// Policy denied the inference
|
||||
PolicyDenied,
|
||||
/// Compute budget exceeded
|
||||
BudgetExceeded,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SkipReason {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::LowCoherence => write!(f, "low_coherence"),
|
||||
Self::PolicyDenied => write!(f, "policy_denied"),
|
||||
Self::BudgetExceeded => write!(f, "budget_exceeded"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantized tensor wrapper
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizedTensor {
|
||||
/// Raw quantized data
|
||||
pub data: Vec<u8>,
|
||||
/// Quantization specification
|
||||
pub spec: QuantSpec,
|
||||
/// Tensor shape (row-major)
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
impl QuantizedTensor {
|
||||
/// Create a new quantized tensor
|
||||
pub fn new(data: Vec<u8>, spec: QuantSpec, shape: Vec<usize>) -> Self {
|
||||
Self { data, spec, shape }
|
||||
}
|
||||
|
||||
/// Total number of elements
|
||||
pub fn numel(&self) -> usize {
|
||||
self.shape.iter().product()
|
||||
}
|
||||
|
||||
/// Expected data size in bytes
|
||||
pub fn expected_bytes(&self) -> usize {
|
||||
let numel = self.numel();
|
||||
(numel + self.spec.weights_per_byte() - 1) / self.spec.weights_per_byte()
|
||||
}
|
||||
|
||||
/// Validate tensor integrity
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
let expected = self.expected_bytes();
|
||||
if self.data.len() != expected {
|
||||
return Err(crate::error::Error::QuantizationError(format!(
|
||||
"Data size mismatch: expected {} bytes, got {}",
|
||||
expected,
|
||||
self.data.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_id_hex_roundtrip() {
|
||||
let bytes = [0x12u8; 32];
|
||||
let id = ModelId::new(bytes);
|
||||
let hex = id.to_hex();
|
||||
let parsed = ModelId::from_hex(&hex).unwrap();
|
||||
assert_eq!(id, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fixed_shape_validate() {
|
||||
let valid = FixedShape::new(64, 256, 8, 32, 32000);
|
||||
assert!(valid.validate().is_ok());
|
||||
|
||||
let invalid = FixedShape::new(64, 256, 8, 16, 32000); // 8*16 != 256
|
||||
assert!(invalid.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quant_spec_bytes() {
|
||||
assert_eq!(QuantSpec::int8().weights_per_byte(), 1);
|
||||
assert_eq!(QuantSpec::int4_int8().weights_per_byte(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_decision() {
|
||||
assert!(GateDecision::RanFull.did_run());
|
||||
assert!(GateDecision::EarlyExit { layer: 3 }.did_run());
|
||||
assert!(!GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence
|
||||
}
|
||||
.did_run());
|
||||
}
|
||||
}
|
||||
215
vendor/ruvector/crates/ruvector-fpga-transformer/src/witness/hash.rs
vendored
Normal file
215
vendor/ruvector/crates/ruvector-fpga-transformer/src/witness/hash.rs
vendored
Normal file
@@ -0,0 +1,215 @@
|
||||
//! Witness hashing for integrity verification
|
||||
|
||||
use crate::types::WitnessLog;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Compute a hash of the witness log for integrity verification
|
||||
pub fn compute_witness_hash(witness: &WitnessLog) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
hasher.update(&witness.model_hash);
|
||||
hasher.update(&witness.quant_hash);
|
||||
hasher.update(&[witness.backend as u8]);
|
||||
hasher.update(&witness.cycles.to_le_bytes());
|
||||
hasher.update(&witness.latency_ns.to_le_bytes());
|
||||
|
||||
// Hash gate decision
|
||||
match witness.gate_decision {
|
||||
crate::types::GateDecision::RanFull => {
|
||||
hasher.update(&[0u8]);
|
||||
}
|
||||
crate::types::GateDecision::EarlyExit { layer } => {
|
||||
hasher.update(&[1u8, layer]);
|
||||
}
|
||||
crate::types::GateDecision::Skipped { reason } => {
|
||||
hasher.update(&[2u8, reason as u8]);
|
||||
}
|
||||
}
|
||||
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Verify a witness hash
|
||||
pub fn verify_witness_hash(witness: &WitnessLog, expected: &[u8; 32]) -> bool {
|
||||
let computed = compute_witness_hash(witness);
|
||||
computed == *expected
|
||||
}
|
||||
|
||||
/// Compute a combined hash for a sequence of witnesses
|
||||
/// Useful for verifying an entire inference chain
|
||||
pub fn compute_chain_hash(witnesses: &[WitnessLog]) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
for witness in witnesses {
|
||||
let witness_hash = compute_witness_hash(witness);
|
||||
hasher.update(&witness_hash);
|
||||
}
|
||||
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Witness proof for verification
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WitnessProof {
|
||||
/// Hash of the witness
|
||||
pub hash: [u8; 32],
|
||||
/// Timestamp when proof was created
|
||||
pub timestamp_ns: u64,
|
||||
/// Optional signature
|
||||
pub signature: Option<[u8; 64]>,
|
||||
}
|
||||
|
||||
impl WitnessProof {
|
||||
/// Create a new proof from a witness
|
||||
pub fn new(witness: &WitnessLog) -> Self {
|
||||
Self {
|
||||
hash: compute_witness_hash(witness),
|
||||
timestamp_ns: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos() as u64)
|
||||
.unwrap_or(0),
|
||||
signature: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a proof with signature
|
||||
#[cfg(feature = "sign")]
|
||||
pub fn signed(witness: &WitnessLog, secret_key: &[u8; 32]) -> Self {
|
||||
use ed25519_dalek::{Signer, SigningKey};
|
||||
|
||||
let hash = compute_witness_hash(witness);
|
||||
let timestamp_ns = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos() as u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
// Create message to sign
|
||||
let mut message = [0u8; 40];
|
||||
message[..32].copy_from_slice(&hash);
|
||||
message[32..40].copy_from_slice(×tamp_ns.to_le_bytes());
|
||||
|
||||
let signing_key = SigningKey::from_bytes(secret_key);
|
||||
let signature = signing_key.sign(&message);
|
||||
|
||||
Self {
|
||||
hash,
|
||||
timestamp_ns,
|
||||
signature: Some(signature.to_bytes()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify the proof against a witness
|
||||
pub fn verify(&self, witness: &WitnessLog) -> bool {
|
||||
verify_witness_hash(witness, &self.hash)
|
||||
}
|
||||
|
||||
/// Verify the signature
|
||||
#[cfg(feature = "sign")]
|
||||
pub fn verify_signature(&self, pubkey: &[u8; 32]) -> bool {
|
||||
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
|
||||
|
||||
let Some(sig_bytes) = self.signature else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let Ok(verifying_key) = VerifyingKey::from_bytes(pubkey) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let signature = Signature::from_bytes(&sig_bytes);
|
||||
|
||||
let mut message = [0u8; 40];
|
||||
message[..32].copy_from_slice(&self.hash);
|
||||
message[32..40].copy_from_slice(&self.timestamp_ns.to_le_bytes());
|
||||
|
||||
verifying_key.verify(&message, &signature).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{BackendKind, GateDecision};
|
||||
|
||||
#[test]
|
||||
fn test_witness_hash_deterministic() {
|
||||
let witness = WitnessLog::new(
|
||||
[1u8; 32],
|
||||
[2u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
1000,
|
||||
50000,
|
||||
GateDecision::RanFull,
|
||||
);
|
||||
|
||||
let hash1 = compute_witness_hash(&witness);
|
||||
let hash2 = compute_witness_hash(&witness);
|
||||
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_hash_changes() {
|
||||
let witness1 = WitnessLog::new(
|
||||
[1u8; 32],
|
||||
[2u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
1000,
|
||||
50000,
|
||||
GateDecision::RanFull,
|
||||
);
|
||||
|
||||
let witness2 = WitnessLog::new(
|
||||
[1u8; 32],
|
||||
[2u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
1001, // Different cycles
|
||||
50000,
|
||||
GateDecision::RanFull,
|
||||
);
|
||||
|
||||
let hash1 = compute_witness_hash(&witness1);
|
||||
let hash2 = compute_witness_hash(&witness2);
|
||||
|
||||
assert_ne!(hash1, hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_witness_hash() {
|
||||
let witness = WitnessLog::empty();
|
||||
let hash = compute_witness_hash(&witness);
|
||||
|
||||
assert!(verify_witness_hash(&witness, &hash));
|
||||
assert!(!verify_witness_hash(&witness, &[0u8; 32]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_hash() {
|
||||
let witnesses: Vec<WitnessLog> = (0..5)
|
||||
.map(|i| {
|
||||
WitnessLog::new(
|
||||
[i as u8; 32],
|
||||
[0u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
i * 100,
|
||||
i * 1000,
|
||||
GateDecision::RanFull,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let chain_hash1 = compute_chain_hash(&witnesses);
|
||||
let chain_hash2 = compute_chain_hash(&witnesses);
|
||||
|
||||
assert_eq!(chain_hash1, chain_hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_proof() {
|
||||
let witness = WitnessLog::empty();
|
||||
let proof = WitnessProof::new(&witness);
|
||||
|
||||
assert!(proof.verify(&witness));
|
||||
assert!(proof.timestamp_ns > 0);
|
||||
}
|
||||
}
|
||||
311
vendor/ruvector/crates/ruvector-fpga-transformer/src/witness/log.rs
vendored
Normal file
311
vendor/ruvector/crates/ruvector-fpga-transformer/src/witness/log.rs
vendored
Normal file
@@ -0,0 +1,311 @@
|
||||
//! Witness log builder and utilities
|
||||
|
||||
use crate::types::{BackendKind, GateDecision, WitnessLog};
|
||||
use std::time::Instant;
|
||||
|
||||
/// Builder for creating witness logs
|
||||
pub struct WitnessBuilder {
|
||||
model_hash: [u8; 32],
|
||||
quant_hash: [u8; 32],
|
||||
backend: BackendKind,
|
||||
start_time: Instant,
|
||||
cycles: u32,
|
||||
gate_decision: GateDecision,
|
||||
}
|
||||
|
||||
impl WitnessBuilder {
|
||||
/// Start building a new witness
|
||||
pub fn new(backend: BackendKind) -> Self {
|
||||
Self {
|
||||
model_hash: [0u8; 32],
|
||||
quant_hash: [0u8; 32],
|
||||
backend,
|
||||
start_time: Instant::now(),
|
||||
cycles: 0,
|
||||
gate_decision: GateDecision::RanFull,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set model hash
|
||||
pub fn model_hash(mut self, hash: [u8; 32]) -> Self {
|
||||
self.model_hash = hash;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set quantization hash
|
||||
pub fn quant_hash(mut self, hash: [u8; 32]) -> Self {
|
||||
self.quant_hash = hash;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set compute cycles
|
||||
pub fn cycles(mut self, cycles: u32) -> Self {
|
||||
self.cycles = cycles;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set gate decision
|
||||
pub fn gate_decision(mut self, decision: GateDecision) -> Self {
|
||||
self.gate_decision = decision;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the witness log
|
||||
pub fn build(self) -> WitnessLog {
|
||||
let latency_ns = self.start_time.elapsed().as_nanos() as u32;
|
||||
|
||||
WitnessLog::new(
|
||||
self.model_hash,
|
||||
self.quant_hash,
|
||||
self.backend,
|
||||
self.cycles,
|
||||
latency_ns,
|
||||
self.gate_decision,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl WitnessLog {
|
||||
/// Convert to compact bytes for storage
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
let mut bytes = Vec::with_capacity(80);
|
||||
|
||||
bytes.extend_from_slice(&self.model_hash);
|
||||
bytes.extend_from_slice(&self.quant_hash);
|
||||
bytes.push(self.backend as u8);
|
||||
bytes.extend_from_slice(&self.cycles.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.latency_ns.to_le_bytes());
|
||||
|
||||
// Encode gate decision
|
||||
match self.gate_decision {
|
||||
GateDecision::RanFull => {
|
||||
bytes.push(0);
|
||||
bytes.push(0);
|
||||
}
|
||||
GateDecision::EarlyExit { layer } => {
|
||||
bytes.push(1);
|
||||
bytes.push(layer);
|
||||
}
|
||||
GateDecision::Skipped { reason } => {
|
||||
bytes.push(2);
|
||||
bytes.push(reason as u8);
|
||||
}
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Parse from bytes
|
||||
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
if bytes.len() < 75 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let model_hash: [u8; 32] = bytes[0..32].try_into().ok()?;
|
||||
let quant_hash: [u8; 32] = bytes[32..64].try_into().ok()?;
|
||||
|
||||
let backend = match bytes[64] {
|
||||
0 => BackendKind::FpgaPcie,
|
||||
1 => BackendKind::FpgaDaemon,
|
||||
2 => BackendKind::WasmSim,
|
||||
3 => BackendKind::NativeSim,
|
||||
_ => BackendKind::NativeSim,
|
||||
};
|
||||
|
||||
let cycles = u32::from_le_bytes(bytes[65..69].try_into().ok()?);
|
||||
let latency_ns = u32::from_le_bytes(bytes[69..73].try_into().ok()?);
|
||||
|
||||
let gate_decision = match bytes[73] {
|
||||
0 => GateDecision::RanFull,
|
||||
1 => GateDecision::EarlyExit { layer: bytes[74] },
|
||||
2 => GateDecision::Skipped {
|
||||
reason: match bytes[74] {
|
||||
0 => crate::types::SkipReason::LowCoherence,
|
||||
1 => crate::types::SkipReason::PolicyDenied,
|
||||
_ => crate::types::SkipReason::BudgetExceeded,
|
||||
},
|
||||
},
|
||||
_ => GateDecision::RanFull,
|
||||
};
|
||||
|
||||
Some(Self {
|
||||
model_hash,
|
||||
quant_hash,
|
||||
backend,
|
||||
cycles,
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get latency in microseconds
|
||||
pub fn latency_us(&self) -> f64 {
|
||||
self.latency_ns as f64 / 1000.0
|
||||
}
|
||||
|
||||
/// Get latency in milliseconds
|
||||
pub fn latency_ms(&self) -> f64 {
|
||||
self.latency_ns as f64 / 1_000_000.0
|
||||
}
|
||||
|
||||
/// Check if this was a successful full inference
|
||||
pub fn is_full_inference(&self) -> bool {
|
||||
matches!(self.gate_decision, GateDecision::RanFull)
|
||||
}
|
||||
|
||||
/// Check if this was an early exit
|
||||
pub fn is_early_exit(&self) -> bool {
|
||||
matches!(self.gate_decision, GateDecision::EarlyExit { .. })
|
||||
}
|
||||
|
||||
/// Check if inference was skipped
|
||||
pub fn is_skipped(&self) -> bool {
|
||||
matches!(self.gate_decision, GateDecision::Skipped { .. })
|
||||
}
|
||||
}
|
||||
|
||||
/// Witness log aggregator for collecting statistics
|
||||
#[derive(Debug, Default)]
|
||||
pub struct WitnessAggregator {
|
||||
/// Total inferences
|
||||
pub count: u64,
|
||||
/// Total cycles
|
||||
pub total_cycles: u64,
|
||||
/// Total latency (ns)
|
||||
pub total_latency_ns: u64,
|
||||
/// Full inference count
|
||||
pub full_count: u64,
|
||||
/// Early exit count
|
||||
pub early_exit_count: u64,
|
||||
/// Skipped count
|
||||
pub skipped_count: u64,
|
||||
/// Sum of squares for variance calculation
|
||||
latency_sq_sum: u128,
|
||||
}
|
||||
|
||||
impl WitnessAggregator {
|
||||
/// Create a new aggregator
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Add a witness to the aggregate
|
||||
pub fn add(&mut self, witness: &WitnessLog) {
|
||||
self.count += 1;
|
||||
self.total_cycles += witness.cycles as u64;
|
||||
self.total_latency_ns += witness.latency_ns as u64;
|
||||
self.latency_sq_sum += (witness.latency_ns as u128).pow(2);
|
||||
|
||||
match witness.gate_decision {
|
||||
GateDecision::RanFull => self.full_count += 1,
|
||||
GateDecision::EarlyExit { .. } => self.early_exit_count += 1,
|
||||
GateDecision::Skipped { .. } => self.skipped_count += 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average latency (ns)
|
||||
pub fn avg_latency_ns(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_latency_ns as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average cycles
|
||||
pub fn avg_cycles(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_cycles as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get latency standard deviation (ns)
|
||||
pub fn latency_std_ns(&self) -> f64 {
|
||||
if self.count <= 1 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean = self.avg_latency_ns();
|
||||
let variance = (self.latency_sq_sum as f64 / self.count as f64) - (mean * mean);
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
/// Get early exit rate
|
||||
pub fn early_exit_rate(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.early_exit_count as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get skip rate
|
||||
pub fn skip_rate(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.skipped_count as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_witness_builder() {
|
||||
let witness = WitnessBuilder::new(BackendKind::NativeSim)
|
||||
.model_hash([1u8; 32])
|
||||
.quant_hash([2u8; 32])
|
||||
.cycles(1000)
|
||||
.gate_decision(GateDecision::RanFull)
|
||||
.build();
|
||||
|
||||
assert_eq!(witness.model_hash, [1u8; 32]);
|
||||
assert_eq!(witness.backend, BackendKind::NativeSim);
|
||||
assert_eq!(witness.cycles, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_bytes_roundtrip() {
|
||||
let witness = WitnessLog::new(
|
||||
[0x42u8; 32],
|
||||
[0x24u8; 32],
|
||||
BackendKind::FpgaDaemon,
|
||||
5000,
|
||||
100_000,
|
||||
GateDecision::EarlyExit { layer: 4 },
|
||||
);
|
||||
|
||||
let bytes = witness.to_bytes();
|
||||
let parsed = WitnessLog::from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(witness.model_hash, parsed.model_hash);
|
||||
assert_eq!(witness.quant_hash, parsed.quant_hash);
|
||||
assert_eq!(witness.backend, parsed.backend);
|
||||
assert_eq!(witness.cycles, parsed.cycles);
|
||||
assert_eq!(witness.latency_ns, parsed.latency_ns);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_aggregator() {
|
||||
let mut agg = WitnessAggregator::new();
|
||||
|
||||
for i in 0..10 {
|
||||
let mut witness = WitnessLog::empty();
|
||||
witness.latency_ns = 1000 * (i + 1);
|
||||
witness.cycles = 100 * (i + 1);
|
||||
if i < 3 {
|
||||
witness.gate_decision = GateDecision::EarlyExit { layer: 2 };
|
||||
}
|
||||
agg.add(&witness);
|
||||
}
|
||||
|
||||
assert_eq!(agg.count, 10);
|
||||
assert_eq!(agg.early_exit_count, 3);
|
||||
assert!((agg.early_exit_rate() - 0.3).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
12
vendor/ruvector/crates/ruvector-fpga-transformer/src/witness/mod.rs
vendored
Normal file
12
vendor/ruvector/crates/ruvector-fpga-transformer/src/witness/mod.rs
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Witness logging for audit trails and ReasoningBank integration
|
||||
//!
|
||||
//! Every inference produces a small witness bundle that records
|
||||
//! what happened and enables verification and replay.
|
||||
|
||||
pub mod hash;
|
||||
pub mod log;
|
||||
|
||||
// Re-export WitnessLog from types as the canonical location
|
||||
pub use crate::types::WitnessLog;
|
||||
pub use hash::{compute_witness_hash, verify_witness_hash};
|
||||
pub use log::{WitnessAggregator, WitnessBuilder};
|
||||
Reference in New Issue
Block a user