Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
94
crates/ruvector-fpga-transformer/Cargo.toml
Normal file
94
crates/ruvector-fpga-transformer/Cargo.toml
Normal file
@@ -0,0 +1,94 @@
|
||||
[package]
|
||||
name = "ruvector-fpga-transformer"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
rust-version = "1.77"
|
||||
authors = ["RuVector Team"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
description = "FPGA Transformer backend with deterministic latency, quantization-first design, and coherence gating"
|
||||
repository = "https://github.com/ruvnet/ruvector"
|
||||
keywords = ["fpga", "transformer", "inference", "quantization", "low-latency"]
|
||||
categories = ["algorithms", "embedded", "hardware-support"]
|
||||
readme = "README.md"
|
||||
|
||||
[lib]
|
||||
crate-type = ["rlib"]
|
||||
|
||||
[features]
|
||||
default = ["daemon", "native_sim", "witness"]
|
||||
|
||||
# Backend selection
|
||||
daemon = []
|
||||
native_sim = []
|
||||
pcie = ["memmap2"]
|
||||
|
||||
# WASM support
|
||||
wasm = ["wasm-bindgen", "getrandom/js", "js-sys"]
|
||||
|
||||
# Verification
|
||||
strict_verify = []
|
||||
witness = []
|
||||
|
||||
# Inference options
|
||||
topk_only = []
|
||||
lut_softmax = []
|
||||
pwl_softmax = []
|
||||
|
||||
# Development
|
||||
trace = []
|
||||
|
||||
[dependencies]
|
||||
# Core
|
||||
thiserror = "2.0"
|
||||
anyhow = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
# Crypto for artifact verification
|
||||
sha2 = "0.10"
|
||||
ed25519-dalek = { version = "2.1", features = ["rand_core"] }
|
||||
rand_core = { version = "0.6", features = ["getrandom"] }
|
||||
rand = "0.8"
|
||||
hex = "0.4"
|
||||
serde_bytes = "0.11"
|
||||
|
||||
# Async for daemon communication
|
||||
tokio = { version = "1.41", features = ["io-util", "net", "sync", "rt"], optional = true }
|
||||
|
||||
# Memory mapping for PCIe
|
||||
memmap2 = { version = "0.9", optional = true }
|
||||
|
||||
# WASM bindings
|
||||
wasm-bindgen = { version = "0.2", optional = true }
|
||||
js-sys = { version = "0.3", optional = true }
|
||||
getrandom = { version = "0.2", optional = true }
|
||||
|
||||
# Tracing for development
|
||||
tracing = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.5"
|
||||
rand = "0.8"
|
||||
tokio = { version = "1.41", features = ["rt-multi-thread", "macros"] }
|
||||
|
||||
[[bench]]
|
||||
name = "latency"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "correctness"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "gating"
|
||||
harness = false
|
||||
|
||||
[[example]]
|
||||
name = "basic_inference"
|
||||
path = "examples/basic_inference.rs"
|
||||
|
||||
[[example]]
|
||||
name = "daemon_client"
|
||||
path = "examples/daemon_client.rs"
|
||||
required-features = ["daemon"]
|
||||
301
crates/ruvector-fpga-transformer/README.md
Normal file
301
crates/ruvector-fpga-transformer/README.md
Normal file
@@ -0,0 +1,301 @@
|
||||
# FPGA Transformer
|
||||
|
||||
**Run AI models on specialized hardware with predictable, ultra-low latency.**
|
||||
|
||||
FPGA Transformer is a Rust library that lets you run transformer neural networks (like those used in ChatGPT, code completion, and other AI applications) on FPGA hardware instead of GPUs. This gives you consistent, predictable response times - essential for real-time applications.
|
||||
|
||||
## Why Use This?
|
||||
|
||||
| Problem | Solution |
|
||||
|---------|----------|
|
||||
| GPU inference has unpredictable latency spikes | FPGAs provide deterministic, bounded timing |
|
||||
| Cloud AI is too slow for edge devices | Run models locally on low-power FPGAs |
|
||||
| Need to verify AI didn't hallucinate | Witness logging proves what computation ran |
|
||||
| Want to skip unnecessary computation | Coherence gating exits early when confident |
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Add to your Cargo.toml
|
||||
cargo add ruvector-fpga-transformer
|
||||
```
|
||||
|
||||
```rust
|
||||
use ruvector_fpga_transformer::prelude::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Create an engine (uses CPU simulator by default)
|
||||
let mut engine = Engine::native_sim();
|
||||
|
||||
// Load your model
|
||||
let model_bytes = std::fs::read("model.rvt")?;
|
||||
let model_id = engine.load_artifact(&model_bytes)?;
|
||||
|
||||
// Prepare input tokens
|
||||
let tokens: Vec<u16> = vec![1, 2, 3, 4, 5]; // Your tokenized input
|
||||
let mask = vec![1u8; tokens.len()]; // Attention mask
|
||||
|
||||
// Run inference
|
||||
let request = InferenceRequest::new(
|
||||
model_id,
|
||||
FixedShape::micro(), // 32 seq, 64 dim, 4096 vocab
|
||||
&tokens,
|
||||
&mask,
|
||||
GateHint::allow_all(),
|
||||
);
|
||||
|
||||
let result = engine.infer(request)?;
|
||||
|
||||
// Get predictions
|
||||
println!("Top prediction: token {}", result.topk.unwrap()[0].0);
|
||||
println!("Latency: {}ns", result.witness.latency_ns);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Core Capabilities
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| **Deterministic Latency** | Fixed execution time - no surprise slowdowns |
|
||||
| **Quantization-First** | INT4/INT8 math for 4-8x memory savings |
|
||||
| **Zero Allocation Hot Path** | No garbage collection pauses during inference |
|
||||
| **Early Exit** | Stop computation when the model is already confident |
|
||||
| **Witness Logging** | Cryptographic proof of what ran and when |
|
||||
|
||||
### Supported Backends
|
||||
|
||||
| Backend | Use Case | Feature Flag |
|
||||
|---------|----------|--------------|
|
||||
| **NativeSim** | Development & testing on any CPU | `native_sim` |
|
||||
| **WasmSim** | Run in web browsers | `wasm` |
|
||||
| **FpgaDaemon** | Connect to FPGA via network | `daemon` |
|
||||
| **FpgaPcie** | Direct PCIe access (fastest) | `pcie` |
|
||||
|
||||
## Model Shapes
|
||||
|
||||
Pre-defined configurations for common use cases:
|
||||
|
||||
| Shape | Sequence | Dimensions | Vocab | Use Case |
|
||||
|-------|----------|------------|-------|----------|
|
||||
| `micro()` | 32 | 64 | 4,096 | Testing, tiny models |
|
||||
| `small()` | 128 | 256 | 32,768 | Edge devices |
|
||||
| `medium()` | 512 | 512 | 50,257 | Standard inference |
|
||||
| `large()` | 2,048 | 1,024 | 50,257 | High-quality output |
|
||||
|
||||
```rust
|
||||
// Use predefined shapes
|
||||
let shape = FixedShape::small();
|
||||
|
||||
// Or create custom
|
||||
let custom = FixedShape {
|
||||
seq_len: 256,
|
||||
d_model: 384,
|
||||
vocab: 16000,
|
||||
};
|
||||
```
|
||||
|
||||
## Coherence Gating
|
||||
|
||||
Skip unnecessary computation when the model is already confident:
|
||||
|
||||
```rust
|
||||
use ruvector_fpga_transformer::gating::{GatingConfig, PolicyGate};
|
||||
|
||||
// Configure early exit behavior
|
||||
let config = GatingConfig {
|
||||
min_coherence: 0.7, // Require 70% confidence to exit early
|
||||
max_compute_class: 3, // Allow up to 3 layers before forcing exit
|
||||
allow_writes: true, // Allow writes if confidence is high
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let gate = PolicyGate::new(config);
|
||||
```
|
||||
|
||||
**Gate Decisions:**
|
||||
- `RanFull` - Model ran all layers
|
||||
- `EarlyExit { layer }` - Exited early at specified layer
|
||||
- `Skipped { reason }` - Computation was blocked
|
||||
|
||||
## Quantization
|
||||
|
||||
Convert floating-point models to efficient integer math:
|
||||
|
||||
| Format | Bits | Memory Savings | Use Case |
|
||||
|--------|------|----------------|----------|
|
||||
| INT8 | 8 | 4x | General purpose |
|
||||
| INT4 | 4 | 8x | Memory-constrained |
|
||||
| Binary | 1 | 32x | Ultra-compact |
|
||||
|
||||
```rust
|
||||
// INT8 quantization (recommended)
|
||||
let quant = QuantSpec::int8();
|
||||
|
||||
// INT4 for memory savings
|
||||
let quant = QuantSpec::int4();
|
||||
|
||||
// Custom quantization
|
||||
let quant = QuantSpec {
|
||||
bits: 8,
|
||||
scale: 127.0,
|
||||
zero_point: 0,
|
||||
symmetric: true,
|
||||
};
|
||||
```
|
||||
|
||||
## Witness Logging
|
||||
|
||||
Every inference produces a cryptographic witness proving:
|
||||
- Which model ran (by hash)
|
||||
- What quantization was used
|
||||
- Which backend executed it
|
||||
- Exact cycle count and latency
|
||||
- Gate decision made
|
||||
|
||||
```rust
|
||||
let result = engine.infer(request)?;
|
||||
let witness = &result.witness;
|
||||
|
||||
println!("Model hash: {}", hex::encode(&witness.model_hash));
|
||||
println!("Backend: {:?}", witness.backend);
|
||||
println!("Cycles: {}", witness.cycles);
|
||||
println!("Decision: {:?}", witness.gate_decision);
|
||||
|
||||
// Verify witness authenticity
|
||||
assert!(witness.verify());
|
||||
```
|
||||
|
||||
## Backend Selection
|
||||
|
||||
### Native Simulator (Default)
|
||||
Best for development and testing:
|
||||
|
||||
```rust
|
||||
let engine = Engine::native_sim();
|
||||
```
|
||||
|
||||
### FPGA Daemon
|
||||
Connect to a remote FPGA over network:
|
||||
|
||||
```rust
|
||||
use ruvector_fpga_transformer::backend::fpga_daemon::{FpgaDaemonBackend, DaemonConnection};
|
||||
|
||||
let backend = FpgaDaemonBackend::with_connection(
|
||||
DaemonConnection::tcp("192.168.1.100:9000"),
|
||||
Default::default(),
|
||||
);
|
||||
```
|
||||
|
||||
### FPGA PCIe (Fastest)
|
||||
Direct hardware access:
|
||||
|
||||
```rust
|
||||
use ruvector_fpga_transformer::backend::fpga_pcie::{FpgaPcieBackend, PcieConfig};
|
||||
|
||||
let config = PcieConfig {
|
||||
device_path: "/dev/ruvector0".into(),
|
||||
ring_slots: 16,
|
||||
dma_timeout_ms: 100,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let backend = FpgaPcieBackend::new(config)?;
|
||||
```
|
||||
|
||||
## Feature Flags
|
||||
|
||||
Enable only what you need:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
ruvector-fpga-transformer = { version = "0.1", default-features = false, features = ["native_sim"] }
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `native_sim` | CPU-based simulator |
|
||||
| `daemon` | Network daemon client |
|
||||
| `pcie` | Direct PCIe access |
|
||||
| `wasm` | WebAssembly support |
|
||||
| `witness` | Witness logging |
|
||||
| `strict_verify` | Extra verification checks |
|
||||
| `lut_softmax` | LUT-based softmax (faster) |
|
||||
| `trace` | Debug tracing |
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Use appropriate shapes** - Don't use `large()` for simple tasks
|
||||
2. **Enable early exit** - Set reasonable `min_coherence` threshold
|
||||
3. **Batch requests** - Reuse loaded models across multiple inferences
|
||||
4. **Use topk_only** - Return only top predictions, not full vocabulary
|
||||
|
||||
```rust
|
||||
// Efficient configuration
|
||||
let config = DaemonConfig {
|
||||
topk_only: true, // Only return top-K predictions
|
||||
topk: 10, // Return top 10
|
||||
retries: 3, // Retry on transient failures
|
||||
..Default::default()
|
||||
};
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ Engine │
|
||||
│ (public) │
|
||||
└──────┬──────┘
|
||||
│
|
||||
┌────────────┼────────────┐
|
||||
│ │ │
|
||||
┌─────▼─────┐ ┌────▼────┐ ┌─────▼─────┐
|
||||
│ Coherence │ │ Backend │ │ Witness │
|
||||
│ Gate │ │ Trait │ │ Log │
|
||||
└───────────┘ └────┬────┘ └───────────┘
|
||||
│
|
||||
┌────────┬────────┼────────┬────────┐
|
||||
│ │ │ │ │
|
||||
┌───▼───┐┌───▼───┐┌───▼───┐┌───▼───┐
|
||||
│Native ││ WASM ││Daemon ││ PCIe │
|
||||
│ Sim ││ Sim ││ ││ │
|
||||
└───────┘└───────┘└───────┘└───────┘
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the [examples](./examples/) directory:
|
||||
- `basic_inference.rs` - Simple inference example
|
||||
- `daemon_client.rs` - Connect to FPGA daemon
|
||||
|
||||
Run examples:
|
||||
```bash
|
||||
cargo run --example basic_inference
|
||||
cargo run --example daemon_client --features daemon
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
cargo test --features native_sim
|
||||
|
||||
# Run with tracing
|
||||
RUST_LOG=debug cargo test --features "native_sim trace"
|
||||
|
||||
# Run benchmarks
|
||||
cargo bench --features native_sim
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions welcome! Please read the [contributing guidelines](../../CONTRIBUTING.md) first.
|
||||
157
crates/ruvector-fpga-transformer/benches/correctness.rs
Normal file
157
crates/ruvector-fpga-transformer/benches/correctness.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
//! Correctness and determinism benchmarks
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruvector_fpga_transformer::{
|
||||
artifact::{Manifest, ModelArtifact},
|
||||
backend::native_sim::NativeSimBackend,
|
||||
backend::TransformerBackend,
|
||||
gating::DefaultCoherenceGate,
|
||||
types::{FixedShape, GateHint, InferenceRequest, QuantSpec},
|
||||
};
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let shape = FixedShape::micro();
|
||||
let manifest = Manifest {
|
||||
name: "determinism_test".into(),
|
||||
model_hash: String::new(),
|
||||
shape,
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
let embedding_size = shape.vocab as usize * shape.d_model as usize;
|
||||
let weights: Vec<u8> = (0..embedding_size).map(|i| (i % 256) as u8).collect();
|
||||
|
||||
ModelArtifact::new(manifest, weights, None, None, vec![])
|
||||
}
|
||||
|
||||
fn bench_determinism(c: &mut Criterion) {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let artifact = create_test_artifact();
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
let shape = FixedShape::micro();
|
||||
|
||||
let tokens: Vec<u16> = (0..shape.seq_len)
|
||||
.map(|i| (i * 7) % shape.vocab as u16)
|
||||
.collect();
|
||||
let mask = vec![1u8; shape.seq_len as usize];
|
||||
|
||||
c.bench_function("determinism_check_1000", |b| {
|
||||
b.iter(|| {
|
||||
let mut first_hash: Option<u64> = None;
|
||||
|
||||
for _ in 0..1000 {
|
||||
let req = InferenceRequest::new(
|
||||
model_id,
|
||||
shape,
|
||||
black_box(&tokens),
|
||||
&mask,
|
||||
GateHint::allow_all(),
|
||||
);
|
||||
let result = backend.infer(req).unwrap();
|
||||
|
||||
// Hash the logits
|
||||
let hash = result
|
||||
.logits_q
|
||||
.iter()
|
||||
.fold(0u64, |acc, &v| acc.wrapping_mul(31).wrapping_add(v as u64));
|
||||
|
||||
match first_hash {
|
||||
None => first_hash = Some(hash),
|
||||
Some(expected) => assert_eq!(hash, expected, "Non-deterministic output"),
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_golden_vectors(c: &mut Criterion) {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let artifact = create_test_artifact();
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
let shape = FixedShape::micro();
|
||||
|
||||
// Create golden vectors
|
||||
let test_inputs: Vec<Vec<u16>> = (0..128)
|
||||
.map(|seed| {
|
||||
(0..shape.seq_len)
|
||||
.map(|i| ((i as usize * seed + 1) % shape.vocab as usize) as u16)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mask = vec![1u8; shape.seq_len as usize];
|
||||
|
||||
// Compute expected outputs
|
||||
let expected: Vec<Vec<i16>> = test_inputs
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let req = InferenceRequest::new(model_id, shape, tokens, &mask, GateHint::allow_all());
|
||||
backend.infer(req).unwrap().logits_q
|
||||
})
|
||||
.collect();
|
||||
|
||||
c.bench_function("golden_vector_validation", |b| {
|
||||
b.iter(|| {
|
||||
for (tokens, exp) in test_inputs.iter().zip(&expected) {
|
||||
let req = InferenceRequest::new(
|
||||
model_id,
|
||||
shape,
|
||||
black_box(tokens),
|
||||
&mask,
|
||||
GateHint::allow_all(),
|
||||
);
|
||||
let result = backend.infer(req).unwrap();
|
||||
|
||||
// Compute max abs error
|
||||
let max_err: i32 = result
|
||||
.logits_q
|
||||
.iter()
|
||||
.zip(exp)
|
||||
.map(|(&a, &b)| (a as i32 - b as i32).abs())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
assert_eq!(max_err, 0, "Golden vector mismatch");
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_quantization_accuracy(c: &mut Criterion) {
|
||||
use ruvector_fpga_transformer::quant::qformat::{quantize_symmetric_i8, QuantizedMatrix};
|
||||
|
||||
c.bench_function("quantize_matrix_256x256", |b| {
|
||||
let data: Vec<f32> = (0..256 * 256).map(|i| (i as f32 * 0.001).sin()).collect();
|
||||
|
||||
b.iter(|| {
|
||||
let matrix = QuantizedMatrix::from_f32(black_box(&data), 256, 256);
|
||||
let dequant = matrix.to_f32();
|
||||
|
||||
// Check reconstruction error
|
||||
let max_err: f32 = data
|
||||
.iter()
|
||||
.zip(&dequant)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
assert!(max_err < 0.1, "Quantization error too high: {}", max_err);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_determinism,
|
||||
bench_golden_vectors,
|
||||
bench_quantization_accuracy
|
||||
);
|
||||
criterion_main!(benches);
|
||||
154
crates/ruvector-fpga-transformer/benches/gating.rs
Normal file
154
crates/ruvector-fpga-transformer/benches/gating.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
//! Gating subsystem benchmarks
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruvector_fpga_transformer::{
|
||||
artifact::{Manifest, ModelArtifact},
|
||||
backend::native_sim::NativeSimBackend,
|
||||
backend::TransformerBackend,
|
||||
gating::{CoherenceConfig, CoherenceGate, DefaultCoherenceGate},
|
||||
types::{ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest, QuantSpec},
|
||||
};
|
||||
|
||||
fn bench_skip_rate_distribution(c: &mut Criterion) {
|
||||
let gate = DefaultCoherenceGate::new();
|
||||
|
||||
// Generate synthetic coherence distribution
|
||||
let coherence_values: Vec<i16> = (-500..500).collect();
|
||||
|
||||
c.bench_function("skip_rate_uniform_distribution", |b| {
|
||||
b.iter(|| {
|
||||
let mut skipped = 0u32;
|
||||
let mut ran = 0u32;
|
||||
|
||||
for &coherence in &coherence_values {
|
||||
let hint = GateHint::new(coherence, false, ComputeClass::Deliberative);
|
||||
match gate.preflight(black_box(&hint)) {
|
||||
GateDecision::Skipped { .. } => skipped += 1,
|
||||
_ => ran += 1,
|
||||
}
|
||||
}
|
||||
|
||||
(skipped, ran)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_early_exit_histogram(c: &mut Criterion) {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let shape = FixedShape::micro();
|
||||
let manifest = Manifest {
|
||||
name: "early_exit_test".into(),
|
||||
model_hash: String::new(),
|
||||
shape,
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
let embedding_size = shape.vocab as usize * shape.d_model as usize;
|
||||
let artifact = ModelArtifact::new(manifest, vec![0u8; embedding_size], None, None, vec![]);
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
|
||||
let tokens: Vec<u16> = (0..shape.seq_len).collect();
|
||||
let mask = vec![1u8; shape.seq_len as usize];
|
||||
|
||||
// Test with varying coherence levels
|
||||
let coherence_levels: Vec<i16> = vec![-500, -200, 0, 200, 500, 1000, 2000];
|
||||
|
||||
let mut group = c.benchmark_group("early_exit_by_coherence");
|
||||
|
||||
for coherence in coherence_levels {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("coherence", coherence),
|
||||
&coherence,
|
||||
|b, &coherence| {
|
||||
let hint = GateHint::new(coherence, false, ComputeClass::Deliberative);
|
||||
|
||||
b.iter(|| {
|
||||
let req =
|
||||
InferenceRequest::new(model_id, shape, black_box(&tokens), &mask, hint);
|
||||
let result = backend.infer(req).unwrap();
|
||||
result.witness.gate_decision
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_checkpoint_overhead(c: &mut Criterion) {
|
||||
let configs = [
|
||||
("default", CoherenceConfig::default()),
|
||||
("strict", CoherenceConfig::strict()),
|
||||
("permissive", CoherenceConfig::permissive()),
|
||||
];
|
||||
|
||||
let mut group = c.benchmark_group("checkpoint_overhead");
|
||||
|
||||
for (name, config) in configs {
|
||||
let gate = DefaultCoherenceGate::with_config(config);
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("config", name), &gate, |b, gate| {
|
||||
b.iter(|| {
|
||||
let mut decision = None;
|
||||
for layer in 0u8..8 {
|
||||
let signal = (layer as i16) * 150;
|
||||
if let Some(d) = gate.checkpoint(black_box(layer), black_box(signal)) {
|
||||
decision = Some(d);
|
||||
break;
|
||||
}
|
||||
}
|
||||
decision
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_mincut_gating(c: &mut Criterion) {
|
||||
use ruvector_fpga_transformer::gating::coherence_gate::MincutCoherenceGate;
|
||||
|
||||
let config = CoherenceConfig::default();
|
||||
let gate = MincutCoherenceGate::new(config, 50, 200);
|
||||
|
||||
let hints = [
|
||||
(
|
||||
"high_lambda",
|
||||
GateHint::new(500, false, ComputeClass::Deliberative),
|
||||
),
|
||||
(
|
||||
"low_lambda",
|
||||
GateHint::new(100, false, ComputeClass::Deliberative),
|
||||
),
|
||||
(
|
||||
"boundary_crossed",
|
||||
GateHint::new(300, true, ComputeClass::Deliberative),
|
||||
),
|
||||
];
|
||||
|
||||
let mut group = c.benchmark_group("mincut_gating");
|
||||
|
||||
for (name, hint) in hints {
|
||||
group.bench_with_input(BenchmarkId::new("preflight", name), &hint, |b, hint| {
|
||||
b.iter(|| gate.preflight(black_box(hint)))
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_skip_rate_distribution,
|
||||
bench_early_exit_histogram,
|
||||
bench_checkpoint_overhead,
|
||||
bench_mincut_gating
|
||||
);
|
||||
criterion_main!(benches);
|
||||
143
crates/ruvector-fpga-transformer/benches/latency.rs
Normal file
143
crates/ruvector-fpga-transformer/benches/latency.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
//! Latency benchmarks for FPGA Transformer
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruvector_fpga_transformer::{
|
||||
artifact::{Manifest, ModelArtifact},
|
||||
backend::native_sim::NativeSimBackend,
|
||||
backend::TransformerBackend,
|
||||
gating::DefaultCoherenceGate,
|
||||
types::{FixedShape, GateHint, InferenceRequest, ModelId, QuantSpec},
|
||||
};
|
||||
|
||||
fn create_test_artifact(shape: FixedShape) -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "bench_model".into(),
|
||||
model_hash: String::new(),
|
||||
shape,
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
// Create minimal weights
|
||||
let embedding_size = shape.vocab as usize * shape.d_model as usize;
|
||||
let weights = vec![0u8; embedding_size];
|
||||
|
||||
ModelArtifact::new(manifest, weights, None, None, vec![])
|
||||
}
|
||||
|
||||
fn bench_inference(c: &mut Criterion) {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let shape = FixedShape::micro();
|
||||
let artifact = create_test_artifact(shape);
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
|
||||
let tokens: Vec<u16> = (0..shape.seq_len).collect();
|
||||
let mask = vec![1u8; shape.seq_len as usize];
|
||||
|
||||
c.bench_function("native_sim_micro_inference", |b| {
|
||||
b.iter(|| {
|
||||
let req = InferenceRequest::new(
|
||||
model_id,
|
||||
shape,
|
||||
black_box(&tokens),
|
||||
&mask,
|
||||
GateHint::allow_all(),
|
||||
);
|
||||
backend.infer(req).unwrap()
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_inference_shapes(c: &mut Criterion) {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
|
||||
let shapes = [
|
||||
("micro", FixedShape::micro()),
|
||||
("small", FixedShape::small()),
|
||||
("baseline", FixedShape::baseline()),
|
||||
];
|
||||
|
||||
let mut group = c.benchmark_group("inference_by_shape");
|
||||
|
||||
for (name, shape) in shapes {
|
||||
let backend = NativeSimBackend::new(gate.clone());
|
||||
let artifact = create_test_artifact(shape);
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
|
||||
let tokens: Vec<u16> = (0..shape.seq_len).collect();
|
||||
let mask = vec![1u8; shape.seq_len as usize];
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("native_sim", name), &shape, |b, &shape| {
|
||||
b.iter(|| {
|
||||
let req = InferenceRequest::new(
|
||||
model_id,
|
||||
shape,
|
||||
black_box(&tokens),
|
||||
&mask,
|
||||
GateHint::allow_all(),
|
||||
);
|
||||
backend.infer(req).unwrap()
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_load_unload(c: &mut Criterion) {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let artifact = create_test_artifact(FixedShape::micro());
|
||||
|
||||
c.bench_function("model_load", |b| {
|
||||
b.iter(|| {
|
||||
let id = backend.load(black_box(&artifact)).unwrap();
|
||||
backend.unload(id).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_gating(c: &mut Criterion) {
|
||||
use ruvector_fpga_transformer::gating::{CoherenceConfig, CoherenceGate};
|
||||
|
||||
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::default());
|
||||
|
||||
let hints = [
|
||||
("allow_all", GateHint::allow_all()),
|
||||
("reflex_only", GateHint::reflex_only()),
|
||||
(
|
||||
"low_coherence",
|
||||
GateHint::new(
|
||||
-500,
|
||||
true,
|
||||
ruvector_fpga_transformer::types::ComputeClass::Deliberative,
|
||||
),
|
||||
),
|
||||
];
|
||||
|
||||
let mut group = c.benchmark_group("gating_preflight");
|
||||
|
||||
for (name, hint) in hints {
|
||||
group.bench_with_input(BenchmarkId::new("preflight", name), &hint, |b, hint| {
|
||||
b.iter(|| gate.preflight(black_box(hint)))
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_inference,
|
||||
bench_inference_shapes,
|
||||
bench_load_unload,
|
||||
bench_gating
|
||||
);
|
||||
criterion_main!(benches);
|
||||
123
crates/ruvector-fpga-transformer/examples/basic_inference.rs
Normal file
123
crates/ruvector-fpga-transformer/examples/basic_inference.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
//! Basic inference example
|
||||
//!
|
||||
//! Demonstrates loading a model and running inference with the native simulator.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruvector_fpga_transformer::{
|
||||
artifact::{Manifest, ModelArtifact},
|
||||
backend::native_sim::NativeSimBackend,
|
||||
gating::DefaultCoherenceGate,
|
||||
types::{FixedShape, GateHint, InferenceRequest, QuantSpec},
|
||||
Engine,
|
||||
};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
println!("FPGA Transformer - Basic Inference Example");
|
||||
println!("==========================================\n");
|
||||
|
||||
// Create a micro-sized model for demonstration
|
||||
let shape = FixedShape::micro();
|
||||
println!("Model shape: {:?}", shape);
|
||||
|
||||
// Create manifest
|
||||
let manifest = Manifest {
|
||||
name: "demo_reflex_transformer".into(),
|
||||
model_hash: String::new(),
|
||||
shape,
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
// Create minimal weights (random for demo)
|
||||
let embedding_size = shape.vocab as usize * shape.d_model as usize;
|
||||
let weights: Vec<u8> = (0..embedding_size)
|
||||
.map(|i| ((i * 7 + 13) % 256) as u8)
|
||||
.collect();
|
||||
|
||||
println!("Weight size: {} bytes", weights.len());
|
||||
|
||||
// Create artifact
|
||||
let artifact = ModelArtifact::new(manifest, weights, None, None, vec![]);
|
||||
println!("Artifact created, model ID: {}", artifact.model_id());
|
||||
|
||||
// Create backend and engine
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = Box::new(NativeSimBackend::new(gate.clone()));
|
||||
let mut engine = Engine::new(backend, gate);
|
||||
|
||||
// Load model
|
||||
let model_id = engine.load(&artifact)?;
|
||||
println!("Model loaded successfully\n");
|
||||
|
||||
// Prepare input
|
||||
let tokens: Vec<u16> = (0..shape.seq_len).collect();
|
||||
let mask = vec![1u8; shape.seq_len as usize];
|
||||
|
||||
println!("Running inference...");
|
||||
println!(" Input tokens: {:?}...", &tokens[..4.min(tokens.len())]);
|
||||
|
||||
// Run inference with different coherence levels
|
||||
let coherence_levels = [
|
||||
(
|
||||
"High coherence",
|
||||
GateHint::new(
|
||||
500,
|
||||
false,
|
||||
ruvector_fpga_transformer::ComputeClass::Deliberative,
|
||||
),
|
||||
),
|
||||
(
|
||||
"Medium coherence",
|
||||
GateHint::new(
|
||||
100,
|
||||
false,
|
||||
ruvector_fpga_transformer::ComputeClass::Associative,
|
||||
),
|
||||
),
|
||||
(
|
||||
"Low coherence",
|
||||
GateHint::new(-100, true, ruvector_fpga_transformer::ComputeClass::Reflex),
|
||||
),
|
||||
];
|
||||
|
||||
for (name, hint) in coherence_levels {
|
||||
let req = InferenceRequest::new(model_id, shape, &tokens, &mask, hint);
|
||||
|
||||
match engine.infer(req) {
|
||||
Ok(result) => {
|
||||
println!("\n{}", name);
|
||||
println!(" Gate decision: {:?}", result.witness.gate_decision);
|
||||
println!(
|
||||
" Latency: {:.2}ms",
|
||||
result.witness.latency_ns as f64 / 1_000_000.0
|
||||
);
|
||||
|
||||
if let Some(topk) = &result.topk {
|
||||
println!(" Top-3 predictions:");
|
||||
for (i, (token, score)) in topk.iter().take(3).enumerate() {
|
||||
println!(" {}. Token {} (score: {})", i + 1, token, score);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("\n{}: Skipped - {:?}", name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Print statistics
|
||||
println!("\n==========================================");
|
||||
println!("Engine Statistics:");
|
||||
let stats = engine.stats();
|
||||
println!(" Total inferences: {}", stats.total_inferences);
|
||||
println!(" Successful: {}", stats.successful);
|
||||
println!(" Skipped: {}", stats.skipped);
|
||||
println!(" Early exits: {}", stats.early_exits);
|
||||
println!(" Success rate: {:.1}%", stats.success_rate() * 100.0);
|
||||
println!(" Avg latency: {:.2}ms", stats.avg_latency_ms());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
114
crates/ruvector-fpga-transformer/examples/daemon_client.rs
Normal file
114
crates/ruvector-fpga-transformer/examples/daemon_client.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
//! FPGA Daemon client example
|
||||
//!
|
||||
//! Demonstrates connecting to an FPGA daemon and running inference.
|
||||
//! This example requires the `daemon` feature and a running FPGA daemon.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruvector_fpga_transformer::{
|
||||
artifact::{Manifest, ModelArtifact},
|
||||
backend::fpga_daemon::{DaemonConfig, DaemonConnection, FpgaDaemonBackend},
|
||||
gating::DefaultCoherenceGate,
|
||||
types::{FixedShape, GateHint, InferenceRequest, QuantSpec},
|
||||
Engine,
|
||||
};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
println!("FPGA Transformer - Daemon Client Example");
|
||||
println!("=========================================\n");
|
||||
|
||||
// Configure daemon connection
|
||||
let socket_path = std::env::var("RUVECTOR_FPGA_SOCKET")
|
||||
.unwrap_or_else(|_| "/var/run/ruvector_fpga.sock".into());
|
||||
|
||||
println!("Connecting to daemon at: {}", socket_path);
|
||||
|
||||
let connection = DaemonConnection::unix(&socket_path);
|
||||
let config = DaemonConfig {
|
||||
connect_timeout_ms: 5000,
|
||||
read_timeout_ms: 10000,
|
||||
write_timeout_ms: 5000,
|
||||
retries: 3,
|
||||
backoff_multiplier: 2.0,
|
||||
topk_only: true,
|
||||
topk: 16,
|
||||
};
|
||||
|
||||
// Create backend
|
||||
let backend = Box::new(FpgaDaemonBackend::with_connection(connection, config));
|
||||
|
||||
// Create gate and engine
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let mut engine = Engine::new(backend, gate);
|
||||
|
||||
// Create a test model
|
||||
let shape = FixedShape::micro();
|
||||
let manifest = Manifest {
|
||||
name: "fpga_test_model".into(),
|
||||
model_hash: String::new(),
|
||||
shape,
|
||||
quant: QuantSpec::int4_int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
let embedding_size = shape.vocab as usize * shape.d_model as usize / 2; // INT4 packed
|
||||
let weights: Vec<u8> = (0..embedding_size)
|
||||
.map(|i| ((i * 11 + 7) % 256) as u8)
|
||||
.collect();
|
||||
|
||||
let artifact = ModelArtifact::new(manifest, weights, None, None, vec![]);
|
||||
|
||||
// Try to load the model
|
||||
println!("Loading model...");
|
||||
match engine.load(&artifact) {
|
||||
Ok(model_id) => {
|
||||
println!("Model loaded: {}", model_id);
|
||||
|
||||
// Prepare input
|
||||
let tokens: Vec<u16> = (0..shape.seq_len).map(|i| i * 2).collect();
|
||||
let mask = vec![1u8; shape.seq_len as usize];
|
||||
|
||||
// Run inference
|
||||
println!("\nRunning FPGA inference...");
|
||||
let req = InferenceRequest::new(model_id, shape, &tokens, &mask, GateHint::allow_all());
|
||||
|
||||
match engine.infer(req) {
|
||||
Ok(result) => {
|
||||
println!("Inference successful!");
|
||||
println!(" Backend: {:?}", result.witness.backend);
|
||||
println!(" Cycles: {}", result.witness.cycles);
|
||||
println!(
|
||||
" Latency: {}ns ({:.3}ms)",
|
||||
result.witness.latency_ns,
|
||||
result.witness.latency_ns as f64 / 1_000_000.0
|
||||
);
|
||||
println!(" Gate decision: {:?}", result.witness.gate_decision);
|
||||
|
||||
if let Some(topk) = &result.topk {
|
||||
println!("\n Top-5 predictions:");
|
||||
for (i, (token, score)) in topk.iter().take(5).enumerate() {
|
||||
println!(" {}. Token {} (score: {})", i + 1, token, score);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Inference failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Unload model
|
||||
engine.unload(model_id)?;
|
||||
println!("\nModel unloaded");
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load model: {}", e);
|
||||
eprintln!("\nMake sure the FPGA daemon is running:");
|
||||
eprintln!(" ruvector-fpga-daemon --socket {}", socket_path);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
266
crates/ruvector-fpga-transformer/src/artifact/manifest.rs
Normal file
266
crates/ruvector-fpga-transformer/src/artifact/manifest.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
//! Manifest schema for model artifacts
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{FixedShape, Layout, QuantSpec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Model manifest containing all metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Manifest {
|
||||
/// Model name
|
||||
pub name: String,
|
||||
/// SHA-256 hash of model (hex string)
|
||||
pub model_hash: String,
|
||||
/// Fixed shape specification
|
||||
pub shape: FixedShape,
|
||||
/// Quantization specification
|
||||
pub quant: QuantSpec,
|
||||
/// I/O configuration
|
||||
pub io: IoSpec,
|
||||
/// Backend configuration
|
||||
pub backend: BackendSpec,
|
||||
/// Test vector specification
|
||||
pub tests: TestSpec,
|
||||
}
|
||||
|
||||
impl Manifest {
|
||||
/// Create a new manifest
|
||||
pub fn new(name: impl Into<String>, shape: FixedShape, quant: QuantSpec) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
model_hash: String::new(),
|
||||
shape,
|
||||
quant,
|
||||
io: IoSpec::default(),
|
||||
backend: BackendSpec::default(),
|
||||
tests: TestSpec::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate manifest consistency
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.name.is_empty() {
|
||||
return Err(Error::InvalidArtifact("Model name is empty".into()));
|
||||
}
|
||||
|
||||
// Validate shape
|
||||
self.shape
|
||||
.validate()
|
||||
.map_err(|e| Error::InvalidArtifact(e))?;
|
||||
|
||||
// Validate quantization bits
|
||||
if !matches!(self.quant.w_bits, 1 | 2 | 4 | 8 | 16) {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Invalid weight bits: {}",
|
||||
self.quant.w_bits
|
||||
)));
|
||||
}
|
||||
if !matches!(self.quant.a_bits, 4 | 8 | 16 | 32) {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Invalid activation bits: {}",
|
||||
self.quant.a_bits
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert to JSON string
|
||||
pub fn to_json(&self) -> Result<String> {
|
||||
serde_json::to_string_pretty(self).map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Parse from JSON string
|
||||
pub fn from_json(json: &str) -> Result<Self> {
|
||||
serde_json::from_str(json).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// I/O type specification
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IoSpec {
|
||||
/// Input token type (typically "u16")
|
||||
pub tokens: String,
|
||||
/// Output logit type (typically "i16" or "i32")
|
||||
pub logits: String,
|
||||
/// Top-K count (0 for full logits)
|
||||
pub topk: u16,
|
||||
}
|
||||
|
||||
impl Default for IoSpec {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tokens: "u16".into(),
|
||||
logits: "i16".into(),
|
||||
topk: 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend-specific configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BackendSpec {
|
||||
/// Backend kind ("fpga_pcie", "fpga_daemon", "native_sim", "wasm_sim")
|
||||
pub kind: String,
|
||||
/// Protocol version
|
||||
pub protocol: u16,
|
||||
/// Backend-specific options
|
||||
#[serde(default)]
|
||||
pub options: BackendOptions,
|
||||
}
|
||||
|
||||
impl Default for BackendSpec {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
kind: "native_sim".into(),
|
||||
protocol: 1,
|
||||
options: BackendOptions::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend-specific options
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct BackendOptions {
|
||||
/// Enable batch processing
|
||||
#[serde(default)]
|
||||
pub batch_enabled: bool,
|
||||
/// Maximum batch size
|
||||
#[serde(default)]
|
||||
pub max_batch: u16,
|
||||
/// Enable early exit
|
||||
#[serde(default)]
|
||||
pub early_exit: bool,
|
||||
/// Minimum coherence threshold for early exit
|
||||
#[serde(default)]
|
||||
pub early_exit_threshold: i16,
|
||||
/// FPGA clock frequency in MHz (for cycle estimation)
|
||||
#[serde(default)]
|
||||
pub clock_mhz: u16,
|
||||
}
|
||||
|
||||
/// Test vector specification
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TestSpec {
|
||||
/// Number of test vectors
|
||||
pub vectors: u32,
|
||||
/// Maximum absolute error allowed
|
||||
pub max_abs_err: i32,
|
||||
/// Whether test vectors must pass before activation
|
||||
#[serde(default = "default_true")]
|
||||
pub require_pass: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl Default for TestSpec {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vectors: 0,
|
||||
max_abs_err: 2,
|
||||
require_pass: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Manifest builder for convenient construction
|
||||
pub struct ManifestBuilder {
|
||||
manifest: Manifest,
|
||||
}
|
||||
|
||||
impl ManifestBuilder {
|
||||
/// Create a new builder with name and shape
|
||||
pub fn new(name: impl Into<String>, shape: FixedShape) -> Self {
|
||||
Self {
|
||||
manifest: Manifest::new(name, shape, QuantSpec::int8()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set quantization spec
|
||||
pub fn quant(mut self, quant: QuantSpec) -> Self {
|
||||
self.manifest.quant = quant;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set model hash
|
||||
pub fn model_hash(mut self, hash: impl Into<String>) -> Self {
|
||||
self.manifest.model_hash = hash.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set I/O spec
|
||||
pub fn io(mut self, io: IoSpec) -> Self {
|
||||
self.manifest.io = io;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set backend spec
|
||||
pub fn backend(mut self, backend: BackendSpec) -> Self {
|
||||
self.manifest.backend = backend;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set test spec
|
||||
pub fn tests(mut self, tests: TestSpec) -> Self {
|
||||
self.manifest.tests = tests;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable top-K only output
|
||||
pub fn topk_only(mut self, k: u16) -> Self {
|
||||
self.manifest.io.topk = k;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable early exit
|
||||
pub fn early_exit(mut self, threshold: i16) -> Self {
|
||||
self.manifest.backend.options.early_exit = true;
|
||||
self.manifest.backend.options.early_exit_threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the manifest
|
||||
pub fn build(self) -> Result<Manifest> {
|
||||
self.manifest.validate()?;
|
||||
Ok(self.manifest)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_manifest_builder() {
|
||||
let manifest = ManifestBuilder::new("test", FixedShape::micro())
|
||||
.quant(QuantSpec::int4_int8())
|
||||
.topk_only(16)
|
||||
.early_exit(100)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manifest.name, "test");
|
||||
assert_eq!(manifest.quant.w_bits, 4);
|
||||
assert_eq!(manifest.io.topk, 16);
|
||||
assert!(manifest.backend.options.early_exit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_json_roundtrip() {
|
||||
let manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
|
||||
let json = manifest.to_json().unwrap();
|
||||
let parsed = Manifest::from_json(&json).unwrap();
|
||||
assert_eq!(manifest.name, parsed.name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_validation() {
|
||||
let mut manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
|
||||
assert!(manifest.validate().is_ok());
|
||||
|
||||
manifest.name = String::new();
|
||||
assert!(manifest.validate().is_err());
|
||||
}
|
||||
}
|
||||
242
crates/ruvector-fpga-transformer/src/artifact/mod.rs
Normal file
242
crates/ruvector-fpga-transformer/src/artifact/mod.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
//! Model artifact format and handling
|
||||
//!
|
||||
//! Signed bundles with metadata, weights, and test vectors.
|
||||
|
||||
pub mod manifest;
|
||||
pub mod pack;
|
||||
pub mod verify;
|
||||
|
||||
pub use manifest::{BackendSpec, IoSpec, Manifest, TestSpec};
|
||||
pub use pack::{pack_artifact, unpack_artifact};
|
||||
pub use verify::{verify_artifact, verify_signature};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{FixedShape, ModelId, QuantSpec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Complete model artifact
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelArtifact {
|
||||
/// Manifest with metadata
|
||||
pub manifest: Manifest,
|
||||
/// Quantized weights (binary blob)
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub weights: Vec<u8>,
|
||||
/// Optional FPGA bitstream
|
||||
#[serde(with = "serde_bytes_option")]
|
||||
pub bitstream: Option<Vec<u8>>,
|
||||
/// Optional calibration data
|
||||
#[serde(with = "serde_bytes_option")]
|
||||
pub calibration: Option<Vec<u8>>,
|
||||
/// Test vectors for validation
|
||||
pub test_vectors: Vec<TestVector>,
|
||||
/// Ed25519 signature over manifest + file hashes
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub signature: [u8; 64],
|
||||
/// Ed25519 public key
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub pubkey: [u8; 32],
|
||||
}
|
||||
|
||||
/// Serde helper for Option<Vec<u8>>
|
||||
mod serde_bytes_option {
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
pub fn serialize<S: Serializer>(data: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
|
||||
match data {
|
||||
Some(bytes) => s.serialize_some(&serde_bytes::Bytes::new(bytes)),
|
||||
None => s.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Vec<u8>>, D::Error> {
|
||||
let opt: Option<serde_bytes::ByteBuf> = Option::deserialize(d)?;
|
||||
Ok(opt.map(|b| b.into_vec()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector for model validation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TestVector {
|
||||
/// Input tokens
|
||||
pub tokens: Vec<u16>,
|
||||
/// Expected output logits (top-K or full)
|
||||
pub expected: Vec<i16>,
|
||||
/// Maximum absolute error allowed
|
||||
pub max_abs_err: i32,
|
||||
}
|
||||
|
||||
impl ModelArtifact {
|
||||
/// Create a new artifact (for building/packing)
|
||||
pub fn new(
|
||||
manifest: Manifest,
|
||||
weights: Vec<u8>,
|
||||
bitstream: Option<Vec<u8>>,
|
||||
calibration: Option<Vec<u8>>,
|
||||
test_vectors: Vec<TestVector>,
|
||||
) -> Self {
|
||||
Self {
|
||||
manifest,
|
||||
weights,
|
||||
bitstream,
|
||||
calibration,
|
||||
test_vectors,
|
||||
signature: [0u8; 64],
|
||||
pubkey: [0u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute model ID (SHA-256 of manifest + weights hash)
|
||||
pub fn model_id(&self) -> ModelId {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(self.manifest.name.as_bytes());
|
||||
hasher.update(&self.model_hash());
|
||||
hasher.update(&self.quant_hash());
|
||||
ModelId::new(hasher.finalize().into())
|
||||
}
|
||||
|
||||
/// Compute hash of model weights
|
||||
pub fn model_hash(&self) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&self.weights);
|
||||
if let Some(ref bitstream) = self.bitstream {
|
||||
hasher.update(bitstream);
|
||||
}
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Compute hash of quantization parameters
|
||||
pub fn quant_hash(&self) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
let quant_json = serde_json::to_string(&self.manifest.quant).unwrap_or_default();
|
||||
hasher.update(quant_json.as_bytes());
|
||||
if let Some(ref calib) = self.calibration {
|
||||
hasher.update(calib);
|
||||
}
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Validate artifact integrity
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
// Validate manifest
|
||||
self.manifest.validate()?;
|
||||
|
||||
// Validate shape
|
||||
self.manifest
|
||||
.shape
|
||||
.validate()
|
||||
.map_err(|e| Error::InvalidArtifact(e))?;
|
||||
|
||||
// Check weights size is reasonable
|
||||
let min_weight_size =
|
||||
self.manifest.shape.embedding_params() / self.manifest.quant.weights_per_byte();
|
||||
if self.weights.len() < min_weight_size {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Weights too small: {} bytes, expected at least {} for embeddings",
|
||||
self.weights.len(),
|
||||
min_weight_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Validate test vectors if strict mode
|
||||
#[cfg(feature = "strict_verify")]
|
||||
self.run_test_vectors()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run test vectors for validation
|
||||
#[cfg(feature = "strict_verify")]
|
||||
pub fn run_test_vectors(&self) -> Result<()> {
|
||||
// This would require running inference, which creates a circular dependency
|
||||
// In practice, this is done by the backend after loading
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the fixed shape
|
||||
pub fn shape(&self) -> &FixedShape {
|
||||
&self.manifest.shape
|
||||
}
|
||||
|
||||
/// Get quantization spec
|
||||
pub fn quant(&self) -> &QuantSpec {
|
||||
&self.manifest.quant
|
||||
}
|
||||
|
||||
/// Check if artifact has FPGA bitstream
|
||||
pub fn has_bitstream(&self) -> bool {
|
||||
self.bitstream.is_some()
|
||||
}
|
||||
|
||||
/// Estimated memory footprint in bytes
|
||||
pub fn memory_footprint(&self) -> usize {
|
||||
self.weights.len()
|
||||
+ self.bitstream.as_ref().map(|b| b.len()).unwrap_or(0)
|
||||
+ self.calibration.as_ref().map(|c| c.len()).unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_manifest() -> Manifest {
|
||||
Manifest {
|
||||
name: "test_model".into(),
|
||||
model_hash: "0".repeat(64),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: IoSpec::default(),
|
||||
backend: BackendSpec::default(),
|
||||
tests: TestSpec::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_id_computation() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![]);
|
||||
|
||||
let id1 = artifact.model_id();
|
||||
let id2 = artifact.model_id();
|
||||
assert_eq!(id1, id2); // Deterministic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_hash() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(manifest, vec![42u8; 4096 * 64], None, None, vec![]);
|
||||
|
||||
let hash = artifact.model_hash();
|
||||
assert_ne!(hash, [0u8; 32]); // Non-zero hash
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_artifact_validation() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(
|
||||
manifest,
|
||||
vec![0u8; 4096 * 64], // Enough for micro embeddings
|
||||
None,
|
||||
None,
|
||||
vec![],
|
||||
);
|
||||
|
||||
assert!(artifact.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_artifact_too_small_weights() {
|
||||
let manifest = create_test_manifest();
|
||||
let artifact = ModelArtifact::new(
|
||||
manifest,
|
||||
vec![0u8; 100], // Too small
|
||||
None,
|
||||
None,
|
||||
vec![],
|
||||
);
|
||||
|
||||
assert!(artifact.validate().is_err());
|
||||
}
|
||||
}
|
||||
304
crates/ruvector-fpga-transformer/src/artifact/pack.rs
Normal file
304
crates/ruvector-fpga-transformer/src/artifact/pack.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
//! Artifact packing and unpacking
|
||||
|
||||
use std::io::{Read, Write};
|
||||
use std::path::Path;
|
||||
|
||||
use crate::artifact::{ModelArtifact, TestVector};
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Magic bytes for artifact file format
|
||||
const ARTIFACT_MAGIC: &[u8; 4] = b"RVAT"; // RuVector ArTifact
|
||||
const ARTIFACT_VERSION: u16 = 1;
|
||||
|
||||
// Security: Maximum size limits to prevent DoS via unbounded allocations
|
||||
/// Maximum manifest size (1 MB)
|
||||
const MAX_MANIFEST_SIZE: usize = 1024 * 1024;
|
||||
/// Maximum weights size (1 GB)
|
||||
const MAX_WEIGHTS_SIZE: usize = 1024 * 1024 * 1024;
|
||||
/// Maximum bitstream/calibration size (100 MB)
|
||||
const MAX_BLOB_SIZE: usize = 100 * 1024 * 1024;
|
||||
/// Maximum number of test vectors
|
||||
const MAX_TEST_VECTORS: usize = 10_000;
|
||||
/// Maximum tokens per test vector
|
||||
const MAX_TOKENS_PER_VECTOR: usize = 65_536;
|
||||
/// Maximum expected values per test vector
|
||||
const MAX_EXPECTED_PER_VECTOR: usize = 1_000_000;
|
||||
|
||||
/// Pack an artifact to bytes
|
||||
pub fn pack_artifact(artifact: &ModelArtifact) -> Result<Vec<u8>> {
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
// Write magic and version
|
||||
buffer.extend_from_slice(ARTIFACT_MAGIC);
|
||||
buffer.extend_from_slice(&ARTIFACT_VERSION.to_le_bytes());
|
||||
|
||||
// Write manifest as JSON with length prefix
|
||||
let manifest_json = serde_json::to_string(&artifact.manifest)?;
|
||||
let manifest_bytes = manifest_json.as_bytes();
|
||||
buffer.extend_from_slice(&(manifest_bytes.len() as u32).to_le_bytes());
|
||||
buffer.extend_from_slice(manifest_bytes);
|
||||
|
||||
// Write weights with length prefix
|
||||
buffer.extend_from_slice(&(artifact.weights.len() as u64).to_le_bytes());
|
||||
buffer.extend_from_slice(&artifact.weights);
|
||||
|
||||
// Write optional bitstream
|
||||
if let Some(ref bitstream) = artifact.bitstream {
|
||||
buffer.push(1); // Present flag
|
||||
buffer.extend_from_slice(&(bitstream.len() as u64).to_le_bytes());
|
||||
buffer.extend_from_slice(bitstream);
|
||||
} else {
|
||||
buffer.push(0); // Not present
|
||||
}
|
||||
|
||||
// Write optional calibration
|
||||
if let Some(ref calibration) = artifact.calibration {
|
||||
buffer.push(1);
|
||||
buffer.extend_from_slice(&(calibration.len() as u64).to_le_bytes());
|
||||
buffer.extend_from_slice(calibration);
|
||||
} else {
|
||||
buffer.push(0);
|
||||
}
|
||||
|
||||
// Write test vectors
|
||||
buffer.extend_from_slice(&(artifact.test_vectors.len() as u32).to_le_bytes());
|
||||
for vector in &artifact.test_vectors {
|
||||
// Write tokens
|
||||
buffer.extend_from_slice(&(vector.tokens.len() as u16).to_le_bytes());
|
||||
for &token in &vector.tokens {
|
||||
buffer.extend_from_slice(&token.to_le_bytes());
|
||||
}
|
||||
// Write expected
|
||||
buffer.extend_from_slice(&(vector.expected.len() as u32).to_le_bytes());
|
||||
for &exp in &vector.expected {
|
||||
buffer.extend_from_slice(&exp.to_le_bytes());
|
||||
}
|
||||
// Write max_abs_err
|
||||
buffer.extend_from_slice(&vector.max_abs_err.to_le_bytes());
|
||||
}
|
||||
|
||||
// Write signature and pubkey
|
||||
buffer.extend_from_slice(&artifact.signature);
|
||||
buffer.extend_from_slice(&artifact.pubkey);
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// Unpack an artifact from bytes
|
||||
pub fn unpack_artifact(data: &[u8]) -> Result<ModelArtifact> {
|
||||
let mut cursor = std::io::Cursor::new(data);
|
||||
let mut read_buf = [0u8; 8];
|
||||
|
||||
// Read and verify magic
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
if &read_buf[..4] != ARTIFACT_MAGIC {
|
||||
return Err(Error::InvalidArtifact("Invalid magic bytes".into()));
|
||||
}
|
||||
|
||||
// Read version
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
let version = u16::from_le_bytes([read_buf[0], read_buf[1]]);
|
||||
if version != ARTIFACT_VERSION {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Unsupported version: {}",
|
||||
version
|
||||
)));
|
||||
}
|
||||
|
||||
// Read manifest
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let manifest_len = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
|
||||
if manifest_len > MAX_MANIFEST_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Manifest size {} exceeds maximum {}",
|
||||
manifest_len, MAX_MANIFEST_SIZE
|
||||
)));
|
||||
}
|
||||
let mut manifest_bytes = vec![0u8; manifest_len];
|
||||
cursor.read_exact(&mut manifest_bytes)?;
|
||||
let manifest = serde_json::from_slice(&manifest_bytes)?;
|
||||
|
||||
// Read weights
|
||||
cursor.read_exact(&mut read_buf)?;
|
||||
let weights_len = u64::from_le_bytes(read_buf) as usize;
|
||||
if weights_len > MAX_WEIGHTS_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Weights size {} exceeds maximum {}",
|
||||
weights_len, MAX_WEIGHTS_SIZE
|
||||
)));
|
||||
}
|
||||
let mut weights = vec![0u8; weights_len];
|
||||
cursor.read_exact(&mut weights)?;
|
||||
|
||||
// Read optional bitstream
|
||||
cursor.read_exact(&mut read_buf[..1])?;
|
||||
let bitstream = if read_buf[0] == 1 {
|
||||
cursor.read_exact(&mut read_buf)?;
|
||||
let len = u64::from_le_bytes(read_buf) as usize;
|
||||
if len > MAX_BLOB_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Bitstream size {} exceeds maximum {}",
|
||||
len, MAX_BLOB_SIZE
|
||||
)));
|
||||
}
|
||||
let mut data = vec![0u8; len];
|
||||
cursor.read_exact(&mut data)?;
|
||||
Some(data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Read optional calibration
|
||||
cursor.read_exact(&mut read_buf[..1])?;
|
||||
let calibration = if read_buf[0] == 1 {
|
||||
cursor.read_exact(&mut read_buf)?;
|
||||
let len = u64::from_le_bytes(read_buf) as usize;
|
||||
if len > MAX_BLOB_SIZE {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Calibration size {} exceeds maximum {}",
|
||||
len, MAX_BLOB_SIZE
|
||||
)));
|
||||
}
|
||||
let mut data = vec![0u8; len];
|
||||
cursor.read_exact(&mut data)?;
|
||||
Some(data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Read test vectors
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let num_vectors = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
|
||||
if num_vectors > MAX_TEST_VECTORS {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Test vector count {} exceeds maximum {}",
|
||||
num_vectors, MAX_TEST_VECTORS
|
||||
)));
|
||||
}
|
||||
let mut test_vectors = Vec::with_capacity(num_vectors);
|
||||
|
||||
for _ in 0..num_vectors {
|
||||
// Read tokens
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
let num_tokens = u16::from_le_bytes([read_buf[0], read_buf[1]]) as usize;
|
||||
if num_tokens > MAX_TOKENS_PER_VECTOR {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Token count {} exceeds maximum {}",
|
||||
num_tokens, MAX_TOKENS_PER_VECTOR
|
||||
)));
|
||||
}
|
||||
let mut tokens = Vec::with_capacity(num_tokens);
|
||||
for _ in 0..num_tokens {
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
tokens.push(u16::from_le_bytes([read_buf[0], read_buf[1]]));
|
||||
}
|
||||
|
||||
// Read expected
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let num_expected = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
|
||||
if num_expected > MAX_EXPECTED_PER_VECTOR {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Expected values count {} exceeds maximum {}",
|
||||
num_expected, MAX_EXPECTED_PER_VECTOR
|
||||
)));
|
||||
}
|
||||
let mut expected = Vec::with_capacity(num_expected);
|
||||
for _ in 0..num_expected {
|
||||
cursor.read_exact(&mut read_buf[..2])?;
|
||||
expected.push(i16::from_le_bytes([read_buf[0], read_buf[1]]));
|
||||
}
|
||||
|
||||
// Read max_abs_err
|
||||
cursor.read_exact(&mut read_buf[..4])?;
|
||||
let max_abs_err = i32::from_le_bytes(read_buf[..4].try_into().unwrap());
|
||||
|
||||
test_vectors.push(TestVector {
|
||||
tokens,
|
||||
expected,
|
||||
max_abs_err,
|
||||
});
|
||||
}
|
||||
|
||||
// Read signature and pubkey
|
||||
let mut signature = [0u8; 64];
|
||||
cursor.read_exact(&mut signature)?;
|
||||
let mut pubkey = [0u8; 32];
|
||||
cursor.read_exact(&mut pubkey)?;
|
||||
|
||||
Ok(ModelArtifact {
|
||||
manifest,
|
||||
weights,
|
||||
bitstream,
|
||||
calibration,
|
||||
test_vectors,
|
||||
signature,
|
||||
pubkey,
|
||||
})
|
||||
}
|
||||
|
||||
/// Save artifact to file
|
||||
pub fn save_artifact(artifact: &ModelArtifact, path: impl AsRef<Path>) -> Result<()> {
|
||||
let data = pack_artifact(artifact)?;
|
||||
std::fs::write(path, data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load artifact from file
|
||||
pub fn load_artifact(path: impl AsRef<Path>) -> Result<ModelArtifact> {
|
||||
let data = std::fs::read(path)?;
|
||||
unpack_artifact(&data)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::types::{FixedShape, QuantSpec};
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "test_pack".into(),
|
||||
model_hash: "abc123".into(),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact {
|
||||
manifest,
|
||||
weights: (0..5000).map(|i| (i % 256) as u8).collect(),
|
||||
bitstream: Some(vec![0xFF; 100]),
|
||||
calibration: None,
|
||||
test_vectors: vec![TestVector {
|
||||
tokens: vec![1, 2, 3],
|
||||
expected: vec![100, 200, 300],
|
||||
max_abs_err: 5,
|
||||
}],
|
||||
signature: [0x42u8; 64],
|
||||
pubkey: [0x24u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_unpack_roundtrip() {
|
||||
let original = create_test_artifact();
|
||||
let packed = pack_artifact(&original).unwrap();
|
||||
let unpacked = unpack_artifact(&packed).unwrap();
|
||||
|
||||
assert_eq!(original.manifest.name, unpacked.manifest.name);
|
||||
assert_eq!(original.weights, unpacked.weights);
|
||||
assert_eq!(original.bitstream, unpacked.bitstream);
|
||||
assert_eq!(original.calibration, unpacked.calibration);
|
||||
assert_eq!(original.test_vectors.len(), unpacked.test_vectors.len());
|
||||
assert_eq!(original.signature, unpacked.signature);
|
||||
assert_eq!(original.pubkey, unpacked.pubkey);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_magic() {
|
||||
let data = b"XXXX0000";
|
||||
assert!(unpack_artifact(data).is_err());
|
||||
}
|
||||
}
|
||||
203
crates/ruvector-fpga-transformer/src/artifact/verify.rs
Normal file
203
crates/ruvector-fpga-transformer/src/artifact/verify.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
//! Artifact verification and signature validation
|
||||
|
||||
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Verify artifact signature
|
||||
pub fn verify_signature(artifact: &ModelArtifact) -> Result<bool> {
|
||||
// Compute the message to verify (manifest hash + file hashes)
|
||||
let message = compute_signing_message(artifact);
|
||||
|
||||
// Load public key
|
||||
let pubkey = VerifyingKey::from_bytes(&artifact.pubkey)
|
||||
.map_err(|e| Error::SignatureError(format!("Invalid public key: {}", e)))?;
|
||||
|
||||
// Load signature
|
||||
let signature = Signature::from_bytes(&artifact.signature);
|
||||
|
||||
// Verify
|
||||
pubkey
|
||||
.verify(&message, &signature)
|
||||
.map(|_| true)
|
||||
.map_err(|e| Error::SignatureError(format!("Verification failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Verify complete artifact integrity
|
||||
pub fn verify_artifact(artifact: &ModelArtifact) -> Result<()> {
|
||||
// 1. Validate manifest
|
||||
artifact.manifest.validate()?;
|
||||
|
||||
// 2. Verify model hash matches manifest
|
||||
let computed_hash = hex::encode(artifact.model_hash());
|
||||
if !artifact.manifest.model_hash.is_empty() && computed_hash != artifact.manifest.model_hash {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Model hash mismatch: expected {}, got {}",
|
||||
artifact.manifest.model_hash, computed_hash
|
||||
)));
|
||||
}
|
||||
|
||||
// 3. Verify signature (if present)
|
||||
if artifact.pubkey != [0u8; 32] {
|
||||
verify_signature(artifact)?;
|
||||
}
|
||||
|
||||
// 4. Verify weights size
|
||||
let expected_min =
|
||||
artifact.manifest.shape.embedding_params() / artifact.manifest.quant.weights_per_byte();
|
||||
if artifact.weights.len() < expected_min {
|
||||
return Err(Error::InvalidArtifact(format!(
|
||||
"Weights too small: {} < {}",
|
||||
artifact.weights.len(),
|
||||
expected_min
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute the message that was signed
|
||||
fn compute_signing_message(artifact: &ModelArtifact) -> Vec<u8> {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
// Hash manifest
|
||||
let manifest_json = serde_json::to_string(&artifact.manifest).unwrap_or_default();
|
||||
hasher.update(manifest_json.as_bytes());
|
||||
|
||||
// Hash weights
|
||||
let weights_hash = artifact.model_hash();
|
||||
hasher.update(&weights_hash);
|
||||
|
||||
// Hash quant params
|
||||
let quant_hash = artifact.quant_hash();
|
||||
hasher.update(&quant_hash);
|
||||
|
||||
// Hash bitstream if present
|
||||
if let Some(ref bitstream) = artifact.bitstream {
|
||||
let mut h = Sha256::new();
|
||||
h.update(bitstream);
|
||||
hasher.update(&h.finalize());
|
||||
}
|
||||
|
||||
// Hash calibration if present
|
||||
if let Some(ref calib) = artifact.calibration {
|
||||
let mut h = Sha256::new();
|
||||
h.update(calib);
|
||||
hasher.update(&h.finalize());
|
||||
}
|
||||
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
/// Sign an artifact with Ed25519 private key
|
||||
#[cfg(feature = "sign")]
|
||||
pub fn sign_artifact(artifact: &mut ModelArtifact, secret_key: &[u8; 32]) -> Result<()> {
|
||||
use ed25519_dalek::{Signer, SigningKey};
|
||||
|
||||
let signing_key = SigningKey::from_bytes(secret_key);
|
||||
let message = compute_signing_message(artifact);
|
||||
|
||||
let signature = signing_key.sign(&message);
|
||||
|
||||
artifact.signature = signature.to_bytes();
|
||||
artifact.pubkey = signing_key.verifying_key().to_bytes();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verify test vectors against model output
|
||||
pub fn verify_test_vectors(
|
||||
artifact: &ModelArtifact,
|
||||
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
|
||||
) -> Result<()> {
|
||||
let max_err = artifact.manifest.tests.max_abs_err;
|
||||
|
||||
for (i, vector) in artifact.test_vectors.iter().enumerate() {
|
||||
let output = infer_fn(&vector.tokens)?;
|
||||
|
||||
// Compare outputs
|
||||
let actual_max_err = output
|
||||
.iter()
|
||||
.zip(&vector.expected)
|
||||
.map(|(&a, &b)| (a as i32 - b as i32).abs())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
if actual_max_err > max_err {
|
||||
return Err(Error::TestVectorError {
|
||||
expected: max_err,
|
||||
actual: actual_max_err,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate test vectors for an artifact
|
||||
pub fn generate_test_vectors(
|
||||
artifact: &mut ModelArtifact,
|
||||
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
|
||||
count: usize,
|
||||
) -> Result<()> {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let seq_len = artifact.manifest.shape.seq_len as usize;
|
||||
let vocab = artifact.manifest.shape.vocab as u16;
|
||||
|
||||
artifact.test_vectors.clear();
|
||||
|
||||
for _ in 0..count {
|
||||
// Generate random input
|
||||
let tokens: Vec<u16> = (0..seq_len).map(|_| rng.gen_range(0..vocab)).collect();
|
||||
|
||||
// Run inference
|
||||
let expected = infer_fn(&tokens)?;
|
||||
|
||||
artifact.test_vectors.push(crate::artifact::TestVector {
|
||||
tokens,
|
||||
expected,
|
||||
max_abs_err: artifact.manifest.tests.max_abs_err,
|
||||
});
|
||||
}
|
||||
|
||||
artifact.manifest.tests.vectors = count as u32;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::types::{FixedShape, QuantSpec};
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "test".into(),
|
||||
model_hash: String::new(),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_artifact() {
|
||||
let artifact = create_test_artifact();
|
||||
assert!(verify_artifact(&artifact).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_signing_message() {
|
||||
let artifact = create_test_artifact();
|
||||
let msg = compute_signing_message(&artifact);
|
||||
assert_eq!(msg.len(), 32); // SHA-256 output
|
||||
}
|
||||
}
|
||||
566
crates/ruvector-fpga-transformer/src/backend/fpga_daemon.rs
Normal file
566
crates/ruvector-fpga-transformer/src/backend/fpga_daemon.rs
Normal file
@@ -0,0 +1,566 @@
|
||||
//! FPGA Daemon backend
|
||||
//!
|
||||
//! Communicates with a local daemon over Unix socket or TCP
|
||||
//! to send inference requests to an FPGA accelerator.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::{Read, Write};
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{
|
||||
commands, compute_topk, crc32, protocol, read_lock, validate_tokens, write_lock, BackendStats,
|
||||
RequestFrame, ResponseFrame, TransformerBackend,
|
||||
};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{
|
||||
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
|
||||
};
|
||||
|
||||
/// Connection type for daemon communication
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DaemonConnection {
|
||||
/// Unix domain socket path
|
||||
Unix(String),
|
||||
/// TCP address (host:port)
|
||||
Tcp(String),
|
||||
}
|
||||
|
||||
impl DaemonConnection {
|
||||
/// Create a Unix socket connection
|
||||
pub fn unix(path: impl Into<String>) -> Self {
|
||||
Self::Unix(path.into())
|
||||
}
|
||||
|
||||
/// Create a TCP connection
|
||||
pub fn tcp(addr: impl Into<String>) -> Self {
|
||||
Self::Tcp(addr.into())
|
||||
}
|
||||
|
||||
/// Default socket path
|
||||
pub fn default_socket() -> Self {
|
||||
Self::Unix("/var/run/ruvector_fpga.sock".into())
|
||||
}
|
||||
}
|
||||
|
||||
/// FPGA Daemon backend
|
||||
pub struct FpgaDaemonBackend {
|
||||
/// Connection configuration
|
||||
connection: DaemonConnection,
|
||||
/// Loaded models (cached metadata)
|
||||
models: RwLock<HashMap<ModelId, ModelMetadata>>,
|
||||
/// Statistics
|
||||
stats: RwLock<BackendStats>,
|
||||
/// Configuration
|
||||
config: DaemonConfig,
|
||||
}
|
||||
|
||||
/// Cached model metadata
|
||||
struct ModelMetadata {
|
||||
artifact: ModelArtifact,
|
||||
loaded_at: Instant,
|
||||
}
|
||||
|
||||
/// Configuration for daemon backend
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DaemonConfig {
|
||||
/// Connection timeout in milliseconds
|
||||
pub connect_timeout_ms: u64,
|
||||
/// Read timeout in milliseconds
|
||||
pub read_timeout_ms: u64,
|
||||
/// Write timeout in milliseconds
|
||||
pub write_timeout_ms: u64,
|
||||
/// Number of retry attempts
|
||||
pub retries: usize,
|
||||
/// Retry backoff multiplier
|
||||
pub backoff_multiplier: f64,
|
||||
/// Return only top-K results
|
||||
pub topk_only: bool,
|
||||
/// Top-K count
|
||||
pub topk: u16,
|
||||
}
|
||||
|
||||
impl Default for DaemonConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
connect_timeout_ms: 5000,
|
||||
read_timeout_ms: 10000,
|
||||
write_timeout_ms: 5000,
|
||||
retries: 3,
|
||||
backoff_multiplier: 2.0,
|
||||
topk_only: true,
|
||||
topk: 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FpgaDaemonBackend {
|
||||
/// Create a new daemon backend with Unix socket
|
||||
pub fn new(socket_path: impl AsRef<Path>) -> Self {
|
||||
Self::with_connection(
|
||||
DaemonConnection::unix(socket_path.as_ref().to_string_lossy()),
|
||||
DaemonConfig::default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create with custom connection and config
|
||||
pub fn with_connection(connection: DaemonConnection, config: DaemonConfig) -> Self {
|
||||
Self {
|
||||
connection,
|
||||
models: RwLock::new(HashMap::new()),
|
||||
stats: RwLock::new(BackendStats::default()),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to the daemon
|
||||
fn connect(&self) -> Result<Box<dyn ReadWrite>> {
|
||||
let timeout = Duration::from_millis(self.config.connect_timeout_ms);
|
||||
|
||||
match &self.connection {
|
||||
DaemonConnection::Unix(path) => {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::net::UnixStream;
|
||||
let stream = UnixStream::connect(path)
|
||||
.map_err(|e| Error::daemon_connection(format!("Unix socket: {}", e)))?;
|
||||
stream
|
||||
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
|
||||
.ok();
|
||||
stream
|
||||
.set_write_timeout(Some(Duration::from_millis(
|
||||
self.config.write_timeout_ms,
|
||||
)))
|
||||
.ok();
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = (path, timeout);
|
||||
Err(Error::FeatureNotAvailable(
|
||||
"Unix sockets not available on this platform".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
DaemonConnection::Tcp(addr) => {
|
||||
use std::net::TcpStream;
|
||||
let stream = TcpStream::connect_timeout(
|
||||
&addr
|
||||
.parse()
|
||||
.map_err(|e| Error::daemon_connection(format!("Invalid address: {}", e)))?,
|
||||
timeout,
|
||||
)
|
||||
.map_err(|e| Error::daemon_connection(format!("TCP: {}", e)))?;
|
||||
stream
|
||||
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
|
||||
.ok();
|
||||
stream
|
||||
.set_write_timeout(Some(Duration::from_millis(self.config.write_timeout_ms)))
|
||||
.ok();
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send inference request to daemon
|
||||
fn send_request(
|
||||
&self,
|
||||
stream: &mut dyn ReadWrite,
|
||||
req: &InferenceRequest,
|
||||
) -> Result<(Vec<i16>, ResponseFrame)> {
|
||||
let shape = &req.shape;
|
||||
|
||||
// Build request flags
|
||||
let mut flags = 0u16;
|
||||
if self.config.topk_only {
|
||||
flags |= protocol::flags::TOPK_ONLY;
|
||||
}
|
||||
|
||||
// Create request frame
|
||||
let frame = RequestFrame::new(
|
||||
shape.seq_len,
|
||||
shape.d_model,
|
||||
shape.vocab,
|
||||
&req.model,
|
||||
flags,
|
||||
self.config.topk,
|
||||
);
|
||||
|
||||
// Build payload
|
||||
let mut payload = Vec::with_capacity(
|
||||
protocol::HEADER_SIZE + req.tokens.len() * 2 + req.attn_mask.len() + 8,
|
||||
);
|
||||
|
||||
// Write header
|
||||
payload.extend_from_slice(&frame.to_bytes());
|
||||
|
||||
// Write tokens (u16 little-endian)
|
||||
for &token in req.tokens {
|
||||
payload.extend_from_slice(&token.to_le_bytes());
|
||||
}
|
||||
|
||||
// Write mask
|
||||
payload.extend_from_slice(req.attn_mask);
|
||||
|
||||
// Write gate hint (packed)
|
||||
payload.extend_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
|
||||
payload.push(req.gate_hint.boundary_crossed as u8);
|
||||
payload.push(req.gate_hint.max_compute_class as u8);
|
||||
|
||||
// Calculate and append checksum
|
||||
let checksum = crc32(&payload);
|
||||
payload.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
// Send payload
|
||||
stream
|
||||
.write_all(&payload)
|
||||
.map_err(|e| Error::backend(format!("Write failed: {}", e)))?;
|
||||
stream
|
||||
.flush()
|
||||
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
|
||||
|
||||
// Read response header
|
||||
let mut response_header = [0u8; 14];
|
||||
stream
|
||||
.read_exact(&mut response_header)
|
||||
.map_err(|e| Error::backend(format!("Read header failed: {}", e)))?;
|
||||
|
||||
let response = ResponseFrame::from_bytes(&response_header);
|
||||
|
||||
// Copy packed fields to avoid alignment issues
|
||||
let status = { response.status };
|
||||
|
||||
// Check status
|
||||
match status {
|
||||
protocol::status::OK => {}
|
||||
protocol::status::MODEL_NOT_FOUND => {
|
||||
return Err(Error::ModelNotFound(req.model));
|
||||
}
|
||||
protocol::status::SHAPE_MISMATCH => {
|
||||
return Err(Error::ShapeMismatch {
|
||||
expected: req.shape,
|
||||
actual: req.shape, // Daemon should provide actual shape
|
||||
});
|
||||
}
|
||||
protocol::status::GATE_BLOCKED => {
|
||||
return Err(Error::GateBlocked {
|
||||
reason: crate::types::SkipReason::PolicyDenied,
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::backend(format!("Daemon error: status {}", status)));
|
||||
}
|
||||
}
|
||||
|
||||
// Read logits
|
||||
let logits_count = if self.config.topk_only {
|
||||
self.config.topk as usize * 2 // (token_id, logit) pairs
|
||||
} else {
|
||||
shape.vocab as usize
|
||||
};
|
||||
|
||||
let mut logits_bytes = vec![0u8; logits_count * 2];
|
||||
stream
|
||||
.read_exact(&mut logits_bytes)
|
||||
.map_err(|e| Error::backend(format!("Read logits failed: {}", e)))?;
|
||||
|
||||
// Parse logits
|
||||
let logits: Vec<i16> = logits_bytes
|
||||
.chunks(2)
|
||||
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
|
||||
.collect();
|
||||
|
||||
// Read and verify checksum
|
||||
let mut checksum_bytes = [0u8; 4];
|
||||
stream.read_exact(&mut checksum_bytes).ok(); // Checksum is optional
|
||||
|
||||
Ok((logits, response))
|
||||
}
|
||||
|
||||
/// Send load model command to daemon
|
||||
fn send_load_command(
|
||||
&self,
|
||||
stream: &mut dyn ReadWrite,
|
||||
artifact: &ModelArtifact,
|
||||
) -> Result<()> {
|
||||
// Pack artifact
|
||||
let artifact_bytes = crate::artifact::pack::pack_artifact(artifact)?;
|
||||
|
||||
// Build command packet:
|
||||
// [command: 1] [model_id: 32] [artifact_len: 4] [artifact_data: N] [checksum: 4]
|
||||
let mut payload = Vec::with_capacity(1 + 32 + 4 + artifact_bytes.len() + 4);
|
||||
|
||||
// Command byte
|
||||
payload.push(commands::LOAD_MODEL);
|
||||
|
||||
// Model ID (32 bytes)
|
||||
payload.extend_from_slice(artifact.model_id().as_bytes());
|
||||
|
||||
// Artifact length (u32 LE)
|
||||
payload.extend_from_slice(&(artifact_bytes.len() as u32).to_le_bytes());
|
||||
|
||||
// Artifact data
|
||||
payload.extend_from_slice(&artifact_bytes);
|
||||
|
||||
// Checksum
|
||||
let checksum = crc32(&payload);
|
||||
payload.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
// Send
|
||||
stream
|
||||
.write_all(&payload)
|
||||
.map_err(|e| Error::backend(format!("Write load command failed: {}", e)))?;
|
||||
stream
|
||||
.flush()
|
||||
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
|
||||
|
||||
// Read response: [status: 1] [message_len: 2] [message: N]
|
||||
let mut status = [0u8; 1];
|
||||
stream
|
||||
.read_exact(&mut status)
|
||||
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
|
||||
|
||||
if status[0] != 0 {
|
||||
// Read error message
|
||||
let mut msg_len = [0u8; 2];
|
||||
stream.read_exact(&mut msg_len).ok();
|
||||
let len = u16::from_le_bytes(msg_len) as usize;
|
||||
let mut msg = vec![0u8; len.min(256)];
|
||||
stream.read_exact(&mut msg).ok();
|
||||
let error_msg = String::from_utf8_lossy(&msg);
|
||||
return Err(Error::backend(format!(
|
||||
"Daemon rejected load: {}",
|
||||
error_msg
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send unload model command to daemon
|
||||
fn send_unload_command(&self, stream: &mut dyn ReadWrite, model_id: ModelId) -> Result<()> {
|
||||
// Build command packet: [command: 1] [model_id: 32] [checksum: 4]
|
||||
let mut payload = Vec::with_capacity(1 + 32 + 4);
|
||||
payload.push(commands::UNLOAD_MODEL);
|
||||
payload.extend_from_slice(model_id.as_bytes());
|
||||
let checksum = crc32(&payload);
|
||||
payload.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
// Send
|
||||
stream
|
||||
.write_all(&payload)
|
||||
.map_err(|e| Error::backend(format!("Write unload command failed: {}", e)))?;
|
||||
stream
|
||||
.flush()
|
||||
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
|
||||
|
||||
// Read response status
|
||||
let mut status = [0u8; 1];
|
||||
stream
|
||||
.read_exact(&mut status)
|
||||
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
|
||||
|
||||
if status[0] != 0 {
|
||||
return Err(Error::backend("Daemon rejected unload"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute with retries
|
||||
fn with_retries<T, F>(&self, mut f: F) -> Result<T>
|
||||
where
|
||||
F: FnMut() -> Result<T>,
|
||||
{
|
||||
let mut last_error = None;
|
||||
let mut delay = Duration::from_millis(100);
|
||||
|
||||
for attempt in 0..=self.config.retries {
|
||||
match f() {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) if e.is_recoverable() => {
|
||||
last_error = Some(e);
|
||||
if attempt < self.config.retries {
|
||||
std::thread::sleep(delay);
|
||||
delay = Duration::from_secs_f64(
|
||||
delay.as_secs_f64() * self.config.backoff_multiplier,
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| Error::backend("Unknown error")))
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for FpgaDaemonBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
let model_id = artifact.model_id();
|
||||
|
||||
// Send load command to daemon to preload the model
|
||||
self.with_retries(|| {
|
||||
let mut stream = self.connect()?;
|
||||
self.send_load_command(stream.as_mut(), artifact)
|
||||
})?;
|
||||
|
||||
// Cache metadata locally
|
||||
{
|
||||
let mut models = write_lock(&self.models, |m| {
|
||||
m.insert(
|
||||
model_id,
|
||||
ModelMetadata {
|
||||
artifact: artifact.clone(),
|
||||
loaded_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
write_lock(&self.stats, |s| {
|
||||
s.models_loaded += 1;
|
||||
})?;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Check model is loaded locally and validate tokens
|
||||
let model_metadata = read_lock(&self.models, |models| {
|
||||
models.get(&req.model).map(|m| m.artifact.clone())
|
||||
})?
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate tokens against vocabulary
|
||||
validate_tokens(req.tokens, model_metadata.manifest.shape.vocab)?;
|
||||
|
||||
// Execute with retries
|
||||
let (logits, response) = self.with_retries(|| {
|
||||
let mut stream = self.connect()?;
|
||||
self.send_request(stream.as_mut(), &req)
|
||||
})?;
|
||||
|
||||
let latency_ns = start.elapsed().as_nanos() as u32;
|
||||
|
||||
// Parse response
|
||||
let gate_decision = response.to_gate_decision();
|
||||
|
||||
// Build top-K if we got pairs
|
||||
let (logits_q, topk) = if self.config.topk_only {
|
||||
// logits contains (token_id, logit) pairs
|
||||
let pairs: Vec<(u16, i16)> = logits
|
||||
.chunks(2)
|
||||
.filter_map(|chunk| {
|
||||
if chunk.len() == 2 {
|
||||
Some((chunk[0] as u16, chunk[1]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
(vec![], Some(pairs))
|
||||
} else {
|
||||
// Full logits - use common compute_topk
|
||||
let topk = compute_topk(&logits, 16);
|
||||
(logits, Some(topk))
|
||||
};
|
||||
|
||||
// Copy packed fields to avoid alignment issues
|
||||
let resp_cycles = { response.cycles };
|
||||
let resp_latency_ns = { response.latency_ns };
|
||||
|
||||
// Create witness
|
||||
let witness = WitnessLog::new(
|
||||
model_metadata.model_hash(),
|
||||
model_metadata.quant_hash(),
|
||||
BackendKind::FpgaDaemon,
|
||||
resp_cycles,
|
||||
latency_ns.min(resp_latency_ns.max(latency_ns)),
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
// Update stats (with poison handling)
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.total_inferences += 1;
|
||||
stats.total_cycles += resp_cycles as u64;
|
||||
let n = stats.total_inferences;
|
||||
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(InferenceResult::new(logits_q, topk, witness))
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
// Send unload command to daemon
|
||||
self.with_retries(|| {
|
||||
let mut stream = self.connect()?;
|
||||
self.send_unload_command(stream.as_mut(), model)
|
||||
})?;
|
||||
|
||||
// Remove from local cache
|
||||
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
|
||||
|
||||
if removed {
|
||||
write_lock(&self.stats, |s| {
|
||||
s.models_loaded = s.models_loaded.saturating_sub(1);
|
||||
})?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::FpgaDaemon
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
self.stats.read().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait combining Read and Write for stream abstraction
|
||||
trait ReadWrite: Read + Write + Send {}
|
||||
impl<T: Read + Write + Send> ReadWrite for T {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_daemon_connection_types() {
|
||||
let unix = DaemonConnection::unix("/tmp/test.sock");
|
||||
assert!(matches!(unix, DaemonConnection::Unix(_)));
|
||||
|
||||
let tcp = DaemonConnection::tcp("127.0.0.1:8080");
|
||||
assert!(matches!(tcp, DaemonConnection::Tcp(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_defaults() {
|
||||
let config = DaemonConfig::default();
|
||||
assert_eq!(config.connect_timeout_ms, 5000);
|
||||
assert_eq!(config.retries, 3);
|
||||
assert!(config.topk_only);
|
||||
}
|
||||
}
|
||||
645
crates/ruvector-fpga-transformer/src/backend/fpga_pcie.rs
Normal file
645
crates/ruvector-fpga-transformer/src/backend/fpga_pcie.rs
Normal file
@@ -0,0 +1,645 @@
|
||||
//! FPGA PCIe backend
|
||||
//!
|
||||
//! Direct memory-mapped access to FPGA accelerator via PCIe.
|
||||
//! Uses DMA ring buffers for zero-copy, lock-free operation.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
use memmap2::{MmapMut, MmapOptions};
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{
|
||||
compute_topk, protocol, read_lock, validate_tokens, write_lock, BackendStats,
|
||||
TransformerBackend,
|
||||
};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{
|
||||
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
|
||||
};
|
||||
|
||||
/// PCIe device configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PcieConfig {
|
||||
/// Device path (e.g., /dev/ruvector0)
|
||||
pub device_path: String,
|
||||
/// BAR0 offset for control registers
|
||||
pub bar0_offset: usize,
|
||||
/// BAR1 offset for DMA buffers
|
||||
pub bar1_offset: usize,
|
||||
/// Number of request slots in ring buffer
|
||||
pub ring_slots: usize,
|
||||
/// Size of each request slot in bytes
|
||||
pub slot_size: usize,
|
||||
/// DMA timeout in milliseconds
|
||||
pub dma_timeout_ms: u64,
|
||||
/// Enable batch mode (multiple requests per DMA burst)
|
||||
pub batch_mode: bool,
|
||||
/// Maximum requests per batch
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for PcieConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_path: "/dev/ruvector0".into(),
|
||||
bar0_offset: 0,
|
||||
bar1_offset: 0x10000,
|
||||
ring_slots: 16,
|
||||
slot_size: 64 * 1024, // 64KB per slot
|
||||
dma_timeout_ms: 100,
|
||||
batch_mode: false,
|
||||
batch_size: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ring buffer slot state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
enum SlotState {
|
||||
Free = 0,
|
||||
Pending = 1,
|
||||
Complete = 2,
|
||||
Error = 3,
|
||||
}
|
||||
|
||||
/// DMA ring buffer for lock-free request/response handling
|
||||
struct DmaRingBuffer {
|
||||
/// Memory-mapped request buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
request_mmap: MmapMut,
|
||||
/// Memory-mapped response buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
response_mmap: MmapMut,
|
||||
/// Slot states
|
||||
slot_states: Vec<AtomicU32>,
|
||||
/// Producer index (next slot to write)
|
||||
producer_idx: AtomicU32,
|
||||
/// Consumer index (next slot to read)
|
||||
consumer_idx: AtomicU32,
|
||||
/// Number of slots
|
||||
num_slots: usize,
|
||||
/// Size per slot
|
||||
slot_size: usize,
|
||||
}
|
||||
|
||||
impl DmaRingBuffer {
|
||||
/// Create a new DMA ring buffer (mock for non-PCIe builds)
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
fn new(_config: &PcieConfig) -> Result<Self> {
|
||||
Err(Error::FeatureNotAvailable(
|
||||
"PCIe support not compiled".into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Create a new DMA ring buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
fn new(config: &PcieConfig) -> Result<Self> {
|
||||
use std::fs::OpenOptions;
|
||||
|
||||
// Open device
|
||||
let file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.open(&config.device_path)
|
||||
.map_err(|e| Error::PcieError(format!("Failed to open device: {}", e)))?;
|
||||
|
||||
let total_size = config.ring_slots * config.slot_size;
|
||||
|
||||
// Map request buffer (BAR1)
|
||||
let request_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.offset(config.bar1_offset as u64)
|
||||
.len(total_size)
|
||||
.map_mut(&file)
|
||||
.map_err(|e| Error::PcieError(format!("Failed to map request buffer: {}", e)))?
|
||||
};
|
||||
|
||||
// Map response buffer (BAR1 + offset)
|
||||
let response_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.offset((config.bar1_offset + total_size) as u64)
|
||||
.len(total_size)
|
||||
.map_mut(&file)
|
||||
.map_err(|e| Error::PcieError(format!("Failed to map response buffer: {}", e)))?
|
||||
};
|
||||
|
||||
// Initialize slot states
|
||||
let slot_states: Vec<AtomicU32> = (0..config.ring_slots)
|
||||
.map(|_| AtomicU32::new(SlotState::Free as u32))
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
request_mmap,
|
||||
response_mmap,
|
||||
slot_states,
|
||||
producer_idx: AtomicU32::new(0),
|
||||
consumer_idx: AtomicU32::new(0),
|
||||
num_slots: config.ring_slots,
|
||||
slot_size: config.slot_size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Acquire a slot for writing
|
||||
fn acquire_slot(&self) -> Option<usize> {
|
||||
let producer = self.producer_idx.load(Ordering::Acquire);
|
||||
let slot = producer as usize % self.num_slots;
|
||||
|
||||
// Check if slot is free
|
||||
if self.slot_states[slot].load(Ordering::Acquire) == SlotState::Free as u32 {
|
||||
// Try to claim it
|
||||
if self.slot_states[slot]
|
||||
.compare_exchange(
|
||||
SlotState::Free as u32,
|
||||
SlotState::Pending as u32,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Relaxed,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
self.producer_idx
|
||||
.store(producer.wrapping_add(1), Ordering::Release);
|
||||
return Some(slot);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Release a slot after reading response
|
||||
fn release_slot(&self, slot: usize) {
|
||||
self.slot_states[slot].store(SlotState::Free as u32, Ordering::Release);
|
||||
self.consumer_idx.fetch_add(1, Ordering::AcqRel);
|
||||
}
|
||||
|
||||
/// Check if a slot is complete
|
||||
fn is_complete(&self, slot: usize) -> bool {
|
||||
self.slot_states[slot].load(Ordering::Acquire) == SlotState::Complete as u32
|
||||
}
|
||||
|
||||
/// Mark a slot as complete (called by FPGA via doorbell/interrupt)
|
||||
fn mark_complete(&self, slot: usize) {
|
||||
self.slot_states[slot].store(SlotState::Complete as u32, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Get request buffer for a slot
|
||||
#[cfg(feature = "pcie")]
|
||||
fn request_buffer(&mut self, slot: usize) -> &mut [u8] {
|
||||
let start = slot * self.slot_size;
|
||||
let end = start + self.slot_size;
|
||||
&mut self.request_mmap[start..end]
|
||||
}
|
||||
|
||||
/// Get response buffer for a slot
|
||||
#[cfg(feature = "pcie")]
|
||||
fn response_buffer(&self, slot: usize) -> &[u8] {
|
||||
let start = slot * self.slot_size;
|
||||
let end = start + self.slot_size;
|
||||
&self.response_mmap[start..end]
|
||||
}
|
||||
}
|
||||
|
||||
/// FPGA PCIe backend
|
||||
pub struct FpgaPcieBackend {
|
||||
/// Configuration
|
||||
config: PcieConfig,
|
||||
/// DMA ring buffer
|
||||
ring: Option<DmaRingBuffer>,
|
||||
/// Loaded models
|
||||
models: RwLock<HashMap<ModelId, ModelMetadata>>,
|
||||
/// Statistics
|
||||
stats: RwLock<BackendStats>,
|
||||
/// Total cycles counter
|
||||
total_cycles: AtomicU64,
|
||||
/// FPGA memory allocator state (next free offset)
|
||||
fpga_mem_offset: AtomicU64,
|
||||
/// FPGA memory total size (2GB default)
|
||||
fpga_mem_size: u64,
|
||||
}
|
||||
|
||||
/// Cached model metadata
|
||||
struct ModelMetadata {
|
||||
artifact: ModelArtifact,
|
||||
fpga_slot: u32, // Slot in FPGA memory where model is loaded
|
||||
weights_offset: u64, // Offset in FPGA DDR where weights are stored
|
||||
weights_size: usize, // Size of weights in bytes
|
||||
}
|
||||
|
||||
/// FPGA DDR base offset for model weights
|
||||
const FPGA_DDR_MODEL_BASE: u64 = 0x1000_0000; // 256MB offset
|
||||
|
||||
impl FpgaPcieBackend {
|
||||
/// Create a new PCIe backend
|
||||
pub fn new(config: PcieConfig) -> Result<Self> {
|
||||
#[cfg(feature = "pcie")]
|
||||
let ring = Some(DmaRingBuffer::new(&config)?);
|
||||
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
let ring = None;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
ring,
|
||||
models: RwLock::new(HashMap::new()),
|
||||
stats: RwLock::new(BackendStats::default()),
|
||||
total_cycles: AtomicU64::new(0),
|
||||
fpga_mem_offset: AtomicU64::new(FPGA_DDR_MODEL_BASE),
|
||||
fpga_mem_size: 2 * 1024 * 1024 * 1024, // 2GB
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_device() -> Result<Self> {
|
||||
Self::new(PcieConfig::default())
|
||||
}
|
||||
|
||||
/// Write inference request to DMA buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
fn write_request(
|
||||
&self,
|
||||
ring: &mut DmaRingBuffer,
|
||||
slot: usize,
|
||||
req: &InferenceRequest,
|
||||
) -> Result<()> {
|
||||
use crate::backend::{protocol, RequestFrame};
|
||||
|
||||
let buffer = ring.request_buffer(slot);
|
||||
let shape = &req.shape;
|
||||
|
||||
// Write header
|
||||
let frame = RequestFrame::new(shape.seq_len, shape.d_model, shape.vocab, &req.model, 0, 16);
|
||||
let header = frame.to_bytes();
|
||||
buffer[..protocol::HEADER_SIZE].copy_from_slice(&header);
|
||||
|
||||
let mut offset = protocol::HEADER_SIZE;
|
||||
|
||||
// Write tokens
|
||||
for &token in req.tokens {
|
||||
buffer[offset..offset + 2].copy_from_slice(&token.to_le_bytes());
|
||||
offset += 2;
|
||||
}
|
||||
|
||||
// Write mask
|
||||
buffer[offset..offset + req.attn_mask.len()].copy_from_slice(req.attn_mask);
|
||||
offset += req.attn_mask.len();
|
||||
|
||||
// Write gate hint
|
||||
buffer[offset..offset + 2].copy_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
|
||||
offset += 2;
|
||||
buffer[offset] = req.gate_hint.boundary_crossed as u8;
|
||||
offset += 1;
|
||||
buffer[offset] = req.gate_hint.max_compute_class as u8;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read inference response from DMA buffer
|
||||
#[cfg(feature = "pcie")]
|
||||
fn read_response(
|
||||
&self,
|
||||
ring: &DmaRingBuffer,
|
||||
slot: usize,
|
||||
shape: &crate::types::FixedShape,
|
||||
) -> Result<(Vec<i16>, u32, u32, GateDecision)> {
|
||||
use crate::backend::ResponseFrame;
|
||||
|
||||
let buffer = ring.response_buffer(slot);
|
||||
|
||||
// Read response header
|
||||
let response = ResponseFrame::from_bytes(&buffer[..14].try_into().unwrap());
|
||||
|
||||
// Check status
|
||||
if response.status != 0 {
|
||||
return Err(Error::backend(format!(
|
||||
"FPGA error: status {}",
|
||||
response.status
|
||||
)));
|
||||
}
|
||||
|
||||
// Read logits
|
||||
let vocab = shape.vocab as usize;
|
||||
let mut logits = Vec::with_capacity(vocab);
|
||||
let mut offset = 14;
|
||||
|
||||
for _ in 0..vocab {
|
||||
let value = i16::from_le_bytes([buffer[offset], buffer[offset + 1]]);
|
||||
logits.push(value);
|
||||
offset += 2;
|
||||
}
|
||||
|
||||
Ok((
|
||||
logits,
|
||||
response.cycles,
|
||||
response.latency_ns,
|
||||
response.to_gate_decision(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Ring doorbell to notify FPGA of pending request
|
||||
#[cfg(feature = "pcie")]
|
||||
fn ring_doorbell(&self, _slot: usize) {
|
||||
// In a real implementation, this would write to a control register
|
||||
// to notify the FPGA that a new request is available
|
||||
}
|
||||
|
||||
/// Wait for response with polling
|
||||
fn wait_for_response(&self, ring: &DmaRingBuffer, slot: usize, timeout_ms: u64) -> Result<()> {
|
||||
let start = Instant::now();
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
|
||||
while !ring.is_complete(slot) {
|
||||
if start.elapsed() > timeout {
|
||||
return Err(Error::Timeout { ms: timeout_ms });
|
||||
}
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Allocate FPGA DDR memory for model weights
|
||||
fn allocate_fpga_memory(&self, size: usize) -> Result<u64> {
|
||||
// Align to 4KB boundary for DMA efficiency
|
||||
let aligned_size = (size + 0xFFF) & !0xFFF;
|
||||
|
||||
// Atomic allocation (simple bump allocator)
|
||||
let offset = self
|
||||
.fpga_mem_offset
|
||||
.fetch_add(aligned_size as u64, Ordering::SeqCst);
|
||||
|
||||
// Check for overflow
|
||||
if offset + aligned_size as u64 > self.fpga_mem_size {
|
||||
// Roll back allocation
|
||||
self.fpga_mem_offset
|
||||
.fetch_sub(aligned_size as u64, Ordering::SeqCst);
|
||||
return Err(Error::ResourceExhausted("FPGA DDR memory full".into()));
|
||||
}
|
||||
|
||||
Ok(offset)
|
||||
}
|
||||
|
||||
/// Upload model weights to FPGA DDR via DMA
|
||||
#[cfg(feature = "pcie")]
|
||||
fn upload_weights_dma(&self, weights: &[u8], fpga_offset: u64) -> Result<()> {
|
||||
// DMA transfer configuration
|
||||
const DMA_CHUNK_SIZE: usize = 64 * 1024; // 64KB per transfer
|
||||
|
||||
let ring = self
|
||||
.ring
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
|
||||
|
||||
// Transfer weights in chunks
|
||||
let mut transferred = 0usize;
|
||||
while transferred < weights.len() {
|
||||
let chunk_size = DMA_CHUNK_SIZE.min(weights.len() - transferred);
|
||||
|
||||
// Acquire a DMA slot
|
||||
let slot = loop {
|
||||
if let Some(s) = ring.acquire_slot() {
|
||||
break s;
|
||||
}
|
||||
std::hint::spin_loop();
|
||||
};
|
||||
|
||||
// Write DMA command to slot (simplified protocol)
|
||||
// In real hardware:
|
||||
// - Write target FPGA DDR address
|
||||
// - Write source offset in slot
|
||||
// - Write transfer length
|
||||
// - Ring doorbell
|
||||
|
||||
// For now, we simulate the DMA by marking complete
|
||||
ring.mark_complete(slot);
|
||||
|
||||
// Wait for completion
|
||||
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
|
||||
|
||||
// Release slot
|
||||
ring.release_slot(slot);
|
||||
|
||||
transferred += chunk_size;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Free FPGA DDR memory (simplified - real impl would use proper allocator)
|
||||
fn free_fpga_memory(&self, _offset: u64, _size: usize) {
|
||||
// In a production system, this would:
|
||||
// 1. Mark the memory region as free in an allocator
|
||||
// 2. Potentially compact memory if fragmentation is high
|
||||
// 3. Update hardware memory management unit
|
||||
//
|
||||
// For this implementation, we use a bump allocator without free.
|
||||
// Memory is reclaimed when all models are unloaded.
|
||||
}
|
||||
|
||||
/// Check if all models are unloaded and reset memory allocator
|
||||
fn maybe_reset_allocator(&self) {
|
||||
let models_empty = read_lock(&self.models, |m| m.is_empty()).unwrap_or(false);
|
||||
if models_empty {
|
||||
self.fpga_mem_offset
|
||||
.store(FPGA_DDR_MODEL_BASE, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for FpgaPcieBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
{
|
||||
let _ = artifact;
|
||||
return Err(Error::FeatureNotAvailable(
|
||||
"PCIe support not compiled".into(),
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
{
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
let model_id = artifact.model_id();
|
||||
let weights_size = artifact.weights.len();
|
||||
|
||||
// Allocate FPGA DDR memory for weights
|
||||
let weights_offset = self.allocate_fpga_memory(weights_size)?;
|
||||
|
||||
// Upload model weights to FPGA DDR via DMA
|
||||
if let Err(e) = self.upload_weights_dma(&artifact.weights, weights_offset) {
|
||||
// Roll back allocation on failure
|
||||
self.free_fpga_memory(weights_offset, weights_size);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Get slot number for this model
|
||||
let fpga_slot = read_lock(&self.models, |m| m.len() as u32)?;
|
||||
|
||||
// Store metadata
|
||||
write_lock(&self.models, |models| {
|
||||
models.insert(
|
||||
model_id,
|
||||
ModelMetadata {
|
||||
artifact: artifact.clone(),
|
||||
fpga_slot,
|
||||
weights_offset,
|
||||
weights_size,
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded += 1;
|
||||
})?;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
#[cfg(not(feature = "pcie"))]
|
||||
{
|
||||
let _ = req;
|
||||
return Err(Error::FeatureNotAvailable(
|
||||
"PCIe support not compiled".into(),
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
{
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Get model metadata
|
||||
let model_artifact = read_lock(&self.models, |models| {
|
||||
models.get(&req.model).map(|m| m.artifact.clone())
|
||||
})?
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate tokens against vocabulary
|
||||
validate_tokens(req.tokens, model_artifact.manifest.shape.vocab)?;
|
||||
|
||||
// Get ring buffer
|
||||
let ring = self
|
||||
.ring
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
|
||||
|
||||
// Acquire slot
|
||||
let slot = ring
|
||||
.acquire_slot()
|
||||
.ok_or_else(|| Error::ResourceExhausted("No DMA slots available".into()))?;
|
||||
|
||||
// Write request (need mutable access - simplified for now)
|
||||
// In production, this would use proper interior mutability
|
||||
// self.write_request(ring, slot, &req)?;
|
||||
|
||||
// Ring doorbell
|
||||
// self.ring_doorbell(slot);
|
||||
|
||||
// Wait for response
|
||||
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
|
||||
|
||||
// Read response
|
||||
let (logits, cycles, fpga_latency_ns, gate_decision) =
|
||||
self.read_response(ring, slot, &req.shape)?;
|
||||
|
||||
// Release slot
|
||||
ring.release_slot(slot);
|
||||
|
||||
let latency_ns = start.elapsed().as_nanos() as u32;
|
||||
|
||||
// Compute top-K using common utility
|
||||
let topk = compute_topk(&logits, 16);
|
||||
|
||||
// Create witness
|
||||
let witness = WitnessLog::new(
|
||||
model_artifact.model_hash(),
|
||||
model_artifact.quant_hash(),
|
||||
BackendKind::FpgaPcie,
|
||||
cycles,
|
||||
fpga_latency_ns.min(latency_ns),
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
// Update stats
|
||||
self.total_cycles
|
||||
.fetch_add(cycles as u64, Ordering::Relaxed);
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.total_inferences += 1;
|
||||
stats.total_cycles = self.total_cycles.load(Ordering::Relaxed);
|
||||
let n = stats.total_inferences;
|
||||
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(InferenceResult::new(logits, Some(topk), witness))
|
||||
}
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
// Remove from cache and get memory info for deallocation
|
||||
let removed = write_lock(&self.models, |models| {
|
||||
models
|
||||
.remove(&model)
|
||||
.map(|m| (m.weights_offset, m.weights_size))
|
||||
})?;
|
||||
|
||||
if let Some((offset, size)) = removed {
|
||||
// Free FPGA DDR memory
|
||||
self.free_fpga_memory(offset, size);
|
||||
|
||||
// Check if we can reset the allocator
|
||||
self.maybe_reset_allocator();
|
||||
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded = stats.models_loaded.saturating_sub(1);
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::FpgaPcie
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pcie_config_default() {
|
||||
let config = PcieConfig::default();
|
||||
assert_eq!(config.ring_slots, 16);
|
||||
assert_eq!(config.slot_size, 64 * 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slot_state_values() {
|
||||
assert_eq!(SlotState::Free as u8, 0);
|
||||
assert_eq!(SlotState::Pending as u8, 1);
|
||||
assert_eq!(SlotState::Complete as u8, 2);
|
||||
}
|
||||
}
|
||||
428
crates/ruvector-fpga-transformer/src/backend/mod.rs
Normal file
428
crates/ruvector-fpga-transformer/src/backend/mod.rs
Normal file
@@ -0,0 +1,428 @@
|
||||
//! Backend implementations for FPGA Transformer
|
||||
//!
|
||||
//! All backends implement the `TransformerBackend` trait for uniform API.
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::error::Result;
|
||||
use crate::types::{InferenceRequest, InferenceResult, ModelId};
|
||||
|
||||
#[cfg(feature = "native_sim")]
|
||||
pub mod native_sim;
|
||||
|
||||
#[cfg(feature = "daemon")]
|
||||
pub mod fpga_daemon;
|
||||
|
||||
#[cfg(feature = "pcie")]
|
||||
pub mod fpga_pcie;
|
||||
|
||||
#[cfg(feature = "wasm")]
|
||||
pub mod wasm_sim;
|
||||
|
||||
/// Trait for transformer inference backends
|
||||
///
|
||||
/// All backends must be thread-safe and implement the same API.
|
||||
pub trait TransformerBackend: Send + Sync {
|
||||
/// Load a model artifact and return its ID
|
||||
///
|
||||
/// The artifact is validated, test vectors are run, and
|
||||
/// the model is prepared for inference.
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId>;
|
||||
|
||||
/// Run inference on the given request
|
||||
///
|
||||
/// The request must specify a model that has been loaded.
|
||||
/// Returns the inference result with witness log.
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult>;
|
||||
|
||||
/// Unload a model to free resources
|
||||
fn unload(&self, model: ModelId) -> Result<()>;
|
||||
|
||||
/// Check if a model is loaded
|
||||
fn is_loaded(&self, model: ModelId) -> bool;
|
||||
|
||||
/// Get the backend kind
|
||||
fn kind(&self) -> crate::types::BackendKind;
|
||||
|
||||
/// Get backend-specific statistics
|
||||
fn stats(&self) -> BackendStats {
|
||||
BackendStats::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BackendStats {
|
||||
/// Number of models currently loaded
|
||||
pub models_loaded: usize,
|
||||
/// Total inferences performed
|
||||
pub total_inferences: u64,
|
||||
/// Total cycles consumed (FPGA only)
|
||||
pub total_cycles: u64,
|
||||
/// Average latency in nanoseconds
|
||||
pub avg_latency_ns: u64,
|
||||
/// P99 latency in nanoseconds
|
||||
pub p99_latency_ns: u64,
|
||||
/// Number of early exits
|
||||
pub early_exits: u64,
|
||||
/// Number of skipped inferences
|
||||
pub skipped: u64,
|
||||
}
|
||||
|
||||
/// Protocol constants for daemon/PCIe communication
|
||||
pub mod protocol {
|
||||
/// Magic number for frame validation
|
||||
pub const MAGIC: u32 = 0x5256_5846; // "RVXF" - RuVector FPGA
|
||||
|
||||
/// Current protocol version
|
||||
pub const VERSION: u16 = 1;
|
||||
|
||||
/// Frame header size in bytes
|
||||
pub const HEADER_SIZE: usize = 24;
|
||||
|
||||
/// Maximum payload size
|
||||
pub const MAX_PAYLOAD: usize = 1024 * 1024; // 1MB
|
||||
|
||||
/// Request flags
|
||||
pub mod flags {
|
||||
/// Return only top-K predictions
|
||||
pub const TOPK_ONLY: u16 = 0x0001;
|
||||
/// Use LUT-based softmax
|
||||
pub const LUT_SOFTMAX: u16 = 0x0002;
|
||||
/// Enable early exit
|
||||
pub const EARLY_EXIT: u16 = 0x0004;
|
||||
/// Return detailed witness
|
||||
pub const WITNESS_DETAIL: u16 = 0x0008;
|
||||
}
|
||||
|
||||
/// Response status codes
|
||||
pub mod status {
|
||||
/// Success
|
||||
pub const OK: u16 = 0;
|
||||
/// Model not found
|
||||
pub const MODEL_NOT_FOUND: u16 = 1;
|
||||
/// Shape mismatch
|
||||
pub const SHAPE_MISMATCH: u16 = 2;
|
||||
/// Gate blocked
|
||||
pub const GATE_BLOCKED: u16 = 3;
|
||||
/// Internal error
|
||||
pub const INTERNAL_ERROR: u16 = 0xFFFF;
|
||||
}
|
||||
}
|
||||
|
||||
/// Request frame for wire protocol
|
||||
#[repr(C, packed)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RequestFrame {
|
||||
/// Magic number (MAGIC)
|
||||
pub magic: u32,
|
||||
/// Protocol version
|
||||
pub protocol: u16,
|
||||
/// Sequence length
|
||||
pub seq_len: u16,
|
||||
/// Model dimension
|
||||
pub d_model: u16,
|
||||
/// Vocabulary size
|
||||
pub vocab: u16,
|
||||
/// Model ID (lower 32 bits)
|
||||
pub model_id_low: u32,
|
||||
/// Model ID (upper 32 bits)
|
||||
pub model_id_high: u32,
|
||||
/// Request flags
|
||||
pub flags: u16,
|
||||
/// Top-K count (if TOPK_ONLY flag set)
|
||||
pub topk: u16,
|
||||
}
|
||||
|
||||
impl RequestFrame {
|
||||
/// Create a new request frame
|
||||
pub fn new(
|
||||
seq_len: u16,
|
||||
d_model: u16,
|
||||
vocab: u32,
|
||||
model_id: &ModelId,
|
||||
flags: u16,
|
||||
topk: u16,
|
||||
) -> Self {
|
||||
let id_bytes = model_id.as_bytes();
|
||||
let model_id_low = u32::from_le_bytes([id_bytes[0], id_bytes[1], id_bytes[2], id_bytes[3]]);
|
||||
let model_id_high =
|
||||
u32::from_le_bytes([id_bytes[4], id_bytes[5], id_bytes[6], id_bytes[7]]);
|
||||
|
||||
Self {
|
||||
magic: protocol::MAGIC,
|
||||
protocol: protocol::VERSION,
|
||||
seq_len,
|
||||
d_model,
|
||||
vocab: (vocab & 0xFFFF) as u16,
|
||||
model_id_low,
|
||||
model_id_high,
|
||||
flags,
|
||||
topk,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize to bytes
|
||||
pub fn to_bytes(&self) -> [u8; protocol::HEADER_SIZE] {
|
||||
let mut bytes = [0u8; protocol::HEADER_SIZE];
|
||||
bytes[0..4].copy_from_slice(&self.magic.to_le_bytes());
|
||||
bytes[4..6].copy_from_slice(&self.protocol.to_le_bytes());
|
||||
bytes[6..8].copy_from_slice(&self.seq_len.to_le_bytes());
|
||||
bytes[8..10].copy_from_slice(&self.d_model.to_le_bytes());
|
||||
bytes[10..12].copy_from_slice(&self.vocab.to_le_bytes());
|
||||
bytes[12..16].copy_from_slice(&self.model_id_low.to_le_bytes());
|
||||
bytes[16..20].copy_from_slice(&self.model_id_high.to_le_bytes());
|
||||
bytes[20..22].copy_from_slice(&self.flags.to_le_bytes());
|
||||
bytes[22..24].copy_from_slice(&self.topk.to_le_bytes());
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
/// Response frame from wire protocol
|
||||
#[repr(C, packed)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ResponseFrame {
|
||||
/// Status code
|
||||
pub status: u16,
|
||||
/// Latency in nanoseconds
|
||||
pub latency_ns: u32,
|
||||
/// Compute cycles
|
||||
pub cycles: u32,
|
||||
/// Gate decision (packed)
|
||||
pub gate_decision: u8,
|
||||
/// Exit layer (if early exit)
|
||||
pub exit_layer: u8,
|
||||
/// Skip reason (if skipped)
|
||||
pub skip_reason: u8,
|
||||
/// Reserved
|
||||
pub reserved: u8,
|
||||
}
|
||||
|
||||
impl ResponseFrame {
|
||||
/// Parse from bytes
|
||||
pub fn from_bytes(bytes: &[u8; 14]) -> Self {
|
||||
Self {
|
||||
status: u16::from_le_bytes([bytes[0], bytes[1]]),
|
||||
latency_ns: u32::from_le_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
|
||||
cycles: u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]),
|
||||
gate_decision: bytes[10],
|
||||
exit_layer: bytes[11],
|
||||
skip_reason: bytes[12],
|
||||
reserved: bytes[13],
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert gate decision to enum
|
||||
pub fn to_gate_decision(&self) -> crate::types::GateDecision {
|
||||
match self.gate_decision {
|
||||
0 => crate::types::GateDecision::RanFull,
|
||||
1 => crate::types::GateDecision::EarlyExit {
|
||||
layer: self.exit_layer,
|
||||
},
|
||||
2 => crate::types::GateDecision::Skipped {
|
||||
reason: match self.skip_reason {
|
||||
0 => crate::types::SkipReason::LowCoherence,
|
||||
1 => crate::types::SkipReason::PolicyDenied,
|
||||
_ => crate::types::SkipReason::BudgetExceeded,
|
||||
},
|
||||
},
|
||||
_ => crate::types::GateDecision::RanFull,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate CRC32 checksum for frame validation
|
||||
pub fn crc32(data: &[u8]) -> u32 {
|
||||
// Simple CRC32 implementation (could use crc32fast crate in production)
|
||||
let mut crc: u32 = 0xFFFFFFFF;
|
||||
for &byte in data {
|
||||
crc ^= byte as u32;
|
||||
for _ in 0..8 {
|
||||
crc = if crc & 1 != 0 {
|
||||
(crc >> 1) ^ 0xEDB88320
|
||||
} else {
|
||||
crc >> 1
|
||||
};
|
||||
}
|
||||
}
|
||||
!crc
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Common utilities shared across backends
|
||||
// ============================================================================
|
||||
|
||||
/// Compute top-K predictions from logits
|
||||
/// Returns sorted (token_id, logit) pairs, descending by logit value
|
||||
#[inline]
|
||||
pub fn compute_topk(logits: &[i16], k: usize) -> Vec<(u16, i16)> {
|
||||
if logits.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// For small K, partial sort is faster
|
||||
if k <= 32 && logits.len() > 100 {
|
||||
// Use partial heap-based selection
|
||||
let mut heap: Vec<(i16, u16)> = Vec::with_capacity(k + 1);
|
||||
for (i, &v) in logits.iter().enumerate() {
|
||||
if heap.len() < k {
|
||||
heap.push((v, i as u16));
|
||||
if heap.len() == k {
|
||||
// Heapify
|
||||
heap.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
}
|
||||
} else if v > heap[0].0 {
|
||||
heap[0] = (v, i as u16);
|
||||
// Maintain min-heap property
|
||||
let mut idx = 0;
|
||||
while idx * 2 + 1 < heap.len() {
|
||||
let left = idx * 2 + 1;
|
||||
let right = idx * 2 + 2;
|
||||
let mut smallest = idx;
|
||||
if heap[left].0 < heap[smallest].0 {
|
||||
smallest = left;
|
||||
}
|
||||
if right < heap.len() && heap[right].0 < heap[smallest].0 {
|
||||
smallest = right;
|
||||
}
|
||||
if smallest == idx {
|
||||
break;
|
||||
}
|
||||
heap.swap(idx, smallest);
|
||||
idx = smallest;
|
||||
}
|
||||
}
|
||||
}
|
||||
heap.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
heap.into_iter().map(|(v, i)| (i, v)).collect()
|
||||
} else {
|
||||
// Full sort for small arrays
|
||||
let mut indexed: Vec<(usize, i16)> = logits.iter().cloned().enumerate().collect();
|
||||
indexed.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
indexed
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(i, v)| (i as u16, v))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to safely read from RwLock, returning error on poison
|
||||
pub fn read_lock<T, R>(
|
||||
lock: &std::sync::RwLock<T>,
|
||||
f: impl FnOnce(&T) -> R,
|
||||
) -> crate::error::Result<R> {
|
||||
lock.read()
|
||||
.map(|guard| f(&*guard))
|
||||
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (read)".into()))
|
||||
}
|
||||
|
||||
/// Helper to safely write to RwLock, returning error on poison
|
||||
pub fn write_lock<T, R>(
|
||||
lock: &std::sync::RwLock<T>,
|
||||
f: impl FnOnce(&mut T) -> R,
|
||||
) -> crate::error::Result<R> {
|
||||
lock.write()
|
||||
.map(|mut guard| f(&mut *guard))
|
||||
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (write)".into()))
|
||||
}
|
||||
|
||||
/// Validate token indices against vocabulary size
|
||||
#[inline]
|
||||
pub fn validate_tokens(tokens: &[u16], vocab_size: u32) -> crate::error::Result<()> {
|
||||
for (i, &token) in tokens.iter().enumerate() {
|
||||
if token as u32 >= vocab_size {
|
||||
return Err(crate::error::Error::InvalidConfig(format!(
|
||||
"Token {} at index {} exceeds vocabulary size {}",
|
||||
token, i, vocab_size
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build witness log from inference metadata
|
||||
pub fn build_witness(
|
||||
model_hash: [u8; 32],
|
||||
quant_hash: [u8; 32],
|
||||
backend: crate::types::BackendKind,
|
||||
cycles: u32,
|
||||
latency_ns: u32,
|
||||
gate_decision: crate::types::GateDecision,
|
||||
) -> crate::types::WitnessLog {
|
||||
crate::types::WitnessLog::new(
|
||||
model_hash,
|
||||
quant_hash,
|
||||
backend,
|
||||
cycles,
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
)
|
||||
}
|
||||
|
||||
/// Command types for daemon protocol
|
||||
pub mod commands {
|
||||
/// Load model command
|
||||
pub const LOAD_MODEL: u8 = 0x01;
|
||||
/// Unload model command
|
||||
pub const UNLOAD_MODEL: u8 = 0x02;
|
||||
/// Inference request command
|
||||
pub const INFER: u8 = 0x03;
|
||||
/// Ping/health check command
|
||||
pub const PING: u8 = 0x04;
|
||||
/// Get status command
|
||||
pub const STATUS: u8 = 0x05;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_request_frame_roundtrip() {
|
||||
let model_id = ModelId::new([0x42u8; 32]);
|
||||
let frame = RequestFrame::new(64, 256, 32000, &model_id, 0, 16);
|
||||
let bytes = frame.to_bytes();
|
||||
|
||||
assert_eq!(bytes.len(), protocol::HEADER_SIZE);
|
||||
assert_eq!(&bytes[0..4], &protocol::MAGIC.to_le_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crc32() {
|
||||
let data = b"test data";
|
||||
let crc = crc32(data);
|
||||
// CRC should be consistent
|
||||
assert_eq!(crc, crc32(data));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_topk() {
|
||||
let logits: Vec<i16> = vec![100, 50, 300, 200, 150];
|
||||
let topk = compute_topk(&logits, 3);
|
||||
|
||||
assert_eq!(topk.len(), 3);
|
||||
assert_eq!(topk[0], (2, 300)); // Index 2, value 300
|
||||
assert_eq!(topk[1], (3, 200)); // Index 3, value 200
|
||||
assert_eq!(topk[2], (4, 150)); // Index 4, value 150
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_topk_large() {
|
||||
let logits: Vec<i16> = (0..1000).map(|i| (i * 7 % 500) as i16).collect();
|
||||
let topk = compute_topk(&logits, 10);
|
||||
|
||||
assert_eq!(topk.len(), 10);
|
||||
// Should be sorted descending
|
||||
for i in 1..topk.len() {
|
||||
assert!(topk[i - 1].1 >= topk[i].1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tokens() {
|
||||
assert!(validate_tokens(&[0, 1, 2], 100).is_ok());
|
||||
assert!(validate_tokens(&[99], 100).is_ok());
|
||||
assert!(validate_tokens(&[100], 100).is_err());
|
||||
assert!(validate_tokens(&[0, 50, 101], 100).is_err());
|
||||
}
|
||||
}
|
||||
544
crates/ruvector-fpga-transformer/src/backend/native_sim.rs
Normal file
544
crates/ruvector-fpga-transformer/src/backend/native_sim.rs
Normal file
@@ -0,0 +1,544 @@
|
||||
//! Native Rust simulator backend
|
||||
//!
|
||||
//! Provides a pure-Rust implementation of the transformer inference
|
||||
//! for testing, development, and fallback when no FPGA is available.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{
|
||||
compute_topk, read_lock, validate_tokens, write_lock, BackendStats, TransformerBackend,
|
||||
};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::gating::CoherenceGate;
|
||||
use crate::quant::{dequantize_i8, quantize_i16, softmax_lut};
|
||||
use crate::types::{
|
||||
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
|
||||
QuantSpec, SkipReason, WitnessLog,
|
||||
};
|
||||
|
||||
/// Loaded model data for native simulation
|
||||
struct LoadedModel {
|
||||
/// Model artifact (contains weights and config)
|
||||
artifact: ModelArtifact,
|
||||
/// Precomputed embedding matrix (dequantized for sim)
|
||||
embeddings: Vec<f32>,
|
||||
/// Layer weights (simplified for simulation)
|
||||
layers: Vec<LayerWeights>,
|
||||
/// Output projection
|
||||
output_proj: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Simplified layer weights for simulation
|
||||
struct LayerWeights {
|
||||
/// Attention Q projection
|
||||
wq: Vec<f32>,
|
||||
/// Attention K projection
|
||||
wk: Vec<f32>,
|
||||
/// Attention V projection
|
||||
wv: Vec<f32>,
|
||||
/// Attention output projection
|
||||
wo: Vec<f32>,
|
||||
/// FFN up projection
|
||||
w1: Vec<f32>,
|
||||
/// FFN down projection
|
||||
w2: Vec<f32>,
|
||||
/// Layer norm weights
|
||||
ln1_weight: Vec<f32>,
|
||||
ln2_weight: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Native simulator backend
|
||||
pub struct NativeSimBackend {
|
||||
/// Loaded models
|
||||
models: RwLock<HashMap<ModelId, Arc<LoadedModel>>>,
|
||||
/// Coherence gate
|
||||
gate: Arc<dyn CoherenceGate>,
|
||||
/// Statistics
|
||||
stats: RwLock<BackendStats>,
|
||||
/// Configuration
|
||||
config: NativeSimConfig,
|
||||
}
|
||||
|
||||
/// Configuration for native simulator
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NativeSimConfig {
|
||||
/// Maximum models to keep loaded
|
||||
pub max_models: usize,
|
||||
/// Enable detailed tracing
|
||||
pub trace: bool,
|
||||
/// Use LUT-based softmax
|
||||
pub lut_softmax: bool,
|
||||
/// Number of layers to simulate (0 = all)
|
||||
pub max_layers: usize,
|
||||
}
|
||||
|
||||
impl Default for NativeSimConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_models: 8,
|
||||
trace: false,
|
||||
lut_softmax: true,
|
||||
max_layers: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeSimBackend {
|
||||
/// Create a new native simulator backend
|
||||
pub fn new(gate: Arc<dyn CoherenceGate>) -> Self {
|
||||
Self::with_config(gate, NativeSimConfig::default())
|
||||
}
|
||||
|
||||
/// Create with custom configuration
|
||||
pub fn with_config(gate: Arc<dyn CoherenceGate>, config: NativeSimConfig) -> Self {
|
||||
Self {
|
||||
models: RwLock::new(HashMap::new()),
|
||||
gate,
|
||||
stats: RwLock::new(BackendStats::default()),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the core transformer inference
|
||||
fn run_inference(
|
||||
&self,
|
||||
model: &LoadedModel,
|
||||
tokens: &[u16],
|
||||
_attn_mask: &[u8],
|
||||
gate_hint: &GateHint,
|
||||
) -> Result<(Vec<i16>, GateDecision)> {
|
||||
let shape = &model.artifact.manifest.shape;
|
||||
let num_layers = model.layers.len();
|
||||
|
||||
// Check preflight gate
|
||||
let preflight = self.gate.preflight(gate_hint);
|
||||
if let GateDecision::Skipped { reason } = preflight {
|
||||
return Ok((
|
||||
vec![0i16; shape.vocab as usize],
|
||||
GateDecision::Skipped { reason },
|
||||
));
|
||||
}
|
||||
|
||||
// Initialize hidden states from embeddings
|
||||
let d_model = shape.d_model as usize;
|
||||
let seq_len = tokens.len();
|
||||
let mut hidden = vec![0.0f32; seq_len * d_model];
|
||||
|
||||
// Lookup embeddings
|
||||
for (i, &token) in tokens.iter().enumerate() {
|
||||
let offset = (token as usize) * d_model;
|
||||
if offset + d_model <= model.embeddings.len() {
|
||||
hidden[i * d_model..(i + 1) * d_model]
|
||||
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
|
||||
}
|
||||
}
|
||||
|
||||
// Run through layers
|
||||
let max_layers = if self.config.max_layers > 0 {
|
||||
self.config.max_layers.min(num_layers)
|
||||
} else {
|
||||
num_layers
|
||||
};
|
||||
|
||||
for layer_idx in 0..max_layers {
|
||||
let layer = &model.layers[layer_idx];
|
||||
|
||||
// Check layer checkpoint for early exit
|
||||
let coherence_signal = self.compute_coherence_signal(&hidden);
|
||||
if let Some(decision) = self.gate.checkpoint(layer_idx as u8, coherence_signal) {
|
||||
if let GateDecision::EarlyExit { layer } = decision {
|
||||
// Early exit - compute output from current hidden state
|
||||
let logits = self.compute_output(&hidden, &model.output_proj, shape);
|
||||
return Ok((logits, GateDecision::EarlyExit { layer }));
|
||||
}
|
||||
}
|
||||
|
||||
// Simplified attention + FFN (for simulation purposes)
|
||||
hidden = self.run_layer(&hidden, layer, shape);
|
||||
}
|
||||
|
||||
// Compute output logits
|
||||
let logits = self.compute_output(&hidden, &model.output_proj, shape);
|
||||
|
||||
Ok((logits, GateDecision::RanFull))
|
||||
}
|
||||
|
||||
/// Run a single transformer layer
|
||||
fn run_layer(&self, hidden: &[f32], layer: &LayerWeights, shape: &FixedShape) -> Vec<f32> {
|
||||
let d_model = shape.d_model as usize;
|
||||
let seq_len = hidden.len() / d_model;
|
||||
|
||||
// Simplified layer computation
|
||||
// In a real implementation, this would do full attention + FFN
|
||||
|
||||
let mut output = hidden.to_vec();
|
||||
|
||||
// Layer norm 1
|
||||
for t in 0..seq_len {
|
||||
let start = t * d_model;
|
||||
let end = start + d_model;
|
||||
layer_norm_inplace(&mut output[start..end], &layer.ln1_weight);
|
||||
}
|
||||
|
||||
// Simplified attention (just apply output projection as placeholder)
|
||||
// Real implementation would compute Q, K, V, attention scores, etc.
|
||||
if !layer.wo.is_empty() {
|
||||
let mut attn_out = vec![0.0f32; output.len()];
|
||||
for t in 0..seq_len {
|
||||
for i in 0..d_model {
|
||||
let mut sum = 0.0f32;
|
||||
for j in 0..d_model.min(layer.wo.len() / d_model) {
|
||||
sum += output[t * d_model + j] * layer.wo[j * d_model + i];
|
||||
}
|
||||
attn_out[t * d_model + i] = sum;
|
||||
}
|
||||
}
|
||||
// Residual connection
|
||||
for i in 0..output.len() {
|
||||
output[i] += attn_out[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Layer norm 2
|
||||
for t in 0..seq_len {
|
||||
let start = t * d_model;
|
||||
let end = start + d_model;
|
||||
layer_norm_inplace(&mut output[start..end], &layer.ln2_weight);
|
||||
}
|
||||
|
||||
// Simplified FFN (SwiGLU-like)
|
||||
if !layer.w1.is_empty() && !layer.w2.is_empty() {
|
||||
let ffn_dim = layer.w1.len() / d_model;
|
||||
let mut ffn_out = vec![0.0f32; output.len()];
|
||||
|
||||
for t in 0..seq_len {
|
||||
// Up projection
|
||||
let mut up = vec![0.0f32; ffn_dim];
|
||||
for i in 0..ffn_dim {
|
||||
for j in 0..d_model {
|
||||
up[i] += output[t * d_model + j] * layer.w1[j * ffn_dim + i];
|
||||
}
|
||||
// SiLU activation
|
||||
up[i] = up[i] * sigmoid(up[i]);
|
||||
}
|
||||
|
||||
// Down projection
|
||||
for i in 0..d_model {
|
||||
for j in 0..ffn_dim.min(layer.w2.len() / d_model) {
|
||||
ffn_out[t * d_model + i] += up[j] * layer.w2[j * d_model + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
for i in 0..output.len() {
|
||||
output[i] += ffn_out[i];
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Compute output logits from hidden state
|
||||
fn compute_output(&self, hidden: &[f32], output_proj: &[f32], shape: &FixedShape) -> Vec<i16> {
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
let seq_len = hidden.len() / d_model;
|
||||
|
||||
// Take last token's hidden state
|
||||
let last_hidden = &hidden[(seq_len - 1) * d_model..];
|
||||
|
||||
// Compute logits
|
||||
let mut logits = vec![0.0f32; vocab];
|
||||
if output_proj.len() >= d_model * vocab {
|
||||
for v in 0..vocab {
|
||||
for d in 0..d_model {
|
||||
logits[v] += last_hidden[d] * output_proj[d * vocab + v];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: random logits for simulation when weights not available
|
||||
for v in 0..vocab {
|
||||
logits[v] = (v as f32 * 0.01).sin();
|
||||
}
|
||||
}
|
||||
|
||||
// Apply softmax (optional) and quantize
|
||||
if self.config.lut_softmax {
|
||||
softmax_lut(&mut logits);
|
||||
} else {
|
||||
softmax_f32(&mut logits);
|
||||
}
|
||||
|
||||
// Quantize to i16
|
||||
quantize_i16(&logits)
|
||||
}
|
||||
|
||||
/// Compute coherence signal for early exit decision
|
||||
fn compute_coherence_signal(&self, hidden: &[f32]) -> i16 {
|
||||
// Simple coherence metric: variance of hidden states
|
||||
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
|
||||
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
|
||||
|
||||
// Scale to Q8.8 fixed point
|
||||
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
|
||||
}
|
||||
|
||||
/// Prepare model from artifact (dequantize weights for simulation)
|
||||
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<LoadedModel> {
|
||||
let shape = &artifact.manifest.shape;
|
||||
let quant = &artifact.manifest.quant;
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
|
||||
// Dequantize embeddings
|
||||
let embedding_size = vocab * d_model;
|
||||
let embeddings = if artifact.weights.len() >= embedding_size {
|
||||
dequantize_i8(&artifact.weights[..embedding_size], quant)
|
||||
} else {
|
||||
// Generate random embeddings for testing
|
||||
(0..embedding_size)
|
||||
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Create simplified layer weights
|
||||
let num_layers = 4; // Default for simulation
|
||||
let layers: Vec<LayerWeights> = (0..num_layers)
|
||||
.map(|_| LayerWeights {
|
||||
wq: vec![0.01; d_model * d_model],
|
||||
wk: vec![0.01; d_model * d_model],
|
||||
wv: vec![0.01; d_model * d_model],
|
||||
wo: vec![0.01; d_model * d_model],
|
||||
w1: vec![0.01; d_model * 4 * d_model],
|
||||
w2: vec![0.01; 4 * d_model * d_model],
|
||||
ln1_weight: vec![1.0; d_model],
|
||||
ln2_weight: vec![1.0; d_model],
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Output projection
|
||||
let output_proj = vec![0.01; d_model * vocab];
|
||||
|
||||
Ok(LoadedModel {
|
||||
artifact: artifact.clone(),
|
||||
embeddings,
|
||||
layers,
|
||||
output_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBackend for NativeSimBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
// Prepare model
|
||||
let model = self.prepare_model(artifact)?;
|
||||
let model_id = artifact.model_id();
|
||||
|
||||
// Check capacity (with poison handling)
|
||||
let at_capacity = read_lock(&self.models, |models| {
|
||||
models.len() >= self.config.max_models && !models.contains_key(&model_id)
|
||||
})?;
|
||||
|
||||
if at_capacity {
|
||||
return Err(Error::ResourceExhausted("Max models reached".into()));
|
||||
}
|
||||
|
||||
// Store model
|
||||
write_lock(&self.models, |models| {
|
||||
models.insert(model_id, Arc::new(model));
|
||||
})?;
|
||||
|
||||
// Update stats
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded += 1;
|
||||
})?;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Get model (with poison handling)
|
||||
let model = read_lock(&self.models, |models| models.get(&req.model).cloned())?
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate shape
|
||||
if model.artifact.manifest.shape != req.shape {
|
||||
return Err(Error::ShapeMismatch {
|
||||
expected: model.artifact.manifest.shape,
|
||||
actual: req.shape,
|
||||
});
|
||||
}
|
||||
|
||||
// Validate tokens against vocabulary
|
||||
validate_tokens(req.tokens, model.artifact.manifest.shape.vocab)?;
|
||||
|
||||
// Run inference
|
||||
let (logits_q, gate_decision) =
|
||||
self.run_inference(&model, req.tokens, req.attn_mask, &req.gate_hint)?;
|
||||
|
||||
let latency_ns = start.elapsed().as_nanos() as u32;
|
||||
|
||||
// Compute top-K using common utility
|
||||
let topk = compute_topk(&logits_q, 16);
|
||||
|
||||
// Create witness
|
||||
let witness = WitnessLog::new(
|
||||
model.artifact.model_hash(),
|
||||
model.artifact.quant_hash(),
|
||||
BackendKind::NativeSim,
|
||||
0, // No cycles for simulator
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
// Update stats (with poison handling)
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.total_inferences += 1;
|
||||
let n = stats.total_inferences;
|
||||
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(InferenceResult::new(logits_q, Some(topk), witness))
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
|
||||
|
||||
if removed {
|
||||
write_lock(&self.stats, |stats| {
|
||||
stats.models_loaded = stats.models_loaded.saturating_sub(1);
|
||||
})?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::NativeSim
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn layer_norm_inplace(x: &mut [f32], weight: &[f32]) {
|
||||
let mean = x.iter().sum::<f32>() / x.len() as f32;
|
||||
let variance = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
|
||||
let std = (variance + 1e-5).sqrt();
|
||||
|
||||
for (i, v) in x.iter_mut().enumerate() {
|
||||
*v = (*v - mean) / std * weight.get(i).copied().unwrap_or(1.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
fn softmax_f32(x: &mut [f32]) {
|
||||
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in x.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
if sum > 0.0 {
|
||||
for v in x.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "test_model".into(),
|
||||
model_hash: "0".repeat(64),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact {
|
||||
manifest,
|
||||
weights: vec![0u8; 4096 * 64], // Minimal embedding weights
|
||||
bitstream: None,
|
||||
calibration: None,
|
||||
test_vectors: vec![],
|
||||
signature: [0u8; 64],
|
||||
pubkey: [0u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_native_sim_load_unload() {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let artifact = create_test_artifact();
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
|
||||
assert!(backend.is_loaded(model_id));
|
||||
|
||||
backend.unload(model_id).unwrap();
|
||||
assert!(!backend.is_loaded(model_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_native_sim_inference() {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::new(gate);
|
||||
|
||||
let artifact = create_test_artifact();
|
||||
let model_id = backend.load(&artifact).unwrap();
|
||||
|
||||
let tokens: Vec<u16> = (0..32).collect();
|
||||
let mask = vec![1u8; 32];
|
||||
|
||||
let req = InferenceRequest::new(
|
||||
model_id,
|
||||
FixedShape::micro(),
|
||||
&tokens,
|
||||
&mask,
|
||||
GateHint::allow_all(),
|
||||
);
|
||||
|
||||
let result = backend.infer(req).unwrap();
|
||||
|
||||
assert!(!result.logits_q.is_empty());
|
||||
assert!(result.topk.is_some());
|
||||
assert_eq!(result.witness.backend, BackendKind::NativeSim);
|
||||
}
|
||||
}
|
||||
348
crates/ruvector-fpga-transformer/src/backend/wasm_sim.rs
Normal file
348
crates/ruvector-fpga-transformer/src/backend/wasm_sim.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
//! WASM Simulator backend
|
||||
//!
|
||||
//! Pure Rust implementation that runs in WASM environments.
|
||||
//! Uses RefCell for interior mutability since WASM is single-threaded.
|
||||
|
||||
#![cfg(feature = "wasm")]
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
use crate::artifact::ModelArtifact;
|
||||
use crate::backend::{compute_topk, validate_tokens, BackendStats, TransformerBackend};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::gating::CoherenceGate;
|
||||
use crate::quant::{dequantize_i8, quantize_i16};
|
||||
use crate::types::{
|
||||
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
|
||||
QuantSpec, WitnessLog,
|
||||
};
|
||||
|
||||
/// Loaded model for WASM simulation
|
||||
struct WasmModel {
|
||||
/// Model artifact
|
||||
artifact: ModelArtifact,
|
||||
/// Prepacked embedding table (dequantized to f32 for computation)
|
||||
embeddings: Vec<f32>,
|
||||
/// Number of layers
|
||||
num_layers: usize,
|
||||
/// Shape info
|
||||
shape: FixedShape,
|
||||
}
|
||||
|
||||
/// WASM simulator backend state (interior mutable for single-threaded WASM)
|
||||
struct WasmState {
|
||||
/// Loaded models
|
||||
models: HashMap<ModelId, WasmModel>,
|
||||
/// Statistics
|
||||
stats: BackendStats,
|
||||
}
|
||||
|
||||
/// WASM simulator backend
|
||||
///
|
||||
/// Uses RefCell for interior mutability since WASM is inherently single-threaded.
|
||||
/// This allows the TransformerBackend trait to be implemented with &self methods.
|
||||
pub struct WasmSimBackend {
|
||||
/// Interior mutable state
|
||||
state: RefCell<WasmState>,
|
||||
/// Coherence gate (immutable, shared)
|
||||
gate: Rc<dyn CoherenceGate>,
|
||||
}
|
||||
|
||||
impl WasmSimBackend {
|
||||
/// Create a new WASM simulator backend
|
||||
pub fn new(gate: Rc<dyn CoherenceGate>) -> Self {
|
||||
Self {
|
||||
state: RefCell::new(WasmState {
|
||||
models: HashMap::new(),
|
||||
stats: BackendStats::default(),
|
||||
}),
|
||||
gate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare model from artifact
|
||||
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<WasmModel> {
|
||||
let shape = artifact.manifest.shape;
|
||||
let quant = &artifact.manifest.quant;
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
|
||||
// Dequantize embeddings
|
||||
let embedding_size = vocab * d_model;
|
||||
let embeddings = if artifact.weights.len() >= embedding_size {
|
||||
dequantize_i8(&artifact.weights[..embedding_size], quant)
|
||||
} else {
|
||||
// Generate deterministic embeddings for testing
|
||||
(0..embedding_size)
|
||||
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Determine number of layers from artifact or default
|
||||
let num_layers = if artifact.manifest.backend.options.early_exit {
|
||||
6
|
||||
} else {
|
||||
4
|
||||
};
|
||||
|
||||
Ok(WasmModel {
|
||||
artifact: artifact.clone(),
|
||||
embeddings,
|
||||
num_layers,
|
||||
shape,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run inference for WASM
|
||||
fn run_inference(
|
||||
&self,
|
||||
model: &WasmModel,
|
||||
tokens: &[u16],
|
||||
gate_hint: &GateHint,
|
||||
) -> (Vec<i16>, GateDecision) {
|
||||
let shape = &model.shape;
|
||||
|
||||
// Check preflight
|
||||
let preflight = self.gate.preflight(gate_hint);
|
||||
if !preflight.did_run() {
|
||||
return (vec![0i16; shape.vocab as usize], preflight);
|
||||
}
|
||||
|
||||
let vocab = shape.vocab as usize;
|
||||
let d_model = shape.d_model as usize;
|
||||
|
||||
// Initialize hidden state from embeddings
|
||||
let seq_len = tokens.len();
|
||||
let mut hidden = vec![0.0f32; seq_len * d_model];
|
||||
|
||||
// Lookup embeddings with bounds checking
|
||||
for (i, &token) in tokens.iter().enumerate() {
|
||||
let offset = (token as usize).min(vocab.saturating_sub(1)) * d_model;
|
||||
if offset + d_model <= model.embeddings.len() {
|
||||
hidden[i * d_model..(i + 1) * d_model]
|
||||
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
|
||||
}
|
||||
}
|
||||
|
||||
// Run through simplified layers with early exit support
|
||||
for layer in 0..model.num_layers {
|
||||
// Simple layer computation (for WASM we keep it lightweight)
|
||||
// Apply simple transformation
|
||||
for t in 0..seq_len {
|
||||
let start = t * d_model;
|
||||
// Simple ReLU-like activation
|
||||
for i in 0..d_model {
|
||||
hidden[start + i] =
|
||||
hidden[start + i].max(0.0) * 0.99 + hidden[start + i] * 0.01;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for early exit
|
||||
let coherence_signal = compute_coherence(&hidden);
|
||||
if let Some(decision) = self.gate.checkpoint(layer as u8, coherence_signal) {
|
||||
let logits = self.compute_output(&hidden, model);
|
||||
return (logits, decision);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute output logits
|
||||
let logits = self.compute_output(&hidden, model);
|
||||
(logits, GateDecision::RanFull)
|
||||
}
|
||||
|
||||
/// Compute output logits from hidden state
|
||||
fn compute_output(&self, hidden: &[f32], model: &WasmModel) -> Vec<i16> {
|
||||
let shape = &model.shape;
|
||||
let d_model = shape.d_model as usize;
|
||||
let vocab = shape.vocab as usize;
|
||||
let seq_len = hidden.len() / d_model;
|
||||
|
||||
// Take last token's hidden state
|
||||
let last_hidden = &hidden[(seq_len.saturating_sub(1)) * d_model..];
|
||||
|
||||
// Compute logits via dot product with embedding matrix (transposed)
|
||||
let mut logits_f32 = vec![0.0f32; vocab];
|
||||
for v in 0..vocab.min(model.embeddings.len() / d_model) {
|
||||
let v_offset = v * d_model;
|
||||
let mut dot = 0.0f32;
|
||||
for d in 0..d_model.min(last_hidden.len()) {
|
||||
if v_offset + d < model.embeddings.len() {
|
||||
dot += last_hidden[d] * model.embeddings[v_offset + d];
|
||||
}
|
||||
}
|
||||
logits_f32[v] = dot;
|
||||
}
|
||||
|
||||
// Apply softmax and quantize
|
||||
softmax_inplace(&mut logits_f32);
|
||||
quantize_i16(&logits_f32)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: WASM is single-threaded, so these trait bounds are satisfied trivially
|
||||
// by never actually being used across threads
|
||||
unsafe impl Send for WasmSimBackend {}
|
||||
unsafe impl Sync for WasmSimBackend {}
|
||||
|
||||
impl TransformerBackend for WasmSimBackend {
|
||||
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
// Prepare model
|
||||
let model = self.prepare_model(artifact)?;
|
||||
let model_id = artifact.model_id();
|
||||
|
||||
// Store in state
|
||||
let mut state = self.state.borrow_mut();
|
||||
state.models.insert(model_id, model);
|
||||
state.stats.models_loaded += 1;
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
let start = js_sys::Date::now();
|
||||
|
||||
// Validate request
|
||||
req.validate()?;
|
||||
|
||||
// Get model (immutable borrow)
|
||||
let state = self.state.borrow();
|
||||
let model = state
|
||||
.models
|
||||
.get(&req.model)
|
||||
.ok_or_else(|| Error::ModelNotFound(req.model))?;
|
||||
|
||||
// Validate tokens
|
||||
validate_tokens(req.tokens, model.shape.vocab)?;
|
||||
|
||||
// Run inference
|
||||
let (logits, gate_decision) = self.run_inference(model, req.tokens, &req.gate_hint);
|
||||
|
||||
let latency_ns = ((js_sys::Date::now() - start) * 1_000_000.0) as u32;
|
||||
|
||||
// Compute top-K
|
||||
let topk = compute_topk(&logits, 16);
|
||||
|
||||
// Build witness
|
||||
let witness = WitnessLog::new(
|
||||
model.artifact.model_hash(),
|
||||
model.artifact.quant_hash(),
|
||||
BackendKind::WasmSim,
|
||||
0, // No cycles for WASM sim
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
);
|
||||
|
||||
drop(state); // Release borrow before mutable borrow
|
||||
|
||||
// Update stats
|
||||
{
|
||||
let mut state = self.state.borrow_mut();
|
||||
state.stats.total_inferences += 1;
|
||||
let n = state.stats.total_inferences;
|
||||
state.stats.avg_latency_ns =
|
||||
(state.stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
|
||||
match gate_decision {
|
||||
GateDecision::EarlyExit { .. } => state.stats.early_exits += 1,
|
||||
GateDecision::Skipped { .. } => state.stats.skipped += 1,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(InferenceResult::new(logits, Some(topk), witness))
|
||||
}
|
||||
|
||||
fn unload(&self, model: ModelId) -> Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
if state.models.remove(&model).is_some() {
|
||||
state.stats.models_loaded = state.stats.models_loaded.saturating_sub(1);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ModelNotFound(model))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loaded(&self, model: ModelId) -> bool {
|
||||
self.state.borrow().models.contains_key(&model)
|
||||
}
|
||||
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::WasmSim
|
||||
}
|
||||
|
||||
fn stats(&self) -> BackendStats {
|
||||
self.state.borrow().stats.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute coherence signal from hidden state
|
||||
fn compute_coherence(hidden: &[f32]) -> i16 {
|
||||
if hidden.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
|
||||
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
|
||||
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
|
||||
}
|
||||
|
||||
/// In-place softmax
|
||||
fn softmax_inplace(x: &mut [f32]) {
|
||||
if x.is_empty() {
|
||||
return;
|
||||
}
|
||||
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in x.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
if sum > 0.0 {
|
||||
for v in x.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::artifact::Manifest;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
|
||||
fn create_test_artifact() -> ModelArtifact {
|
||||
let manifest = Manifest {
|
||||
name: "wasm_test".into(),
|
||||
model_hash: "0".repeat(64),
|
||||
shape: FixedShape::micro(),
|
||||
quant: QuantSpec::int8(),
|
||||
io: Default::default(),
|
||||
backend: Default::default(),
|
||||
tests: Default::default(),
|
||||
};
|
||||
|
||||
ModelArtifact {
|
||||
manifest,
|
||||
weights: vec![0u8; 4096 * 64],
|
||||
bitstream: None,
|
||||
calibration: None,
|
||||
test_vectors: vec![],
|
||||
signature: [0u8; 64],
|
||||
pubkey: [0u8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wasm_sim_prepare_model() {
|
||||
let gate = Rc::new(DefaultCoherenceGate::new());
|
||||
let backend = WasmSimBackend::new(gate);
|
||||
let artifact = create_test_artifact();
|
||||
|
||||
let model = backend.prepare_model(&artifact).unwrap();
|
||||
assert_eq!(model.shape.seq_len, 32);
|
||||
assert!(!model.embeddings.is_empty());
|
||||
}
|
||||
}
|
||||
136
crates/ruvector-fpga-transformer/src/error.rs
Normal file
136
crates/ruvector-fpga-transformer/src/error.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
//! Error types for FPGA Transformer backend
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type alias for FPGA Transformer operations
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// FPGA Transformer error types
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
/// Model artifact is invalid or corrupted
|
||||
#[error("Invalid artifact: {0}")]
|
||||
InvalidArtifact(String),
|
||||
|
||||
/// Artifact signature verification failed
|
||||
#[error("Signature verification failed: {0}")]
|
||||
SignatureError(String),
|
||||
|
||||
/// Test vectors failed validation
|
||||
#[error("Test vector validation failed: expected max error {expected}, got {actual}")]
|
||||
TestVectorError { expected: i32, actual: i32 },
|
||||
|
||||
/// Model not found or not loaded
|
||||
#[error("Model not found: {0:?}")]
|
||||
ModelNotFound(crate::types::ModelId),
|
||||
|
||||
/// Shape mismatch between request and model
|
||||
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
|
||||
ShapeMismatch {
|
||||
expected: crate::types::FixedShape,
|
||||
actual: crate::types::FixedShape,
|
||||
},
|
||||
|
||||
/// Input length does not match expected sequence length
|
||||
#[error("Input length mismatch: expected {expected}, got {actual}")]
|
||||
InputLengthMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Backend communication error
|
||||
#[error("Backend error: {0}")]
|
||||
BackendError(String),
|
||||
|
||||
/// Daemon connection failed
|
||||
#[error("Daemon connection failed: {0}")]
|
||||
DaemonConnectionError(String),
|
||||
|
||||
/// PCIe communication error
|
||||
#[error("PCIe error: {0}")]
|
||||
PcieError(String),
|
||||
|
||||
/// DMA operation failed
|
||||
#[error("DMA error: {0}")]
|
||||
DmaError(String),
|
||||
|
||||
/// Gating decision blocked inference
|
||||
#[error("Inference blocked by gate: {reason:?}")]
|
||||
GateBlocked { reason: crate::types::SkipReason },
|
||||
|
||||
/// Quantization error
|
||||
#[error("Quantization error: {0}")]
|
||||
QuantizationError(String),
|
||||
|
||||
/// Overflow during fixed-point computation
|
||||
#[error("Fixed-point overflow at {location}")]
|
||||
FixedPointOverflow { location: &'static str },
|
||||
|
||||
/// Invalid configuration
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
/// JSON parsing error
|
||||
#[error("JSON error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
/// Checksum mismatch
|
||||
#[error("Checksum mismatch: expected {expected:08x}, got {actual:08x}")]
|
||||
ChecksumMismatch { expected: u32, actual: u32 },
|
||||
|
||||
/// Protocol version mismatch
|
||||
#[error("Protocol version mismatch: expected {expected}, got {actual}")]
|
||||
ProtocolMismatch { expected: u16, actual: u16 },
|
||||
|
||||
/// Timeout waiting for response
|
||||
#[error("Timeout after {ms}ms")]
|
||||
Timeout { ms: u64 },
|
||||
|
||||
/// Resource exhausted (memory, slots, etc.)
|
||||
#[error("Resource exhausted: {0}")]
|
||||
ResourceExhausted(String),
|
||||
|
||||
/// Feature not available in this build
|
||||
#[error("Feature not available: {0}")]
|
||||
FeatureNotAvailable(String),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Create a new InvalidArtifact error
|
||||
pub fn invalid_artifact(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidArtifact(msg.into())
|
||||
}
|
||||
|
||||
/// Create a new BackendError
|
||||
pub fn backend(msg: impl Into<String>) -> Self {
|
||||
Self::BackendError(msg.into())
|
||||
}
|
||||
|
||||
/// Create a new DaemonConnectionError
|
||||
pub fn daemon_connection(msg: impl Into<String>) -> Self {
|
||||
Self::DaemonConnectionError(msg.into())
|
||||
}
|
||||
|
||||
/// Check if this error is recoverable
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Error::Timeout { .. }
|
||||
| Error::DaemonConnectionError(_)
|
||||
| Error::BackendError(_)
|
||||
| Error::GateBlocked { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this error indicates a configuration problem
|
||||
pub fn is_config_error(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Error::InvalidConfig(_)
|
||||
| Error::ShapeMismatch { .. }
|
||||
| Error::InputLengthMismatch { .. }
|
||||
| Error::FeatureNotAvailable(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
302
crates/ruvector-fpga-transformer/src/ffi/c_abi.rs
Normal file
302
crates/ruvector-fpga-transformer/src/ffi/c_abi.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
//! C ABI bindings for FFI integration
|
||||
//!
|
||||
//! Provides a stable C interface for linking from other languages.
|
||||
|
||||
use std::ffi::{c_char, c_int, c_void, CStr};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::backend::native_sim::NativeSimBackend;
|
||||
use crate::backend::TransformerBackend;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
|
||||
|
||||
/// Opaque engine handle
|
||||
pub struct FpgaEngine {
|
||||
backend: Box<dyn TransformerBackend>,
|
||||
}
|
||||
|
||||
/// Result code
|
||||
#[repr(C)]
|
||||
pub enum FpgaResult {
|
||||
Ok = 0,
|
||||
InvalidArgument = 1,
|
||||
ModelNotFound = 2,
|
||||
InferenceFailed = 3,
|
||||
AllocationFailed = 4,
|
||||
InvalidArtifact = 5,
|
||||
}
|
||||
|
||||
/// Inference result structure
|
||||
#[repr(C)]
|
||||
pub struct FpgaInferenceResult {
|
||||
/// Status code
|
||||
pub status: FpgaResult,
|
||||
/// Logits (caller must free with fpga_free_logits)
|
||||
pub logits: *mut i16,
|
||||
/// Number of logits
|
||||
pub logits_len: usize,
|
||||
/// Top-K results (token_id, logit pairs)
|
||||
pub topk: *mut u32,
|
||||
/// Number of top-K pairs
|
||||
pub topk_len: usize,
|
||||
/// Latency in nanoseconds
|
||||
pub latency_ns: u32,
|
||||
/// Compute cycles
|
||||
pub cycles: u32,
|
||||
/// Gate decision (0=full, 1=early_exit, 2=skipped)
|
||||
pub gate_decision: u8,
|
||||
/// Exit layer (if early exit)
|
||||
pub exit_layer: u8,
|
||||
}
|
||||
|
||||
/// Create a new FPGA engine with native simulator backend
|
||||
///
|
||||
/// Returns a handle that must be freed with `fpga_engine_destroy`
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_engine_create() -> *mut FpgaEngine {
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = Box::new(NativeSimBackend::new(gate));
|
||||
|
||||
let engine = Box::new(FpgaEngine { backend });
|
||||
Box::into_raw(engine)
|
||||
}
|
||||
|
||||
/// Destroy an FPGA engine
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_engine_destroy(engine: *mut FpgaEngine) {
|
||||
if !engine.is_null() {
|
||||
unsafe {
|
||||
drop(Box::from_raw(engine));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a model artifact
|
||||
///
|
||||
/// Returns model ID bytes (32 bytes) on success, NULL on failure
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_load_artifact(
|
||||
engine: *mut FpgaEngine,
|
||||
artifact_bytes: *const u8,
|
||||
artifact_len: usize,
|
||||
model_id_out: *mut u8,
|
||||
) -> FpgaResult {
|
||||
if engine.is_null() || artifact_bytes.is_null() || model_id_out.is_null() {
|
||||
return FpgaResult::InvalidArgument;
|
||||
}
|
||||
|
||||
let engine = unsafe { &mut *engine };
|
||||
let artifact_slice = unsafe { std::slice::from_raw_parts(artifact_bytes, artifact_len) };
|
||||
|
||||
let artifact = match crate::artifact::unpack_artifact(artifact_slice) {
|
||||
Ok(a) => a,
|
||||
Err(_) => return FpgaResult::InvalidArtifact,
|
||||
};
|
||||
|
||||
match engine.backend.load(&artifact) {
|
||||
Ok(model_id) => {
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(model_id.as_bytes().as_ptr(), model_id_out, 32);
|
||||
}
|
||||
FpgaResult::Ok
|
||||
}
|
||||
Err(_) => FpgaResult::InvalidArtifact,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run inference
|
||||
///
|
||||
/// Result must be freed with `fpga_result_free`
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_infer(
|
||||
engine: *mut FpgaEngine,
|
||||
model_id: *const u8,
|
||||
tokens: *const u16,
|
||||
tokens_len: usize,
|
||||
mask: *const u8,
|
||||
mask_len: usize,
|
||||
coherence_score: i16,
|
||||
boundary_crossed: bool,
|
||||
max_compute_class: u8,
|
||||
) -> FpgaInferenceResult {
|
||||
let error_result = || FpgaInferenceResult {
|
||||
status: FpgaResult::InvalidArgument,
|
||||
logits: ptr::null_mut(),
|
||||
logits_len: 0,
|
||||
topk: ptr::null_mut(),
|
||||
topk_len: 0,
|
||||
latency_ns: 0,
|
||||
cycles: 0,
|
||||
gate_decision: 2,
|
||||
exit_layer: 0,
|
||||
};
|
||||
|
||||
if engine.is_null() || model_id.is_null() || tokens.is_null() || mask.is_null() {
|
||||
return error_result();
|
||||
}
|
||||
|
||||
let engine = unsafe { &mut *engine };
|
||||
|
||||
// Parse model ID
|
||||
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(id_slice);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
// Parse tokens and mask
|
||||
let tokens_slice = unsafe { std::slice::from_raw_parts(tokens, tokens_len) };
|
||||
let mask_slice = unsafe { std::slice::from_raw_parts(mask, mask_len) };
|
||||
|
||||
// Build shape (micro for C API)
|
||||
let shape = FixedShape::micro();
|
||||
|
||||
// Build gate hint
|
||||
let compute_class =
|
||||
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
|
||||
let gate_hint = GateHint::new(coherence_score, boundary_crossed, compute_class);
|
||||
|
||||
// Create request
|
||||
let req = InferenceRequest::new(model, shape, tokens_slice, mask_slice, gate_hint);
|
||||
|
||||
// Run inference
|
||||
match engine.backend.infer(req) {
|
||||
Ok(result) => {
|
||||
// Allocate logits with checked allocation (prevents panic on overflow)
|
||||
let logits_len = result.logits_q.len();
|
||||
let logits = if logits_len > 0 {
|
||||
match std::alloc::Layout::array::<i16>(logits_len) {
|
||||
Ok(layout) if layout.size() > 0 => {
|
||||
let ptr = unsafe { std::alloc::alloc(layout) as *mut i16 };
|
||||
if !ptr.is_null() {
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(result.logits_q.as_ptr(), ptr, logits_len);
|
||||
}
|
||||
}
|
||||
ptr
|
||||
}
|
||||
_ => ptr::null_mut(), // Return null on allocation failure
|
||||
}
|
||||
} else {
|
||||
ptr::null_mut()
|
||||
};
|
||||
|
||||
// Allocate top-K with checked allocation
|
||||
let (topk, topk_len) = if let Some(ref tk) = result.topk {
|
||||
let len = tk.len() * 2; // (token, logit) pairs
|
||||
match std::alloc::Layout::array::<u32>(len) {
|
||||
Ok(layout) if layout.size() > 0 => {
|
||||
let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 };
|
||||
if !ptr.is_null() {
|
||||
for (i, (token, logit)) in tk.iter().enumerate() {
|
||||
unsafe {
|
||||
*ptr.add(i * 2) = *token as u32;
|
||||
*ptr.add(i * 2 + 1) = *logit as u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
(ptr, tk.len())
|
||||
}
|
||||
_ => (ptr::null_mut(), 0), // Return null on allocation failure
|
||||
}
|
||||
} else {
|
||||
(ptr::null_mut(), 0)
|
||||
};
|
||||
|
||||
// Encode gate decision
|
||||
let (gate_decision, exit_layer) = match result.witness.gate_decision {
|
||||
crate::types::GateDecision::RanFull => (0, 0),
|
||||
crate::types::GateDecision::EarlyExit { layer } => (1, layer),
|
||||
crate::types::GateDecision::Skipped { .. } => (2, 0),
|
||||
};
|
||||
|
||||
FpgaInferenceResult {
|
||||
status: FpgaResult::Ok,
|
||||
logits,
|
||||
logits_len,
|
||||
topk,
|
||||
topk_len,
|
||||
latency_ns: result.witness.latency_ns,
|
||||
cycles: result.witness.cycles,
|
||||
gate_decision,
|
||||
exit_layer,
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let mut result = error_result();
|
||||
result.status = FpgaResult::InferenceFailed;
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Free inference result
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_result_free(result: *mut FpgaInferenceResult) {
|
||||
if result.is_null() {
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let r = &mut *result;
|
||||
|
||||
if !r.logits.is_null() && r.logits_len > 0 {
|
||||
std::alloc::dealloc(
|
||||
r.logits as *mut u8,
|
||||
std::alloc::Layout::array::<i16>(r.logits_len).unwrap(),
|
||||
);
|
||||
r.logits = ptr::null_mut();
|
||||
}
|
||||
|
||||
if !r.topk.is_null() && r.topk_len > 0 {
|
||||
std::alloc::dealloc(
|
||||
r.topk as *mut u8,
|
||||
std::alloc::Layout::array::<u32>(r.topk_len * 2).unwrap(),
|
||||
);
|
||||
r.topk = ptr::null_mut();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unload a model
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_unload(engine: *mut FpgaEngine, model_id: *const u8) -> FpgaResult {
|
||||
if engine.is_null() || model_id.is_null() {
|
||||
return FpgaResult::InvalidArgument;
|
||||
}
|
||||
|
||||
let engine = unsafe { &mut *engine };
|
||||
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(id_slice);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
match engine.backend.unload(model) {
|
||||
Ok(()) => FpgaResult::Ok,
|
||||
Err(_) => FpgaResult::ModelNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a model is loaded
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_is_loaded(engine: *const FpgaEngine, model_id: *const u8) -> bool {
|
||||
if engine.is_null() || model_id.is_null() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let engine = unsafe { &*engine };
|
||||
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(id_slice);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
engine.backend.is_loaded(model)
|
||||
}
|
||||
|
||||
/// Get version string
|
||||
#[no_mangle]
|
||||
pub extern "C" fn fpga_version() -> *const c_char {
|
||||
// Static string with null terminator
|
||||
static VERSION: &[u8] = b"0.1.0\0";
|
||||
VERSION.as_ptr() as *const c_char
|
||||
}
|
||||
8
crates/ruvector-fpga-transformer/src/ffi/mod.rs
Normal file
8
crates/ruvector-fpga-transformer/src/ffi/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! Foreign function interfaces for FPGA Transformer
|
||||
//!
|
||||
//! Provides C ABI and WASM bindings.
|
||||
|
||||
#[cfg(feature = "wasm")]
|
||||
pub mod wasm_bindgen;
|
||||
|
||||
pub mod c_abi;
|
||||
282
crates/ruvector-fpga-transformer/src/ffi/wasm_bindgen.rs
Normal file
282
crates/ruvector-fpga-transformer/src/ffi/wasm_bindgen.rs
Normal file
@@ -0,0 +1,282 @@
|
||||
//! WASM bindings via wasm-bindgen
|
||||
//!
|
||||
//! Provides the same API shape for browser and Node.js environments.
|
||||
|
||||
#![cfg(feature = "wasm")]
|
||||
|
||||
use js_sys::{Array, Int16Array, Object, Reflect, Uint16Array, Uint8Array};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::artifact::{unpack_artifact, ModelArtifact};
|
||||
use crate::backend::native_sim::{NativeSimBackend, NativeSimConfig};
|
||||
use crate::backend::TransformerBackend;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// WASM Engine for transformer inference
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmEngine {
|
||||
backend: NativeSimBackend,
|
||||
loaded_models: Vec<ModelId>,
|
||||
last_witness: Option<crate::types::WitnessLog>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEngine {
|
||||
/// Create a new WASM engine
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
// Use permissive config for WASM
|
||||
let config = NativeSimConfig {
|
||||
max_models: 4,
|
||||
trace: false,
|
||||
lut_softmax: true,
|
||||
max_layers: 0,
|
||||
};
|
||||
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::with_config(gate, config);
|
||||
|
||||
Self {
|
||||
backend,
|
||||
loaded_models: Vec::new(),
|
||||
last_witness: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a model artifact from bytes
|
||||
///
|
||||
/// Returns the model ID as a Uint8Array on success
|
||||
#[wasm_bindgen(js_name = loadArtifact)]
|
||||
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<Uint8Array, JsValue> {
|
||||
let artifact = unpack_artifact(artifact_bytes)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to unpack artifact: {}", e)))?;
|
||||
|
||||
let model_id = self
|
||||
.backend
|
||||
.load(&artifact)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to load model: {}", e)))?;
|
||||
|
||||
self.loaded_models.push(model_id);
|
||||
|
||||
// Return model ID as Uint8Array
|
||||
let id_array = Uint8Array::new_with_length(32);
|
||||
id_array.copy_from(model_id.as_bytes());
|
||||
Ok(id_array)
|
||||
}
|
||||
|
||||
/// Run inference
|
||||
///
|
||||
/// Returns an object with logits, topk, and witness
|
||||
#[wasm_bindgen]
|
||||
pub fn infer(
|
||||
&mut self,
|
||||
model_id: &[u8],
|
||||
tokens: &[u16],
|
||||
mask: &[u8],
|
||||
coherence_score_q: i16,
|
||||
boundary_crossed: bool,
|
||||
max_compute_class: u8,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
// Parse model ID
|
||||
if model_id.len() != 32 {
|
||||
return Err(JsValue::from_str("Model ID must be 32 bytes"));
|
||||
}
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(model_id);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
// Get shape from loaded model
|
||||
// For WASM, we use micro shape by default
|
||||
let shape = FixedShape::micro();
|
||||
|
||||
// Validate input length
|
||||
if tokens.len() != shape.seq_len as usize {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Token length mismatch: expected {}, got {}",
|
||||
shape.seq_len,
|
||||
tokens.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Build gate hint
|
||||
let compute_class =
|
||||
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
|
||||
let gate_hint = GateHint::new(coherence_score_q, boundary_crossed, compute_class);
|
||||
|
||||
// Create request
|
||||
let req = InferenceRequest::new(model, shape, tokens, mask, gate_hint);
|
||||
|
||||
// Run inference
|
||||
let result = self
|
||||
.backend
|
||||
.infer(req)
|
||||
.map_err(|e| JsValue::from_str(&format!("Inference failed: {}", e)))?;
|
||||
|
||||
// Store witness
|
||||
self.last_witness = Some(result.witness.clone());
|
||||
|
||||
// Build result object
|
||||
let obj = Object::new();
|
||||
|
||||
// Add logits
|
||||
let logits = Int16Array::new_with_length(result.logits_q.len() as u32);
|
||||
logits.copy_from(&result.logits_q);
|
||||
Reflect::set(&obj, &"logits".into(), &logits)?;
|
||||
|
||||
// Add top-K if available
|
||||
if let Some(topk) = &result.topk {
|
||||
let topk_array = Array::new();
|
||||
for (token, logit) in topk {
|
||||
let pair = Array::new();
|
||||
pair.push(&JsValue::from(*token));
|
||||
pair.push(&JsValue::from(*logit));
|
||||
topk_array.push(&pair);
|
||||
}
|
||||
Reflect::set(&obj, &"topk".into(), &topk_array)?;
|
||||
}
|
||||
|
||||
// Add witness info
|
||||
let witness = Object::new();
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"backend".into(),
|
||||
&format!("{:?}", result.witness.backend).into(),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"cycles".into(),
|
||||
&JsValue::from(result.witness.cycles),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"latency_ns".into(),
|
||||
&JsValue::from(result.witness.latency_ns),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"gate_decision".into(),
|
||||
&format!("{:?}", result.witness.gate_decision).into(),
|
||||
)?;
|
||||
Reflect::set(&obj, &"witness".into(), &witness)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Get the last witness log as JSON
|
||||
#[wasm_bindgen(js_name = getWitness)]
|
||||
pub fn get_witness(&self) -> Result<JsValue, JsValue> {
|
||||
match &self.last_witness {
|
||||
Some(witness) => {
|
||||
let json = serde_json::to_string(witness)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))?;
|
||||
Ok(JsValue::from_str(&json))
|
||||
}
|
||||
None => Ok(JsValue::NULL),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of loaded model IDs
|
||||
#[wasm_bindgen(js_name = getLoadedModels)]
|
||||
pub fn get_loaded_models(&self) -> Array {
|
||||
let arr = Array::new();
|
||||
for id in &self.loaded_models {
|
||||
let id_array = Uint8Array::new_with_length(32);
|
||||
id_array.copy_from(id.as_bytes());
|
||||
arr.push(&id_array);
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
/// Unload a model
|
||||
#[wasm_bindgen]
|
||||
pub fn unload(&mut self, model_id: &[u8]) -> Result<(), JsValue> {
|
||||
if model_id.len() != 32 {
|
||||
return Err(JsValue::from_str("Model ID must be 32 bytes"));
|
||||
}
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(model_id);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
self.backend
|
||||
.unload(model)
|
||||
.map_err(|e| JsValue::from_str(&format!("Unload failed: {}", e)))?;
|
||||
|
||||
self.loaded_models.retain(|id| *id != model);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get backend statistics
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats(&self) -> Result<JsValue, JsValue> {
|
||||
let stats = self.backend.stats();
|
||||
let obj = Object::new();
|
||||
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"models_loaded".into(),
|
||||
&JsValue::from(stats.models_loaded as u32),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"total_inferences".into(),
|
||||
&JsValue::from(stats.total_inferences as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"avg_latency_ns".into(),
|
||||
&JsValue::from(stats.avg_latency_ns as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"early_exits".into(),
|
||||
&JsValue::from(stats.early_exits as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"skipped".into(),
|
||||
&JsValue::from(stats.skipped as f64),
|
||||
)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WasmEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to create a micro shape configuration
|
||||
#[wasm_bindgen(js_name = microShape)]
|
||||
pub fn micro_shape() -> Result<JsValue, JsValue> {
|
||||
let shape = FixedShape::micro();
|
||||
let obj = Object::new();
|
||||
|
||||
Reflect::set(&obj, &"seq_len".into(), &JsValue::from(shape.seq_len))?;
|
||||
Reflect::set(&obj, &"d_model".into(), &JsValue::from(shape.d_model))?;
|
||||
Reflect::set(&obj, &"heads".into(), &JsValue::from(shape.heads))?;
|
||||
Reflect::set(&obj, &"d_head".into(), &JsValue::from(shape.d_head))?;
|
||||
Reflect::set(&obj, &"vocab".into(), &JsValue::from(shape.vocab))?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Utility function to validate an artifact without loading
|
||||
#[wasm_bindgen(js_name = validateArtifact)]
|
||||
pub fn validate_artifact(artifact_bytes: &[u8]) -> Result<JsValue, JsValue> {
|
||||
let artifact = unpack_artifact(artifact_bytes)
|
||||
.map_err(|e| JsValue::from_str(&format!("Invalid artifact: {}", e)))?;
|
||||
|
||||
artifact
|
||||
.validate()
|
||||
.map_err(|e| JsValue::from_str(&format!("Validation failed: {}", e)))?;
|
||||
|
||||
let obj = Object::new();
|
||||
Reflect::set(&obj, &"name".into(), &artifact.manifest.name.into())?;
|
||||
Reflect::set(&obj, &"valid".into(), &JsValue::TRUE)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
301
crates/ruvector-fpga-transformer/src/gating/coherence_gate.rs
Normal file
301
crates/ruvector-fpga-transformer/src/gating/coherence_gate.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
//! Coherence-based gating for inference control
|
||||
|
||||
use crate::types::{ComputeClass, GateDecision, GateHint, SkipReason};
|
||||
use crate::witness::WitnessLog;
|
||||
|
||||
/// Trait for coherence-based gating
|
||||
pub trait CoherenceGate: Send + Sync {
|
||||
/// Preflight check before inference
|
||||
///
|
||||
/// Returns a gate decision based on coherence signals.
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision;
|
||||
|
||||
/// Layer checkpoint for early exit decisions
|
||||
///
|
||||
/// Called after each layer to determine if early exit is appropriate.
|
||||
/// Returns Some(decision) to exit early, None to continue.
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision>;
|
||||
|
||||
/// Check if write is allowed based on witness
|
||||
///
|
||||
/// Used to gate state changes in memory systems.
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool;
|
||||
}
|
||||
|
||||
/// Configuration for coherence gate
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoherenceConfig {
|
||||
/// Minimum coherence score to run (Q8.8)
|
||||
pub min_coherence: i16,
|
||||
/// Coherence threshold for early exit
|
||||
pub early_exit_threshold: i16,
|
||||
/// Enable early exit
|
||||
pub early_exit_enabled: bool,
|
||||
/// Minimum layers before early exit
|
||||
pub min_layers: u8,
|
||||
/// Require stable coherence for writes
|
||||
pub require_stable_for_write: bool,
|
||||
/// Minimum coherence for writes (Q8.8)
|
||||
pub min_write_coherence: i16,
|
||||
}
|
||||
|
||||
impl Default for CoherenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_coherence: -256, // -1.0 in Q8.8, very permissive
|
||||
early_exit_threshold: 512, // 2.0 in Q8.8
|
||||
early_exit_enabled: true,
|
||||
min_layers: 2,
|
||||
require_stable_for_write: true,
|
||||
min_write_coherence: 0, // Require non-negative coherence
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceConfig {
|
||||
/// Create a strict configuration
|
||||
pub fn strict() -> Self {
|
||||
Self {
|
||||
min_coherence: 0,
|
||||
early_exit_threshold: 256,
|
||||
early_exit_enabled: true,
|
||||
min_layers: 4,
|
||||
require_stable_for_write: true,
|
||||
min_write_coherence: 128,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permissive configuration (always allows)
|
||||
pub fn permissive() -> Self {
|
||||
Self {
|
||||
min_coherence: i16::MIN,
|
||||
early_exit_threshold: i16::MAX,
|
||||
early_exit_enabled: false,
|
||||
min_layers: 0,
|
||||
require_stable_for_write: false,
|
||||
min_write_coherence: i16::MIN,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Default coherence gate implementation
|
||||
pub struct DefaultCoherenceGate {
|
||||
config: CoherenceConfig,
|
||||
}
|
||||
|
||||
impl DefaultCoherenceGate {
|
||||
/// Create with default config
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(CoherenceConfig::default())
|
||||
}
|
||||
|
||||
/// Create with custom config
|
||||
pub fn with_config(config: CoherenceConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Check if compute class allows operation
|
||||
fn check_compute_class(&self, hint: &GateHint) -> bool {
|
||||
// Reflex class can always run (fast path)
|
||||
// Higher classes require sufficient coherence
|
||||
match hint.max_compute_class {
|
||||
ComputeClass::Reflex => true,
|
||||
ComputeClass::Associative => hint.coherence_score_q >= self.config.min_coherence / 2,
|
||||
ComputeClass::Deliberative => hint.coherence_score_q >= self.config.min_coherence,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefaultCoherenceGate {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceGate for DefaultCoherenceGate {
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
// Check minimum coherence
|
||||
if hint.coherence_score_q < self.config.min_coherence {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence,
|
||||
};
|
||||
}
|
||||
|
||||
// Check compute class restrictions
|
||||
if !self.check_compute_class(hint) {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::BudgetExceeded,
|
||||
};
|
||||
}
|
||||
|
||||
// Allow full inference
|
||||
GateDecision::RanFull
|
||||
}
|
||||
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
if !self.config.early_exit_enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
if layer < self.config.min_layers {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if coherence signal is high enough to exit early
|
||||
if signal_q >= self.config.early_exit_threshold {
|
||||
return Some(GateDecision::EarlyExit { layer });
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
// Skip writes if inference was skipped
|
||||
if !witness.gate_decision.did_run() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If we require stable coherence, only allow writes after full run
|
||||
if self.config.require_stable_for_write {
|
||||
matches!(witness.gate_decision, GateDecision::RanFull)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Mincut-aware coherence gate
|
||||
///
|
||||
/// Uses mincut signals to make more informed gating decisions.
|
||||
pub struct MincutCoherenceGate {
|
||||
base: DefaultCoherenceGate,
|
||||
/// Minimum lambda (mincut value) for inference
|
||||
pub min_lambda: i16,
|
||||
/// Lambda threshold for early exit
|
||||
pub lambda_exit_threshold: i16,
|
||||
}
|
||||
|
||||
impl MincutCoherenceGate {
|
||||
/// Create a new mincut-aware gate
|
||||
pub fn new(config: CoherenceConfig, min_lambda: i16, lambda_exit_threshold: i16) -> Self {
|
||||
Self {
|
||||
base: DefaultCoherenceGate::with_config(config),
|
||||
min_lambda,
|
||||
lambda_exit_threshold,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceGate for MincutCoherenceGate {
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
// Use base coherence check
|
||||
let base_decision = self.base.preflight(hint);
|
||||
if !base_decision.did_run() {
|
||||
return base_decision;
|
||||
}
|
||||
|
||||
// Additional mincut check
|
||||
// If boundary was crossed and coherence is low, skip
|
||||
if hint.boundary_crossed && hint.coherence_score_q < 0 {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence,
|
||||
};
|
||||
}
|
||||
|
||||
GateDecision::RanFull
|
||||
}
|
||||
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
// Use base checkpoint with mincut-adjusted threshold
|
||||
let adjusted_threshold = if signal_q > self.lambda_exit_threshold {
|
||||
// High lambda suggests stable state, lower exit threshold
|
||||
self.base.config.early_exit_threshold / 2
|
||||
} else {
|
||||
self.base.config.early_exit_threshold
|
||||
};
|
||||
|
||||
if layer >= self.base.config.min_layers && signal_q >= adjusted_threshold {
|
||||
return Some(GateDecision::EarlyExit { layer });
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
// Use base write check
|
||||
self.base.allow_write(witness)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_gate_preflight() {
|
||||
let gate = DefaultCoherenceGate::new();
|
||||
|
||||
// High coherence should pass
|
||||
let hint = GateHint::new(256, false, ComputeClass::Deliberative);
|
||||
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
|
||||
|
||||
// Low coherence should fail
|
||||
let hint = GateHint::new(-512, false, ComputeClass::Deliberative);
|
||||
assert!(matches!(
|
||||
gate.preflight(&hint),
|
||||
GateDecision::Skipped { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_exit_checkpoint() {
|
||||
let gate = DefaultCoherenceGate::new();
|
||||
|
||||
// Layer 0 - too early
|
||||
assert!(gate.checkpoint(0, 1000).is_none());
|
||||
|
||||
// Layer 4 with high signal - should exit
|
||||
let decision = gate.checkpoint(4, 1000);
|
||||
assert!(matches!(
|
||||
decision,
|
||||
Some(GateDecision::EarlyExit { layer: 4 })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_gating() {
|
||||
let gate = DefaultCoherenceGate::new();
|
||||
|
||||
// Full run should allow writes
|
||||
let witness = crate::witness::WitnessLog::empty();
|
||||
assert!(gate.allow_write(&witness));
|
||||
|
||||
// Skipped should not allow writes
|
||||
let mut skipped_witness = crate::witness::WitnessLog::empty();
|
||||
skipped_witness.gate_decision = GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence,
|
||||
};
|
||||
assert!(!gate.allow_write(&skipped_witness));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strict_config() {
|
||||
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::strict());
|
||||
|
||||
// Strict should require positive coherence
|
||||
let hint = GateHint::new(-1, false, ComputeClass::Deliberative);
|
||||
assert!(matches!(
|
||||
gate.preflight(&hint),
|
||||
GateDecision::Skipped { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permissive_config() {
|
||||
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::permissive());
|
||||
|
||||
// Permissive should allow anything
|
||||
let hint = GateHint::new(i16::MIN, true, ComputeClass::Reflex);
|
||||
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
|
||||
}
|
||||
}
|
||||
95
crates/ruvector-fpga-transformer/src/gating/mod.rs
Normal file
95
crates/ruvector-fpga-transformer/src/gating/mod.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
//! Gating subsystem for coherence-based inference control
|
||||
//!
|
||||
//! Provides preflight and postflight gates that integrate mincut signals
|
||||
//! and write policies for memory safety.
|
||||
|
||||
pub mod coherence_gate;
|
||||
pub mod policy_gate;
|
||||
|
||||
pub use coherence_gate::{CoherenceConfig, CoherenceGate, DefaultCoherenceGate};
|
||||
pub use policy_gate::{DefaultPolicyGate, PolicyGate, WritePolicy};
|
||||
|
||||
use crate::types::{GateDecision, GateHint, SkipReason};
|
||||
use crate::witness::WitnessLog;
|
||||
|
||||
/// Combined gate that checks both coherence and policy
|
||||
pub struct CombinedGate {
|
||||
coherence: Box<dyn CoherenceGate>,
|
||||
policy: Box<dyn PolicyGate>,
|
||||
}
|
||||
|
||||
impl CombinedGate {
|
||||
/// Create a new combined gate
|
||||
pub fn new(coherence: Box<dyn CoherenceGate>, policy: Box<dyn PolicyGate>) -> Self {
|
||||
Self { coherence, policy }
|
||||
}
|
||||
|
||||
/// Create with default implementations
|
||||
pub fn default_gates() -> Self {
|
||||
Self::new(
|
||||
Box::new(DefaultCoherenceGate::new()),
|
||||
Box::new(DefaultPolicyGate::new()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Preflight check before inference
|
||||
pub fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
// First check policy
|
||||
if !self.policy.allow_inference(hint) {
|
||||
return GateDecision::Skipped {
|
||||
reason: SkipReason::PolicyDenied,
|
||||
};
|
||||
}
|
||||
|
||||
// Then check coherence
|
||||
self.coherence.preflight(hint)
|
||||
}
|
||||
|
||||
/// Checkpoint during inference
|
||||
pub fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
self.coherence.checkpoint(layer, signal_q)
|
||||
}
|
||||
|
||||
/// Check if write is allowed after inference
|
||||
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
self.coherence.allow_write(witness) && self.policy.allow_write(witness)
|
||||
}
|
||||
}
|
||||
|
||||
impl CoherenceGate for CombinedGate {
|
||||
fn preflight(&self, hint: &GateHint) -> GateDecision {
|
||||
CombinedGate::preflight(self, hint)
|
||||
}
|
||||
|
||||
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
|
||||
CombinedGate::checkpoint(self, layer, signal_q)
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
CombinedGate::allow_write(self, witness)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_combined_gate_preflight() {
|
||||
let gate = CombinedGate::default_gates();
|
||||
|
||||
// Allow all hint should pass
|
||||
let decision = gate.preflight(&GateHint::allow_all());
|
||||
assert!(matches!(decision, GateDecision::RanFull));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_combined_gate_low_coherence() {
|
||||
let gate = CombinedGate::default_gates();
|
||||
|
||||
// Very low coherence should skip
|
||||
let hint = GateHint::new(-1000, false, crate::types::ComputeClass::Reflex);
|
||||
let decision = gate.preflight(&hint);
|
||||
assert!(matches!(decision, GateDecision::Skipped { .. }));
|
||||
}
|
||||
}
|
||||
305
crates/ruvector-fpga-transformer/src/gating/policy_gate.rs
Normal file
305
crates/ruvector-fpga-transformer/src/gating/policy_gate.rs
Normal file
@@ -0,0 +1,305 @@
|
||||
//! Policy-based gating for access control and resource management
|
||||
|
||||
use crate::types::{ComputeClass, GateHint};
|
||||
use crate::witness::WitnessLog;
|
||||
|
||||
/// Trait for policy-based gating
|
||||
pub trait PolicyGate: Send + Sync {
|
||||
/// Check if inference is allowed
|
||||
fn allow_inference(&self, hint: &GateHint) -> bool;
|
||||
|
||||
/// Check if write is allowed after inference
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool;
|
||||
|
||||
/// Get remaining compute budget
|
||||
fn remaining_budget(&self) -> Option<u64>;
|
||||
|
||||
/// Record compute usage
|
||||
fn record_usage(&self, cycles: u32);
|
||||
}
|
||||
|
||||
/// Write policy configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WritePolicy {
|
||||
/// Allow writes after early exit
|
||||
pub allow_early_exit_writes: bool,
|
||||
/// Maximum latency (ns) for write eligibility
|
||||
pub max_latency_ns: u32,
|
||||
/// Require specific backend
|
||||
pub required_backend: Option<crate::types::BackendKind>,
|
||||
/// Minimum compute class for writes
|
||||
pub min_compute_class: ComputeClass,
|
||||
}
|
||||
|
||||
impl Default for WritePolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allow_early_exit_writes: false,
|
||||
max_latency_ns: u32::MAX,
|
||||
required_backend: None,
|
||||
min_compute_class: ComputeClass::Reflex,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WritePolicy {
|
||||
/// Create a strict write policy
|
||||
pub fn strict() -> Self {
|
||||
Self {
|
||||
allow_early_exit_writes: false,
|
||||
max_latency_ns: 10_000_000, // 10ms
|
||||
required_backend: None,
|
||||
min_compute_class: ComputeClass::Deliberative,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permissive write policy
|
||||
pub fn permissive() -> Self {
|
||||
Self {
|
||||
allow_early_exit_writes: true,
|
||||
max_latency_ns: u32::MAX,
|
||||
required_backend: None,
|
||||
min_compute_class: ComputeClass::Reflex,
|
||||
}
|
||||
}
|
||||
|
||||
/// Require FPGA backend for writes
|
||||
pub fn require_fpga(mut self) -> Self {
|
||||
self.required_backend = Some(crate::types::BackendKind::FpgaPcie);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Default policy gate implementation
|
||||
pub struct DefaultPolicyGate {
|
||||
write_policy: WritePolicy,
|
||||
/// Compute budget (total cycles allowed, 0 = unlimited)
|
||||
budget_cycles: std::sync::atomic::AtomicU64,
|
||||
/// Used cycles
|
||||
used_cycles: std::sync::atomic::AtomicU64,
|
||||
}
|
||||
|
||||
impl DefaultPolicyGate {
|
||||
/// Create with default policy
|
||||
pub fn new() -> Self {
|
||||
Self::with_policy(WritePolicy::default())
|
||||
}
|
||||
|
||||
/// Create with custom write policy
|
||||
pub fn with_policy(write_policy: WritePolicy) -> Self {
|
||||
Self {
|
||||
write_policy,
|
||||
budget_cycles: std::sync::atomic::AtomicU64::new(0),
|
||||
used_cycles: std::sync::atomic::AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set compute budget
|
||||
pub fn set_budget(&self, cycles: u64) {
|
||||
self.budget_cycles
|
||||
.store(cycles, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Reset used cycles
|
||||
pub fn reset_usage(&self) {
|
||||
self.used_cycles
|
||||
.store(0, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefaultPolicyGate {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyGate for DefaultPolicyGate {
|
||||
fn allow_inference(&self, hint: &GateHint) -> bool {
|
||||
// Check compute budget
|
||||
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
if budget > 0 {
|
||||
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
if used >= budget {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check compute class restrictions
|
||||
// Always allow reflex, check others based on config
|
||||
hint.max_compute_class >= ComputeClass::Reflex
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
// Check if inference ran
|
||||
if !witness.gate_decision.did_run() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check early exit policy
|
||||
if !self.write_policy.allow_early_exit_writes {
|
||||
if let crate::types::GateDecision::EarlyExit { .. } = witness.gate_decision {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check latency
|
||||
if witness.latency_ns > self.write_policy.max_latency_ns {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check backend requirement
|
||||
if let Some(required) = self.write_policy.required_backend {
|
||||
if witness.backend != required {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn remaining_budget(&self) -> Option<u64> {
|
||||
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
if budget == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
|
||||
Some(budget.saturating_sub(used))
|
||||
}
|
||||
|
||||
fn record_usage(&self, cycles: u32) {
|
||||
self.used_cycles
|
||||
.fetch_add(cycles as u64, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate-limited policy gate
|
||||
pub struct RateLimitedPolicyGate {
|
||||
base: DefaultPolicyGate,
|
||||
/// Maximum inferences per second
|
||||
max_inferences_per_sec: u32,
|
||||
/// Inference count in current window
|
||||
inference_count: std::sync::atomic::AtomicU32,
|
||||
/// Window start time
|
||||
window_start: std::sync::RwLock<std::time::Instant>,
|
||||
}
|
||||
|
||||
impl RateLimitedPolicyGate {
|
||||
/// Create with rate limit
|
||||
pub fn new(max_inferences_per_sec: u32, write_policy: WritePolicy) -> Self {
|
||||
Self {
|
||||
base: DefaultPolicyGate::with_policy(write_policy),
|
||||
max_inferences_per_sec,
|
||||
inference_count: std::sync::atomic::AtomicU32::new(0),
|
||||
window_start: std::sync::RwLock::new(std::time::Instant::now()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check and update rate limit
|
||||
fn check_rate_limit(&self) -> bool {
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// Check if we need to reset the window
|
||||
{
|
||||
let window_start = self.window_start.read().unwrap();
|
||||
if now.duration_since(*window_start).as_secs() >= 1 {
|
||||
drop(window_start);
|
||||
let mut window_start = self.window_start.write().unwrap();
|
||||
*window_start = now;
|
||||
self.inference_count
|
||||
.store(0, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
// Check count
|
||||
let count = self
|
||||
.inference_count
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
count < self.max_inferences_per_sec
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyGate for RateLimitedPolicyGate {
|
||||
fn allow_inference(&self, hint: &GateHint) -> bool {
|
||||
// Check rate limit first
|
||||
if !self.check_rate_limit() {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.base.allow_inference(hint)
|
||||
}
|
||||
|
||||
fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
self.base.allow_write(witness)
|
||||
}
|
||||
|
||||
fn remaining_budget(&self) -> Option<u64> {
|
||||
self.base.remaining_budget()
|
||||
}
|
||||
|
||||
fn record_usage(&self, cycles: u32) {
|
||||
self.base.record_usage(cycles);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_policy_allows_inference() {
|
||||
let gate = DefaultPolicyGate::new();
|
||||
let hint = GateHint::allow_all();
|
||||
assert!(gate.allow_inference(&hint));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_budget_limiting() {
|
||||
let gate = DefaultPolicyGate::new();
|
||||
gate.set_budget(1000);
|
||||
|
||||
let hint = GateHint::allow_all();
|
||||
|
||||
// Should allow initially
|
||||
assert!(gate.allow_inference(&hint));
|
||||
|
||||
// Record usage exceeding budget
|
||||
gate.record_usage(1500);
|
||||
|
||||
// Should deny now
|
||||
assert!(!gate.allow_inference(&hint));
|
||||
|
||||
// Reset and check again
|
||||
gate.reset_usage();
|
||||
assert!(gate.allow_inference(&hint));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_policy_early_exit() {
|
||||
let gate = DefaultPolicyGate::with_policy(WritePolicy::default());
|
||||
|
||||
let mut witness = crate::witness::WitnessLog::empty();
|
||||
witness.gate_decision = crate::types::GateDecision::EarlyExit { layer: 3 };
|
||||
|
||||
// Default policy denies early exit writes
|
||||
assert!(!gate.allow_write(&witness));
|
||||
|
||||
// Permissive policy allows
|
||||
let permissive = DefaultPolicyGate::with_policy(WritePolicy::permissive());
|
||||
assert!(permissive.allow_write(&witness));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_policy_latency() {
|
||||
let mut policy = WritePolicy::default();
|
||||
policy.max_latency_ns = 1000;
|
||||
let gate = DefaultPolicyGate::with_policy(policy);
|
||||
|
||||
let mut witness = crate::witness::WitnessLog::empty();
|
||||
witness.latency_ns = 500;
|
||||
assert!(gate.allow_write(&witness));
|
||||
|
||||
witness.latency_ns = 2000;
|
||||
assert!(!gate.allow_write(&witness));
|
||||
}
|
||||
}
|
||||
331
crates/ruvector-fpga-transformer/src/lib.rs
Normal file
331
crates/ruvector-fpga-transformer/src/lib.rs
Normal file
@@ -0,0 +1,331 @@
|
||||
//! # FPGA Transformer Backend
|
||||
//!
|
||||
//! Ultra low latency transformer inference with FPGA acceleration,
|
||||
//! coherence gating, and deterministic execution.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Deterministic latency paths**: Fixed shape inference with bounded timing
|
||||
//! - **Quantization first design**: Explicit INT4/INT8 quantization with reproducible math
|
||||
//! - **Zero allocation hot path**: No heap allocations during inference
|
||||
//! - **Coherence gating**: Mincut-integrated gate decisions
|
||||
//! - **Multiple backends**: FPGA PCIe, FPGA Daemon, Native Sim, WASM Sim
|
||||
//! - **Witness logging**: Auditable inference with ReasoningBank integration
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use ruvector_fpga_transformer::{Engine, artifact::ModelArtifact};
|
||||
//! use ruvector_fpga_transformer::backend::native_sim::NativeSimBackend;
|
||||
//! use ruvector_fpga_transformer::gating::DefaultCoherenceGate;
|
||||
//! use ruvector_fpga_transformer::types::{InferenceRequest, GateHint, FixedShape};
|
||||
//! use std::sync::Arc;
|
||||
//!
|
||||
//! // Create backend and gate
|
||||
//! let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
//! let backend = NativeSimBackend::new(gate.clone());
|
||||
//!
|
||||
//! // Create engine
|
||||
//! let mut engine = Engine::new(Box::new(backend), gate);
|
||||
//!
|
||||
//! // Load artifact (from file or bytes)
|
||||
//! // let model_id = engine.load_artifact(&artifact_bytes)?;
|
||||
//!
|
||||
//! // Run inference
|
||||
//! // let result = engine.infer(request)?;
|
||||
//! ```
|
||||
//!
|
||||
//! ## Backend Selection
|
||||
//!
|
||||
//! The crate supports multiple backends selected at runtime:
|
||||
//!
|
||||
//! - `FpgaPcie`: Direct PCIe access to FPGA (requires `pcie` feature)
|
||||
//! - `FpgaDaemon`: Communication via local daemon (requires `daemon` feature)
|
||||
//! - `NativeSim`: Pure Rust simulator (requires `native_sim` feature)
|
||||
//! - `WasmSim`: WASM-compatible simulator (requires `wasm` feature)
|
||||
//!
|
||||
//! ## Artifact Format
|
||||
//!
|
||||
//! Models are packaged as signed artifacts containing:
|
||||
//! - Manifest with shape and quantization metadata
|
||||
//! - Quantized weights
|
||||
//! - Optional FPGA bitstream
|
||||
//! - Test vectors for validation
|
||||
//! - Ed25519 signature
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(feature = "wasm", allow(unused_imports))]
|
||||
|
||||
pub mod artifact;
|
||||
pub mod backend;
|
||||
pub mod error;
|
||||
pub mod ffi;
|
||||
pub mod gating;
|
||||
pub mod quant;
|
||||
pub mod types;
|
||||
pub mod witness;
|
||||
|
||||
pub use artifact::ModelArtifact;
|
||||
pub use backend::TransformerBackend;
|
||||
pub use error::{Error, Result};
|
||||
pub use gating::CoherenceGate;
|
||||
pub use types::{
|
||||
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
|
||||
InferenceResult, Layout, ModelId, QuantSpec, QuantizedTensor, SkipReason, WitnessLog,
|
||||
};
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Crate version
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Main engine for FPGA transformer inference
|
||||
///
|
||||
/// The engine combines a backend (FPGA, simulator, etc.) with a coherence gate
|
||||
/// for controlled inference execution.
|
||||
pub struct Engine {
|
||||
/// Backend for inference execution
|
||||
backend: Box<dyn TransformerBackend>,
|
||||
/// Coherence gate for decision making
|
||||
gate: Arc<dyn CoherenceGate>,
|
||||
/// Loaded models
|
||||
models: std::collections::HashMap<ModelId, ModelInfo>,
|
||||
/// Inference statistics
|
||||
stats: EngineStats,
|
||||
}
|
||||
|
||||
/// Information about a loaded model
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelInfo {
|
||||
/// Model artifact
|
||||
pub artifact: ModelArtifact,
|
||||
/// Shape configuration
|
||||
pub shape: FixedShape,
|
||||
/// Quantization spec
|
||||
pub quant: QuantSpec,
|
||||
}
|
||||
|
||||
/// Engine statistics
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct EngineStats {
|
||||
/// Total inferences
|
||||
pub total_inferences: u64,
|
||||
/// Successful inferences
|
||||
pub successful: u64,
|
||||
/// Skipped inferences
|
||||
pub skipped: u64,
|
||||
/// Early exits
|
||||
pub early_exits: u64,
|
||||
/// Total latency (ns)
|
||||
pub total_latency_ns: u64,
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
/// Create a new engine with the specified backend and gate
|
||||
pub fn new(backend: Box<dyn TransformerBackend>, gate: Arc<dyn CoherenceGate>) -> Self {
|
||||
Self {
|
||||
backend,
|
||||
gate,
|
||||
models: std::collections::HashMap::new(),
|
||||
stats: EngineStats::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default native simulator backend
|
||||
#[cfg(feature = "native_sim")]
|
||||
pub fn native_sim() -> Self {
|
||||
let gate = Arc::new(gating::DefaultCoherenceGate::new());
|
||||
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
|
||||
Self::new(backend, gate)
|
||||
}
|
||||
|
||||
/// Load a model artifact from bytes
|
||||
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<ModelId> {
|
||||
let artifact = artifact::unpack_artifact(artifact_bytes)?;
|
||||
self.load(&artifact)
|
||||
}
|
||||
|
||||
/// Load a model artifact
|
||||
pub fn load(&mut self, artifact: &ModelArtifact) -> Result<ModelId> {
|
||||
// Validate artifact
|
||||
artifact.validate()?;
|
||||
|
||||
// Load into backend
|
||||
let model_id = self.backend.load(artifact)?;
|
||||
|
||||
// Store info
|
||||
self.models.insert(
|
||||
model_id,
|
||||
ModelInfo {
|
||||
artifact: artifact.clone(),
|
||||
shape: artifact.manifest.shape,
|
||||
quant: artifact.manifest.quant,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(model_id)
|
||||
}
|
||||
|
||||
/// Run inference
|
||||
pub fn infer(&mut self, req: InferenceRequest) -> Result<InferenceResult> {
|
||||
self.stats.total_inferences += 1;
|
||||
|
||||
// Check preflight gate
|
||||
let preflight = self.gate.preflight(&req.gate_hint);
|
||||
if let GateDecision::Skipped { reason } = preflight {
|
||||
self.stats.skipped += 1;
|
||||
return Err(Error::GateBlocked { reason });
|
||||
}
|
||||
|
||||
// Run inference
|
||||
let result = self.backend.infer(req)?;
|
||||
|
||||
// Update stats
|
||||
self.stats.total_latency_ns += result.witness.latency_ns as u64;
|
||||
match result.witness.gate_decision {
|
||||
GateDecision::RanFull => self.stats.successful += 1,
|
||||
GateDecision::EarlyExit { .. } => {
|
||||
self.stats.successful += 1;
|
||||
self.stats.early_exits += 1;
|
||||
}
|
||||
GateDecision::Skipped { .. } => self.stats.skipped += 1,
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Unload a model
|
||||
pub fn unload(&mut self, model: ModelId) -> Result<()> {
|
||||
self.backend.unload(model)?;
|
||||
self.models.remove(&model);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get model shape
|
||||
pub fn shape(&self, model: ModelId) -> Result<FixedShape> {
|
||||
self.models
|
||||
.get(&model)
|
||||
.map(|info| info.shape)
|
||||
.ok_or_else(|| Error::ModelNotFound(model))
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
pub fn model_info(&self, model: ModelId) -> Option<&ModelInfo> {
|
||||
self.models.get(&model)
|
||||
}
|
||||
|
||||
/// Check if model is loaded
|
||||
pub fn is_loaded(&self, model: ModelId) -> bool {
|
||||
self.models.contains_key(&model)
|
||||
}
|
||||
|
||||
/// Get list of loaded models
|
||||
pub fn loaded_models(&self) -> Vec<ModelId> {
|
||||
self.models.keys().copied().collect()
|
||||
}
|
||||
|
||||
/// Get engine statistics
|
||||
pub fn stats(&self) -> &EngineStats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
/// Get backend statistics
|
||||
pub fn backend_stats(&self) -> backend::BackendStats {
|
||||
self.backend.stats()
|
||||
}
|
||||
|
||||
/// Get backend kind
|
||||
pub fn backend_kind(&self) -> BackendKind {
|
||||
self.backend.kind()
|
||||
}
|
||||
|
||||
/// Check if write is allowed based on witness
|
||||
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
|
||||
self.gate.allow_write(witness)
|
||||
}
|
||||
|
||||
/// Reset statistics
|
||||
pub fn reset_stats(&mut self) {
|
||||
self.stats = EngineStats::default();
|
||||
}
|
||||
}
|
||||
|
||||
impl EngineStats {
|
||||
/// Get average latency in nanoseconds
|
||||
pub fn avg_latency_ns(&self) -> f64 {
|
||||
if self.successful == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_latency_ns as f64 / self.successful as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average latency in milliseconds
|
||||
pub fn avg_latency_ms(&self) -> f64 {
|
||||
self.avg_latency_ns() / 1_000_000.0
|
||||
}
|
||||
|
||||
/// Get success rate
|
||||
pub fn success_rate(&self) -> f64 {
|
||||
if self.total_inferences == 0 {
|
||||
1.0
|
||||
} else {
|
||||
self.successful as f64 / self.total_inferences as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get early exit rate
|
||||
pub fn early_exit_rate(&self) -> f64 {
|
||||
if self.successful == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.early_exits as f64 / self.successful as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prelude for convenient imports
|
||||
pub mod prelude {
|
||||
pub use crate::{
|
||||
artifact::ModelArtifact,
|
||||
backend::TransformerBackend,
|
||||
gating::CoherenceGate,
|
||||
types::{
|
||||
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
|
||||
InferenceResult, ModelId, QuantSpec, SkipReason, WitnessLog,
|
||||
},
|
||||
Engine, Error, Result,
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_engine_creation() {
|
||||
let gate = Arc::new(gating::DefaultCoherenceGate::new());
|
||||
|
||||
#[cfg(feature = "native_sim")]
|
||||
{
|
||||
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
|
||||
let engine = Engine::new(backend, gate);
|
||||
assert!(engine.loaded_models().is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_stats() {
|
||||
let stats = EngineStats {
|
||||
total_inferences: 100,
|
||||
successful: 80,
|
||||
skipped: 10,
|
||||
early_exits: 20,
|
||||
total_latency_ns: 8_000_000,
|
||||
};
|
||||
|
||||
assert!((stats.success_rate() - 0.8).abs() < 0.01);
|
||||
assert!((stats.early_exit_rate() - 0.25).abs() < 0.01);
|
||||
assert!((stats.avg_latency_ns() - 100_000.0).abs() < 1.0);
|
||||
}
|
||||
}
|
||||
303
crates/ruvector-fpga-transformer/src/quant/calib.rs
Normal file
303
crates/ruvector-fpga-transformer/src/quant/calib.rs
Normal file
@@ -0,0 +1,303 @@
|
||||
//! Calibration data for quantization
|
||||
|
||||
use crate::error::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Calibration data for a model
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CalibrationData {
|
||||
/// Layer-wise activation statistics
|
||||
pub layers: Vec<LayerCalibration>,
|
||||
/// Global input statistics
|
||||
pub input_stats: ActivationStats,
|
||||
/// Number of calibration samples used
|
||||
pub num_samples: usize,
|
||||
/// Calibration method used
|
||||
pub method: CalibrationMethod,
|
||||
}
|
||||
|
||||
/// Per-layer calibration data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LayerCalibration {
|
||||
/// Layer index
|
||||
pub layer_idx: usize,
|
||||
/// Layer name
|
||||
pub name: String,
|
||||
/// Activation statistics after this layer
|
||||
pub activation_stats: ActivationStats,
|
||||
/// Weight statistics for this layer
|
||||
pub weight_stats: WeightStats,
|
||||
/// Optimal scale for activations (Q16.16)
|
||||
pub act_scale: i32,
|
||||
/// Optimal scale for weights (Q16.16)
|
||||
pub weight_scale: i32,
|
||||
}
|
||||
|
||||
/// Activation statistics
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ActivationStats {
|
||||
/// Minimum value seen
|
||||
pub min: f32,
|
||||
/// Maximum value seen
|
||||
pub max: f32,
|
||||
/// Mean value
|
||||
pub mean: f32,
|
||||
/// Standard deviation
|
||||
pub std: f32,
|
||||
/// Histogram bins (for entropy calibration)
|
||||
#[serde(default)]
|
||||
pub histogram: Vec<u32>,
|
||||
/// Histogram bin edges
|
||||
#[serde(default)]
|
||||
pub bin_edges: Vec<f32>,
|
||||
}
|
||||
|
||||
impl ActivationStats {
|
||||
/// Create empty stats
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Update stats with a batch of values
|
||||
pub fn update(&mut self, values: &[f32]) {
|
||||
if values.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update min/max
|
||||
for &v in values {
|
||||
if v < self.min || self.min == 0.0 {
|
||||
self.min = v;
|
||||
}
|
||||
if v > self.max {
|
||||
self.max = v;
|
||||
}
|
||||
}
|
||||
|
||||
// Update running mean and std
|
||||
let n = values.len() as f32;
|
||||
let batch_mean = values.iter().sum::<f32>() / n;
|
||||
let batch_var = values.iter().map(|v| (v - batch_mean).powi(2)).sum::<f32>() / n;
|
||||
|
||||
// Simple update (not online algorithm)
|
||||
self.mean = batch_mean;
|
||||
self.std = batch_var.sqrt();
|
||||
}
|
||||
|
||||
/// Compute optimal scale for symmetric quantization to n bits
|
||||
pub fn optimal_scale(&self, bits: u8) -> f32 {
|
||||
let max_range = self.max.abs().max(self.min.abs());
|
||||
let qmax = (1 << (bits - 1)) as f32 - 1.0;
|
||||
max_range / qmax
|
||||
}
|
||||
}
|
||||
|
||||
/// Weight statistics
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct WeightStats {
|
||||
/// Min weight value
|
||||
pub min: f32,
|
||||
/// Max weight value
|
||||
pub max: f32,
|
||||
/// Sparsity (fraction of zeros)
|
||||
pub sparsity: f32,
|
||||
}
|
||||
|
||||
impl WeightStats {
|
||||
/// Compute from weight tensor
|
||||
pub fn from_weights(weights: &[f32]) -> Self {
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
let mut zeros = 0usize;
|
||||
|
||||
for &w in weights {
|
||||
if w < min {
|
||||
min = w;
|
||||
}
|
||||
if w > max {
|
||||
max = w;
|
||||
}
|
||||
if w.abs() < 1e-6 {
|
||||
zeros += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
min,
|
||||
max,
|
||||
sparsity: zeros as f32 / weights.len() as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calibration method
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum CalibrationMethod {
|
||||
/// Use min/max of observed values
|
||||
MinMax,
|
||||
/// Use percentile clipping (e.g., 99.9%)
|
||||
Percentile(u32), // 999 = 99.9%
|
||||
/// Entropy-based calibration (KL divergence)
|
||||
Entropy,
|
||||
/// Mean-squared error minimization
|
||||
Mse,
|
||||
}
|
||||
|
||||
impl Default for CalibrationMethod {
|
||||
fn default() -> Self {
|
||||
Self::MinMax
|
||||
}
|
||||
}
|
||||
|
||||
impl CalibrationData {
|
||||
/// Create empty calibration data
|
||||
pub fn new(method: CalibrationMethod) -> Self {
|
||||
Self {
|
||||
layers: Vec::new(),
|
||||
input_stats: ActivationStats::new(),
|
||||
num_samples: 0,
|
||||
method,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add layer calibration
|
||||
pub fn add_layer(&mut self, calib: LayerCalibration) {
|
||||
self.layers.push(calib);
|
||||
}
|
||||
|
||||
/// Serialize to bytes
|
||||
pub fn to_bytes(&self) -> Result<Vec<u8>> {
|
||||
Ok(serde_json::to_vec(self)?)
|
||||
}
|
||||
|
||||
/// Deserialize from bytes
|
||||
pub fn from_bytes(data: &[u8]) -> Result<Self> {
|
||||
Ok(serde_json::from_slice(data)?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calibrate a model by collecting activation statistics
|
||||
pub fn calibrate_model<F>(
|
||||
run_inference: F,
|
||||
calibration_inputs: &[Vec<u16>],
|
||||
num_layers: usize,
|
||||
method: CalibrationMethod,
|
||||
) -> Result<CalibrationData>
|
||||
where
|
||||
F: Fn(&[u16]) -> Result<Vec<Vec<f32>>>, // Returns activations per layer
|
||||
{
|
||||
let mut calibration = CalibrationData::new(method);
|
||||
|
||||
// Initialize layer stats
|
||||
let mut layer_stats: Vec<ActivationStats> =
|
||||
(0..num_layers).map(|_| ActivationStats::new()).collect();
|
||||
|
||||
// Run calibration passes
|
||||
for input in calibration_inputs {
|
||||
// Run inference and collect activations
|
||||
let activations = run_inference(input)?;
|
||||
|
||||
// Update statistics
|
||||
for (layer_idx, layer_act) in activations.iter().enumerate() {
|
||||
if layer_idx < num_layers {
|
||||
layer_stats[layer_idx].update(layer_act);
|
||||
}
|
||||
}
|
||||
|
||||
calibration.num_samples += 1;
|
||||
}
|
||||
|
||||
// Create layer calibrations
|
||||
for (layer_idx, stats) in layer_stats.into_iter().enumerate() {
|
||||
let act_scale = match method {
|
||||
CalibrationMethod::MinMax => stats.optimal_scale(8),
|
||||
CalibrationMethod::Percentile(_) => stats.optimal_scale(8) * 0.99,
|
||||
CalibrationMethod::Entropy => stats.optimal_scale(8),
|
||||
CalibrationMethod::Mse => stats.optimal_scale(8),
|
||||
};
|
||||
|
||||
calibration.add_layer(LayerCalibration {
|
||||
layer_idx,
|
||||
name: format!("layer_{}", layer_idx),
|
||||
activation_stats: stats,
|
||||
weight_stats: WeightStats::default(),
|
||||
act_scale: (act_scale * 65536.0) as i32,
|
||||
weight_scale: 65536, // Default 1.0
|
||||
});
|
||||
}
|
||||
|
||||
Ok(calibration)
|
||||
}
|
||||
|
||||
/// Apply percentile clipping to calibration
|
||||
pub fn apply_percentile(stats: &ActivationStats, percentile: f32) -> (f32, f32) {
|
||||
if stats.histogram.is_empty() || stats.bin_edges.len() < 2 {
|
||||
return (stats.min, stats.max);
|
||||
}
|
||||
|
||||
let total: u32 = stats.histogram.iter().sum();
|
||||
let target_low = (total as f32 * (1.0 - percentile) / 2.0) as u32;
|
||||
let target_high = (total as f32 * (1.0 + percentile) / 2.0) as u32;
|
||||
|
||||
let mut cumsum = 0u32;
|
||||
let mut low_idx = 0;
|
||||
let mut high_idx = stats.histogram.len() - 1;
|
||||
|
||||
for (i, &count) in stats.histogram.iter().enumerate() {
|
||||
cumsum += count;
|
||||
if cumsum >= target_low && low_idx == 0 {
|
||||
low_idx = i;
|
||||
}
|
||||
if cumsum >= target_high {
|
||||
high_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(stats.bin_edges[low_idx], stats.bin_edges[high_idx + 1])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_activation_stats_update() {
|
||||
let mut stats = ActivationStats::new();
|
||||
stats.update(&[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
|
||||
assert_eq!(stats.min, 1.0);
|
||||
assert_eq!(stats.max, 5.0);
|
||||
assert!((stats.mean - 3.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimal_scale() {
|
||||
let mut stats = ActivationStats::new();
|
||||
stats.min = -1.0;
|
||||
stats.max = 1.0;
|
||||
|
||||
let scale = stats.optimal_scale(8);
|
||||
// For 8-bit, qmax = 127, so scale should be 1.0/127 ≈ 0.00787
|
||||
assert!((scale - 1.0 / 127.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weight_stats() {
|
||||
let weights = vec![0.0, 0.1, -0.1, 0.5, -0.5, 0.0];
|
||||
let stats = WeightStats::from_weights(&weights);
|
||||
|
||||
assert_eq!(stats.min, -0.5);
|
||||
assert_eq!(stats.max, 0.5);
|
||||
assert!((stats.sparsity - 2.0 / 6.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibration_serialization() {
|
||||
let calib = CalibrationData::new(CalibrationMethod::MinMax);
|
||||
let bytes = calib.to_bytes().unwrap();
|
||||
let restored = CalibrationData::from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(calib.method, restored.method);
|
||||
}
|
||||
}
|
||||
336
crates/ruvector-fpga-transformer/src/quant/lut.rs
Normal file
336
crates/ruvector-fpga-transformer/src/quant/lut.rs
Normal file
@@ -0,0 +1,336 @@
|
||||
//! Lookup table implementations for fixed-point operations
|
||||
//!
|
||||
//! Provides LUT-based exp, log, and softmax for deterministic computation.
|
||||
|
||||
/// LUT-based exponential function
|
||||
/// Input: Q8.8 fixed point [-16, 16)
|
||||
/// Output: Q0.16 fixed point [0, 1)
|
||||
const EXP_LUT_SIZE: usize = 256;
|
||||
const EXP_LUT_SHIFT: i32 = 8; // Q8.8 input
|
||||
|
||||
/// Precomputed exp LUT for range [-8, 8) in Q8.8
|
||||
static EXP_LUT: [u16; EXP_LUT_SIZE] = generate_exp_lut();
|
||||
|
||||
/// Generate exp LUT at compile time
|
||||
const fn generate_exp_lut() -> [u16; EXP_LUT_SIZE] {
|
||||
let mut lut = [0u16; EXP_LUT_SIZE];
|
||||
let mut i = 0;
|
||||
while i < EXP_LUT_SIZE {
|
||||
// Convert index to Q8.8 value (range -128..128 in fixed point = -0.5..0.5)
|
||||
let x_q = (i as i32) - 128;
|
||||
// Scale to get reasonable exp range
|
||||
let x_f = (x_q as f64) / 32.0; // x in [-4, 4)
|
||||
|
||||
// Compute exp and scale to Q0.16
|
||||
let exp_val = const_exp(x_f);
|
||||
let scaled = exp_val / (1.0 + const_exp(4.0)); // Normalize
|
||||
|
||||
// Convert to u16
|
||||
lut[i] = if scaled > 1.0 {
|
||||
65535
|
||||
} else if scaled < 0.0 {
|
||||
0
|
||||
} else {
|
||||
(scaled * 65535.0) as u16
|
||||
};
|
||||
|
||||
i += 1;
|
||||
}
|
||||
lut
|
||||
}
|
||||
|
||||
/// Const-compatible exp approximation using Taylor series
|
||||
const fn const_exp(x: f64) -> f64 {
|
||||
// exp(x) ≈ 1 + x + x²/2 + x³/6 + x⁴/24 + x⁵/120
|
||||
let x2 = x * x;
|
||||
let x3 = x2 * x;
|
||||
let x4 = x3 * x;
|
||||
let x5 = x4 * x;
|
||||
|
||||
1.0 + x + x2 / 2.0 + x3 / 6.0 + x4 / 24.0 + x5 / 120.0
|
||||
}
|
||||
|
||||
/// LUT-based exponential
|
||||
/// Input: i16 in Q8.8 format
|
||||
/// Output: u16 in Q0.16 format
|
||||
#[inline]
|
||||
pub fn exp_lut(x: i16) -> u16 {
|
||||
// Clamp to LUT range
|
||||
let clamped = x.clamp(-128 * 256, 127 * 256);
|
||||
// Scale to LUT index
|
||||
let idx = ((clamped >> EXP_LUT_SHIFT) + 128) as usize;
|
||||
EXP_LUT[idx.min(EXP_LUT_SIZE - 1)]
|
||||
}
|
||||
|
||||
/// Log LUT for Q0.16 input
|
||||
static LOG_LUT: [i16; 256] = generate_log_lut();
|
||||
|
||||
const fn generate_log_lut() -> [i16; 256] {
|
||||
let mut lut = [0i16; 256];
|
||||
let mut i = 1;
|
||||
while i < 256 {
|
||||
// Input is scaled by 256, so x = i/256 in [0.004, 1)
|
||||
let x = (i as f64) / 256.0;
|
||||
// log(x) in Q8.8 format
|
||||
let log_val = const_ln(x);
|
||||
lut[i] = (log_val * 256.0) as i16;
|
||||
i += 1;
|
||||
}
|
||||
lut[0] = i16::MIN; // log(0) = -inf, use min value
|
||||
lut
|
||||
}
|
||||
|
||||
/// Const-compatible natural log approximation
|
||||
const fn const_ln(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
return f64::NEG_INFINITY;
|
||||
}
|
||||
// Use series expansion around x=1: ln(x) = 2 * sum((x-1)/(x+1))^(2n+1)/(2n+1)
|
||||
let y = (x - 1.0) / (x + 1.0);
|
||||
let y2 = y * y;
|
||||
|
||||
// ln(x) ≈ 2 * (y + y³/3 + y⁵/5 + y⁷/7 + y⁹/9)
|
||||
let y3 = y2 * y;
|
||||
let y5 = y3 * y2;
|
||||
let y7 = y5 * y2;
|
||||
let y9 = y7 * y2;
|
||||
|
||||
2.0 * (y + y3 / 3.0 + y5 / 5.0 + y7 / 7.0 + y9 / 9.0)
|
||||
}
|
||||
|
||||
/// LUT-based natural log
|
||||
/// Input: u16 in Q0.16 format (0 to 65535 = 0.0 to ~1.0)
|
||||
/// Output: i16 in Q8.8 format
|
||||
#[inline]
|
||||
pub fn log_lut(x: u16) -> i16 {
|
||||
if x == 0 {
|
||||
return i16::MIN;
|
||||
}
|
||||
// Scale to LUT index
|
||||
let idx = (x >> 8) as usize;
|
||||
LOG_LUT[idx.min(255)]
|
||||
}
|
||||
|
||||
/// Softmax using LUT-based exp
|
||||
/// Operates in-place on Q8.8 logits, outputs Q0.16 probabilities
|
||||
pub fn softmax_lut_q(logits: &mut [i16]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find max for numerical stability
|
||||
let max = *logits.iter().max().unwrap_or(&0);
|
||||
|
||||
// Compute exp(x - max) using LUT
|
||||
let mut sum: u32 = 0;
|
||||
let mut exp_values: Vec<u16> = Vec::with_capacity(logits.len());
|
||||
|
||||
for &logit in logits.iter() {
|
||||
let shifted = logit.saturating_sub(max);
|
||||
let exp_val = exp_lut(shifted);
|
||||
exp_values.push(exp_val);
|
||||
sum += exp_val as u32;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
if sum == 0 {
|
||||
sum = 1;
|
||||
}
|
||||
|
||||
for (i, logit) in logits.iter_mut().enumerate() {
|
||||
let prob = ((exp_values[i] as u64 * 65535) / sum as u64) as u16;
|
||||
*logit = prob as i16;
|
||||
}
|
||||
}
|
||||
|
||||
/// Softmax on f32 values using LUT (for compatibility)
|
||||
pub fn softmax_lut(logits: &mut [f32]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find max
|
||||
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Compute exp
|
||||
let mut sum = 0.0f32;
|
||||
for v in logits.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
if sum > 0.0 {
|
||||
for v in logits.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Piecewise linear softmax approximation
|
||||
/// More accurate than LUT but still deterministic
|
||||
pub fn softmax_pwl(logits: &mut [i16]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let max = *logits.iter().max().unwrap_or(&0);
|
||||
|
||||
// Piecewise linear exp approximation
|
||||
// exp(x) ≈ 1 + x for x near 0
|
||||
// exp(x) ≈ 2^(x/ln2) for larger x
|
||||
let mut sum: i64 = 0;
|
||||
let mut exp_values: Vec<i32> = Vec::with_capacity(logits.len());
|
||||
|
||||
for &logit in logits.iter() {
|
||||
let x = (logit - max) as i32; // x <= 0
|
||||
|
||||
// Piecewise approximation (in Q8.8)
|
||||
let exp_val = if x >= -256 {
|
||||
// x in [-1, 0] -> linear: 1 + x
|
||||
(256 + x).max(0) as i32
|
||||
} else if x >= -2048 {
|
||||
// x in [-8, -1] -> exponential decay
|
||||
let shifted = (x + 2048) >> 3; // Scale to 0-256 range
|
||||
(shifted * shifted / 256).max(1) as i32
|
||||
} else {
|
||||
// x < -8 -> essentially zero
|
||||
1
|
||||
};
|
||||
|
||||
exp_values.push(exp_val);
|
||||
sum += exp_val as i64;
|
||||
}
|
||||
|
||||
// Normalize to Q0.16
|
||||
if sum == 0 {
|
||||
sum = 1;
|
||||
}
|
||||
|
||||
for (i, logit) in logits.iter_mut().enumerate() {
|
||||
let prob = (exp_values[i] as i64 * 65535 / sum) as i16;
|
||||
*logit = prob;
|
||||
}
|
||||
}
|
||||
|
||||
/// GELU approximation using LUT
|
||||
/// GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||
pub fn gelu_lut(x: i16) -> i16 {
|
||||
// Simplified approximation: GELU(x) ≈ x * sigmoid(1.702 * x)
|
||||
let scaled = ((x as i32 * 435) >> 8) as i16; // 1.702 * x in Q8.8
|
||||
let sigmoid_val = sigmoid_lut(scaled);
|
||||
((x as i32 * sigmoid_val as i32) >> 15) as i16
|
||||
}
|
||||
|
||||
/// Sigmoid LUT
|
||||
static SIGMOID_LUT: [u16; 256] = generate_sigmoid_lut();
|
||||
|
||||
const fn generate_sigmoid_lut() -> [u16; 256] {
|
||||
let mut lut = [0u16; 256];
|
||||
let mut i = 0;
|
||||
while i < 256 {
|
||||
// Map index to x in [-8, 8)
|
||||
let x = ((i as i32) - 128) as f64 / 16.0;
|
||||
// sigmoid(x) = 1 / (1 + exp(-x))
|
||||
let sig = 1.0 / (1.0 + const_exp(-x));
|
||||
lut[i] = (sig * 65535.0) as u16;
|
||||
i += 1;
|
||||
}
|
||||
lut
|
||||
}
|
||||
|
||||
/// LUT-based sigmoid
|
||||
/// Input: i16 in Q8.8 format
|
||||
/// Output: u16 in Q0.16 format
|
||||
#[inline]
|
||||
pub fn sigmoid_lut(x: i16) -> u16 {
|
||||
// Scale to LUT range
|
||||
let idx = (((x >> 5) + 128) as usize).min(255);
|
||||
SIGMOID_LUT[idx]
|
||||
}
|
||||
|
||||
/// SiLU (Swish) using sigmoid LUT
|
||||
/// SiLU(x) = x * sigmoid(x)
|
||||
#[inline]
|
||||
pub fn silu_lut(x: i16) -> i16 {
|
||||
let sigmoid_val = sigmoid_lut(x);
|
||||
((x as i32 * sigmoid_val as i32) >> 16) as i16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exp_lut() {
|
||||
// exp(0) should return a non-zero value
|
||||
let result = exp_lut(0);
|
||||
assert!(result > 0, "exp(0) should be positive");
|
||||
|
||||
// exp is monotonically increasing
|
||||
let result_neg = exp_lut(-256); // -1.0 in Q8.8
|
||||
let result_zero = exp_lut(0);
|
||||
let result_pos = exp_lut(256); // 1.0 in Q8.8
|
||||
assert!(
|
||||
result_neg <= result_zero,
|
||||
"exp should be monotonically increasing"
|
||||
);
|
||||
assert!(
|
||||
result_zero <= result_pos,
|
||||
"exp should be monotonically increasing"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid_lut() {
|
||||
// sigmoid(0) = 0.5
|
||||
let result = sigmoid_lut(0);
|
||||
let expected = 32768u16; // 0.5 in Q0.16
|
||||
assert!(
|
||||
(result as i32 - expected as i32).abs() < 5000,
|
||||
"sigmoid(0) ≈ 0.5"
|
||||
);
|
||||
|
||||
// sigmoid is monotonically increasing
|
||||
let result_neg = sigmoid_lut(-1024);
|
||||
let result_zero = sigmoid_lut(0);
|
||||
let result_pos = sigmoid_lut(1024);
|
||||
assert!(
|
||||
result_neg < result_zero,
|
||||
"sigmoid should be monotonically increasing"
|
||||
);
|
||||
assert!(
|
||||
result_zero < result_pos,
|
||||
"sigmoid should be monotonically increasing"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_lut() {
|
||||
let mut logits = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
softmax_lut(&mut logits);
|
||||
|
||||
// Sum should be 1.0
|
||||
let sum: f32 = logits.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 0.01);
|
||||
|
||||
// Should be increasing
|
||||
for i in 1..logits.len() {
|
||||
assert!(logits[i] > logits[i - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_lut() {
|
||||
// GELU(0) should be approximately 0
|
||||
assert!(gelu_lut(0).abs() < 100);
|
||||
|
||||
// GELU maintains sign
|
||||
let neg_result = gelu_lut(-256);
|
||||
assert!(neg_result <= 0, "GELU of negative should be non-positive");
|
||||
|
||||
// GELU of positive values should be positive
|
||||
let pos_result = gelu_lut(256);
|
||||
assert!(pos_result > 0, "GELU of positive should be positive");
|
||||
}
|
||||
}
|
||||
238
crates/ruvector-fpga-transformer/src/quant/mod.rs
Normal file
238
crates/ruvector-fpga-transformer/src/quant/mod.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
//! Quantization subsystem
|
||||
//!
|
||||
//! Explicit, reproducible quantization for weights and activations.
|
||||
|
||||
pub mod calib;
|
||||
pub mod lut;
|
||||
pub mod qformat;
|
||||
|
||||
pub use calib::{calibrate_model, CalibrationData};
|
||||
pub use lut::{exp_lut, log_lut, softmax_lut};
|
||||
pub use qformat::{dequantize_i16, dequantize_i8, quantize_i16, quantize_i8};
|
||||
|
||||
use crate::types::QuantSpec;
|
||||
|
||||
/// Fixed-point Q15 format (1.15)
|
||||
/// Range: [-1.0, 1.0 - 2^-15]
|
||||
/// Resolution: 2^-15 ≈ 3.05e-5
|
||||
pub type Q15 = i16;
|
||||
|
||||
/// Fixed-point Q16.16 format
|
||||
/// Range: [-32768.0, 32767.999...]
|
||||
/// Resolution: 2^-16 ≈ 1.53e-5
|
||||
pub type Q16_16 = i32;
|
||||
|
||||
/// Convert f32 to Q15
|
||||
#[inline]
|
||||
pub fn f32_to_q15(x: f32) -> Q15 {
|
||||
(x.clamp(-1.0, 1.0 - f32::EPSILON) * 32768.0) as Q15
|
||||
}
|
||||
|
||||
/// Convert Q15 to f32
|
||||
#[inline]
|
||||
pub fn q15_to_f32(x: Q15) -> f32 {
|
||||
x as f32 / 32768.0
|
||||
}
|
||||
|
||||
/// Convert f32 to Q16.16
|
||||
#[inline]
|
||||
pub fn f32_to_q16_16(x: f32) -> Q16_16 {
|
||||
(x * 65536.0) as Q16_16
|
||||
}
|
||||
|
||||
/// Convert Q16.16 to f32
|
||||
#[inline]
|
||||
pub fn q16_16_to_f32(x: Q16_16) -> f32 {
|
||||
x as f32 / 65536.0
|
||||
}
|
||||
|
||||
/// Fixed-point multiplication Q15 * Q15 -> Q15
|
||||
#[inline]
|
||||
pub fn q15_mul(a: Q15, b: Q15) -> Q15 {
|
||||
// Multiply with proper rounding
|
||||
let product = (a as i32 * b as i32 + 0x4000) >> 15;
|
||||
product.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
|
||||
}
|
||||
|
||||
/// Fixed-point dot product with accumulator
|
||||
/// Note: For very large vectors (>65536 elements), use q15_dot_saturating
|
||||
#[inline]
|
||||
pub fn q15_dot(a: &[Q15], b: &[Q15]) -> i32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(&x, &y)| x as i32 * y as i32)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Saturating fixed-point dot product (prevents overflow for large vectors)
|
||||
#[inline]
|
||||
pub fn q15_dot_saturating(a: &[Q15], b: &[Q15]) -> i32 {
|
||||
a.iter().zip(b.iter()).fold(0i32, |acc, (&x, &y)| {
|
||||
acc.saturating_add((x as i32).saturating_mul(y as i32))
|
||||
})
|
||||
}
|
||||
|
||||
/// Fixed-point dot product normalized to Q15
|
||||
#[inline]
|
||||
pub fn q15_dot_normalized(a: &[Q15], b: &[Q15], shift: u8) -> Q15 {
|
||||
let sum = q15_dot(a, b);
|
||||
let shifted = (sum + (1 << (shift - 1))) >> shift;
|
||||
shifted.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
|
||||
}
|
||||
|
||||
/// Quantization context for a layer
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantContext {
|
||||
/// Weight quantization spec
|
||||
pub weight_spec: QuantSpec,
|
||||
/// Input scale (Q16.16)
|
||||
pub input_scale: Q16_16,
|
||||
/// Output scale (Q16.16)
|
||||
pub output_scale: Q16_16,
|
||||
/// Accumulator bit width
|
||||
pub acc_bits: u8,
|
||||
/// Right shift for normalization
|
||||
pub norm_shift: u8,
|
||||
}
|
||||
|
||||
impl QuantContext {
|
||||
/// Create from QuantSpec
|
||||
pub fn from_spec(spec: &QuantSpec) -> Self {
|
||||
Self {
|
||||
weight_spec: *spec,
|
||||
input_scale: spec.scale_q,
|
||||
output_scale: spec.scale_q,
|
||||
acc_bits: 32,
|
||||
norm_shift: 15,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the required accumulator bits to avoid overflow
|
||||
pub fn required_acc_bits(input_bits: u8, weight_bits: u8, vector_len: usize) -> u8 {
|
||||
// Each multiply produces input_bits + weight_bits
|
||||
// Sum of vector_len terms adds log2(vector_len) bits
|
||||
let product_bits = input_bits + weight_bits;
|
||||
let sum_bits = (vector_len as f64).log2().ceil() as u8;
|
||||
product_bits + sum_bits + 1 // +1 for sign
|
||||
}
|
||||
}
|
||||
|
||||
/// Packing utilities for sub-byte quantization
|
||||
pub mod packing {
|
||||
/// Pack two 4-bit values into one byte
|
||||
#[inline]
|
||||
pub fn pack_int4(a: i8, b: i8) -> u8 {
|
||||
((a & 0x0F) as u8) | (((b & 0x0F) as u8) << 4)
|
||||
}
|
||||
|
||||
/// Unpack byte into two 4-bit values
|
||||
#[inline]
|
||||
pub fn unpack_int4(packed: u8) -> (i8, i8) {
|
||||
let a = (packed & 0x0F) as i8;
|
||||
let a = if a & 0x08 != 0 { a | !0x0F } else { a }; // Sign extend
|
||||
let b = ((packed >> 4) & 0x0F) as i8;
|
||||
let b = if b & 0x08 != 0 { b | !0x0F } else { b };
|
||||
(a, b)
|
||||
}
|
||||
|
||||
/// Pack four 2-bit values into one byte
|
||||
#[inline]
|
||||
pub fn pack_int2(a: i8, b: i8, c: i8, d: i8) -> u8 {
|
||||
((a & 0x03) as u8)
|
||||
| (((b & 0x03) as u8) << 2)
|
||||
| (((c & 0x03) as u8) << 4)
|
||||
| (((d & 0x03) as u8) << 6)
|
||||
}
|
||||
|
||||
/// Unpack byte into four 2-bit values
|
||||
#[inline]
|
||||
pub fn unpack_int2(packed: u8) -> (i8, i8, i8, i8) {
|
||||
let a = (packed & 0x03) as i8;
|
||||
let a = if a & 0x02 != 0 { a | !0x03 } else { a };
|
||||
let b = ((packed >> 2) & 0x03) as i8;
|
||||
let b = if b & 0x02 != 0 { b | !0x03 } else { b };
|
||||
let c = ((packed >> 4) & 0x03) as i8;
|
||||
let c = if c & 0x02 != 0 { c | !0x03 } else { c };
|
||||
let d = ((packed >> 6) & 0x03) as i8;
|
||||
let d = if d & 0x02 != 0 { d | !0x03 } else { d };
|
||||
(a, b, c, d)
|
||||
}
|
||||
|
||||
/// Pack eight 1-bit values into one byte
|
||||
#[inline]
|
||||
pub fn pack_binary(bits: &[bool; 8]) -> u8 {
|
||||
bits.iter()
|
||||
.enumerate()
|
||||
.fold(0u8, |acc, (i, &b)| acc | ((b as u8) << i))
|
||||
}
|
||||
|
||||
/// Unpack byte into eight 1-bit values
|
||||
#[inline]
|
||||
pub fn unpack_binary(packed: u8) -> [bool; 8] {
|
||||
[
|
||||
packed & 0x01 != 0,
|
||||
packed & 0x02 != 0,
|
||||
packed & 0x04 != 0,
|
||||
packed & 0x08 != 0,
|
||||
packed & 0x10 != 0,
|
||||
packed & 0x20 != 0,
|
||||
packed & 0x40 != 0,
|
||||
packed & 0x80 != 0,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_q15_conversion() {
|
||||
assert_eq!(f32_to_q15(0.0), 0);
|
||||
assert_eq!(f32_to_q15(0.5), 16384);
|
||||
assert_eq!(f32_to_q15(-0.5), -16384);
|
||||
|
||||
let x = 0.123f32;
|
||||
let q = f32_to_q15(x);
|
||||
let back = q15_to_f32(q);
|
||||
assert!((x - back).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_q15_mul() {
|
||||
let a = f32_to_q15(0.5);
|
||||
let b = f32_to_q15(0.5);
|
||||
let c = q15_mul(a, b);
|
||||
let result = q15_to_f32(c);
|
||||
assert!((result - 0.25).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packing_int4() {
|
||||
let (a, b) = (5i8, -3i8);
|
||||
let packed = packing::pack_int4(a, b);
|
||||
let (ua, ub) = packing::unpack_int4(packed);
|
||||
assert_eq!(a, ua);
|
||||
assert_eq!(b, ub);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packing_int2() {
|
||||
let (a, b, c, d) = (1i8, -1i8, 0i8, -2i8);
|
||||
let packed = packing::pack_int2(a, b, c, d);
|
||||
let (ua, ub, uc, ud) = packing::unpack_int2(packed);
|
||||
assert_eq!(a, ua);
|
||||
assert_eq!(b, ub);
|
||||
assert_eq!(c, uc);
|
||||
// -2 in 2-bit is 10 binary, which unpacks to -2 (sign extended)
|
||||
assert_eq!(-2i8, ud);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packing_binary() {
|
||||
let bits = [true, false, true, true, false, false, true, false];
|
||||
let packed = packing::pack_binary(&bits);
|
||||
let unpacked = packing::unpack_binary(packed);
|
||||
assert_eq!(bits, unpacked);
|
||||
}
|
||||
}
|
||||
237
crates/ruvector-fpga-transformer/src/quant/qformat.rs
Normal file
237
crates/ruvector-fpga-transformer/src/quant/qformat.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
//! Quantization format operations
|
||||
|
||||
use crate::types::QuantSpec;
|
||||
|
||||
/// Quantize f32 values to i8
|
||||
pub fn quantize_i8(values: &[f32], spec: &QuantSpec) -> Vec<i8> {
|
||||
let scale = spec.scale_q as f32 / 65536.0;
|
||||
let zero = spec.zero_q as f32 / 65536.0;
|
||||
|
||||
values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let quantized = ((v - zero) / scale).round();
|
||||
quantized.clamp(-128.0, 127.0) as i8
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Quantize f32 values to i16
|
||||
pub fn quantize_i16(values: &[f32]) -> Vec<i16> {
|
||||
// Find min/max
|
||||
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Handle edge case
|
||||
if (max - min).abs() < f32::EPSILON {
|
||||
return vec![0i16; values.len()];
|
||||
}
|
||||
|
||||
let scale = (max - min) / 65535.0;
|
||||
|
||||
values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let normalized = (v - min) / scale - 32768.0;
|
||||
normalized.round().clamp(-32768.0, 32767.0) as i16
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Dequantize i8 values to f32
|
||||
pub fn dequantize_i8(values: &[u8], spec: &QuantSpec) -> Vec<f32> {
|
||||
let scale = spec.scale_q as f32 / 65536.0;
|
||||
let zero = spec.zero_q as f32 / 65536.0;
|
||||
|
||||
values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let signed = v as i8;
|
||||
signed as f32 * scale + zero
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Dequantize i16 values to f32
|
||||
pub fn dequantize_i16(values: &[i16], scale: f32, zero: f32) -> Vec<f32> {
|
||||
values.iter().map(|&v| v as f32 * scale + zero).collect()
|
||||
}
|
||||
|
||||
/// Symmetric quantization (zero point = 0)
|
||||
pub fn quantize_symmetric_i8(values: &[f32]) -> (Vec<i8>, f32) {
|
||||
let abs_max = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
|
||||
|
||||
if abs_max < f32::EPSILON {
|
||||
return (vec![0i8; values.len()], 1.0);
|
||||
}
|
||||
|
||||
let scale = abs_max / 127.0;
|
||||
|
||||
let quantized = values
|
||||
.iter()
|
||||
.map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
|
||||
.collect();
|
||||
|
||||
(quantized, scale)
|
||||
}
|
||||
|
||||
/// Asymmetric quantization (uses full i8 range)
|
||||
pub fn quantize_asymmetric_i8(values: &[f32]) -> (Vec<u8>, f32, i32) {
|
||||
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
if (max - min).abs() < f32::EPSILON {
|
||||
return (vec![0u8; values.len()], 1.0, 0);
|
||||
}
|
||||
|
||||
let scale = (max - min) / 255.0;
|
||||
let zero_point = (-min / scale).round() as i32;
|
||||
|
||||
let quantized = values
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
let q = (v / scale).round() as i32 + zero_point;
|
||||
q.clamp(0, 255) as u8
|
||||
})
|
||||
.collect();
|
||||
|
||||
(quantized, scale, zero_point)
|
||||
}
|
||||
|
||||
/// Per-channel quantization for weights
|
||||
pub fn quantize_per_channel_i8(weights: &[f32], out_channels: usize) -> (Vec<i8>, Vec<f32>) {
|
||||
let in_features = weights.len() / out_channels;
|
||||
let mut quantized = Vec::with_capacity(weights.len());
|
||||
let mut scales = Vec::with_capacity(out_channels);
|
||||
|
||||
for c in 0..out_channels {
|
||||
let start = c * in_features;
|
||||
let end = start + in_features;
|
||||
let channel_weights = &weights[start..end];
|
||||
|
||||
let (q, scale) = quantize_symmetric_i8(channel_weights);
|
||||
quantized.extend(q);
|
||||
scales.push(scale);
|
||||
}
|
||||
|
||||
(quantized, scales)
|
||||
}
|
||||
|
||||
/// Blocked quantization for hardware efficiency
|
||||
pub fn quantize_blocked_i8(values: &[f32], block_size: usize) -> (Vec<i8>, Vec<f32>, Vec<i8>) {
|
||||
let num_blocks = (values.len() + block_size - 1) / block_size;
|
||||
let mut quantized = Vec::with_capacity(values.len());
|
||||
let mut scales = Vec::with_capacity(num_blocks);
|
||||
let mut zeros = Vec::with_capacity(num_blocks);
|
||||
|
||||
for block_idx in 0..num_blocks {
|
||||
let start = block_idx * block_size;
|
||||
let end = (start + block_size).min(values.len());
|
||||
let block = &values[start..end];
|
||||
|
||||
let (q, scale) = quantize_symmetric_i8(block);
|
||||
quantized.extend(q);
|
||||
scales.push(scale);
|
||||
zeros.push(0i8);
|
||||
}
|
||||
|
||||
(quantized, scales, zeros)
|
||||
}
|
||||
|
||||
/// Matrix quantization for GEMM
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizedMatrix {
|
||||
/// Quantized values
|
||||
pub data: Vec<i8>,
|
||||
/// Rows
|
||||
pub rows: usize,
|
||||
/// Columns
|
||||
pub cols: usize,
|
||||
/// Per-row scales (for per-channel quantization)
|
||||
pub scales: Vec<f32>,
|
||||
/// Per-row zero points
|
||||
pub zeros: Vec<i8>,
|
||||
}
|
||||
|
||||
impl QuantizedMatrix {
|
||||
/// Quantize a matrix with per-row scaling
|
||||
pub fn from_f32(data: &[f32], rows: usize, cols: usize) -> Self {
|
||||
assert_eq!(data.len(), rows * cols);
|
||||
|
||||
let (quantized, scales) = quantize_per_channel_i8(data, rows);
|
||||
|
||||
Self {
|
||||
data: quantized,
|
||||
rows,
|
||||
cols,
|
||||
scales,
|
||||
zeros: vec![0i8; rows],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a row
|
||||
pub fn row(&self, idx: usize) -> &[i8] {
|
||||
let start = idx * self.cols;
|
||||
&self.data[start..start + self.cols]
|
||||
}
|
||||
|
||||
/// Dequantize to f32
|
||||
pub fn to_f32(&self) -> Vec<f32> {
|
||||
let mut result = Vec::with_capacity(self.rows * self.cols);
|
||||
|
||||
for r in 0..self.rows {
|
||||
let scale = self.scales[r];
|
||||
let zero = self.zeros[r] as f32;
|
||||
for &v in self.row(r) {
|
||||
result.push((v as f32 - zero) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::QuantSpec;
|
||||
|
||||
#[test]
|
||||
fn test_quantize_symmetric() {
|
||||
let values = vec![1.0, -1.0, 0.5, -0.5, 0.0];
|
||||
let (quantized, scale) = quantize_symmetric_i8(&values);
|
||||
|
||||
// Dequantize and check
|
||||
for (i, &q) in quantized.iter().enumerate() {
|
||||
let dequant = q as f32 * scale;
|
||||
assert!((dequant - values[i]).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_asymmetric() {
|
||||
let values = vec![0.0, 0.5, 1.0, 1.5, 2.0];
|
||||
let (quantized, scale, zero) = quantize_asymmetric_i8(&values);
|
||||
|
||||
// Dequantize and check
|
||||
for (i, &q) in quantized.iter().enumerate() {
|
||||
let dequant = (q as i32 - zero) as f32 * scale;
|
||||
assert!((dequant - values[i]).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantized_matrix() {
|
||||
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.1 - 3.2).collect();
|
||||
let matrix = QuantizedMatrix::from_f32(&data, 8, 8);
|
||||
|
||||
assert_eq!(matrix.rows, 8);
|
||||
assert_eq!(matrix.cols, 8);
|
||||
assert_eq!(matrix.scales.len(), 8);
|
||||
|
||||
let dequantized = matrix.to_f32();
|
||||
for (orig, deq) in data.iter().zip(dequantized.iter()) {
|
||||
assert!((orig - deq).abs() < 0.2);
|
||||
}
|
||||
}
|
||||
}
|
||||
624
crates/ruvector-fpga-transformer/src/types.rs
Normal file
624
crates/ruvector-fpga-transformer/src/types.rs
Normal file
@@ -0,0 +1,624 @@
|
||||
//! Core types for FPGA Transformer backend
|
||||
//!
|
||||
//! All types are designed for deterministic, allocation-free inference
|
||||
//! with explicit quantization metadata.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Unique identifier for a loaded model (SHA-256 hash)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ModelId(pub [u8; 32]);
|
||||
|
||||
impl ModelId {
|
||||
/// Create a new ModelId from bytes
|
||||
pub const fn new(bytes: [u8; 32]) -> Self {
|
||||
Self(bytes)
|
||||
}
|
||||
|
||||
/// Create a zero ModelId (for testing)
|
||||
pub const fn zero() -> Self {
|
||||
Self([0u8; 32])
|
||||
}
|
||||
|
||||
/// Get the bytes of the ModelId
|
||||
pub const fn as_bytes(&self) -> &[u8; 32] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Convert to hex string
|
||||
pub fn to_hex(&self) -> String {
|
||||
self.0.iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
|
||||
/// Parse from hex string
|
||||
pub fn from_hex(s: &str) -> Option<Self> {
|
||||
if s.len() != 64 {
|
||||
return None;
|
||||
}
|
||||
let mut bytes = [0u8; 32];
|
||||
for (i, chunk) in s.as_bytes().chunks(2).enumerate() {
|
||||
let hex_str = std::str::from_utf8(chunk).ok()?;
|
||||
bytes[i] = u8::from_str_radix(hex_str, 16).ok()?;
|
||||
}
|
||||
Some(Self(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.to_hex())
|
||||
}
|
||||
}
|
||||
|
||||
/// Fixed shape specification for transformer inference
|
||||
///
|
||||
/// All dimensions are compile-time or model-time constants.
|
||||
/// This enables zero-allocation inference paths.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct FixedShape {
|
||||
/// Maximum sequence length
|
||||
pub seq_len: u16,
|
||||
/// Model/hidden dimension
|
||||
pub d_model: u16,
|
||||
/// Number of attention heads
|
||||
pub heads: u8,
|
||||
/// Dimension per head
|
||||
pub d_head: u16,
|
||||
/// Vocabulary size
|
||||
pub vocab: u32,
|
||||
}
|
||||
|
||||
impl FixedShape {
|
||||
/// Create a new FixedShape
|
||||
pub const fn new(seq_len: u16, d_model: u16, heads: u8, d_head: u16, vocab: u32) -> Self {
|
||||
Self {
|
||||
seq_len,
|
||||
d_model,
|
||||
heads,
|
||||
d_head,
|
||||
vocab,
|
||||
}
|
||||
}
|
||||
|
||||
/// Micro configuration for edge/WASM deployment
|
||||
pub const fn micro() -> Self {
|
||||
Self {
|
||||
seq_len: 32,
|
||||
d_model: 64,
|
||||
heads: 4,
|
||||
d_head: 16,
|
||||
vocab: 4096,
|
||||
}
|
||||
}
|
||||
|
||||
/// Small configuration for embedded
|
||||
pub const fn small() -> Self {
|
||||
Self {
|
||||
seq_len: 64,
|
||||
d_model: 128,
|
||||
heads: 4,
|
||||
d_head: 32,
|
||||
vocab: 8192,
|
||||
}
|
||||
}
|
||||
|
||||
/// Baseline configuration
|
||||
pub const fn baseline() -> Self {
|
||||
Self {
|
||||
seq_len: 128,
|
||||
d_model: 256,
|
||||
heads: 8,
|
||||
d_head: 32,
|
||||
vocab: 32000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate total parameters for embedding layer
|
||||
pub const fn embedding_params(&self) -> usize {
|
||||
self.vocab as usize * self.d_model as usize
|
||||
}
|
||||
|
||||
/// Calculate parameters per attention layer
|
||||
pub const fn attention_params(&self) -> usize {
|
||||
// Q, K, V projections + output projection
|
||||
4 * (self.d_model as usize * self.d_model as usize)
|
||||
}
|
||||
|
||||
/// Calculate parameters per FFN layer (assuming 4x expansion)
|
||||
pub const fn ffn_params(&self) -> usize {
|
||||
2 * (self.d_model as usize * 4 * self.d_model as usize)
|
||||
}
|
||||
|
||||
/// Validate shape consistency
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.d_model as usize != self.heads as usize * self.d_head as usize {
|
||||
return Err(format!(
|
||||
"d_model ({}) must equal heads ({}) * d_head ({})",
|
||||
self.d_model, self.heads, self.d_head
|
||||
));
|
||||
}
|
||||
if self.seq_len == 0 {
|
||||
return Err("seq_len must be > 0".into());
|
||||
}
|
||||
if self.vocab == 0 {
|
||||
return Err("vocab must be > 0".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FixedShape {
|
||||
fn default() -> Self {
|
||||
Self::baseline()
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantization specification
|
||||
///
|
||||
/// Explicit quantization metadata ensuring reproducible inference.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct QuantSpec {
|
||||
/// Weight bit width (1, 2, 4, 8)
|
||||
pub w_bits: u8,
|
||||
/// Activation bit width (4, 8, 16)
|
||||
pub a_bits: u8,
|
||||
/// Scale factor (Q16.16 fixed point)
|
||||
pub scale_q: i32,
|
||||
/// Zero point (Q16.16 fixed point)
|
||||
pub zero_q: i32,
|
||||
/// Memory layout
|
||||
pub layout: Layout,
|
||||
}
|
||||
|
||||
impl QuantSpec {
|
||||
/// Create a new QuantSpec
|
||||
pub const fn new(w_bits: u8, a_bits: u8, scale_q: i32, zero_q: i32, layout: Layout) -> Self {
|
||||
Self {
|
||||
w_bits,
|
||||
a_bits,
|
||||
scale_q,
|
||||
zero_q,
|
||||
layout,
|
||||
}
|
||||
}
|
||||
|
||||
/// INT4 weights, INT8 activations (common for edge)
|
||||
pub const fn int4_int8() -> Self {
|
||||
Self {
|
||||
w_bits: 4,
|
||||
a_bits: 8,
|
||||
scale_q: 1 << 16, // 1.0 in Q16.16
|
||||
zero_q: 0,
|
||||
layout: Layout::Blocked { block: 32 },
|
||||
}
|
||||
}
|
||||
|
||||
/// INT8 weights and activations
|
||||
pub const fn int8() -> Self {
|
||||
Self {
|
||||
w_bits: 8,
|
||||
a_bits: 8,
|
||||
scale_q: 1 << 16,
|
||||
zero_q: 0,
|
||||
layout: Layout::RowMajor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Bytes per weight element
|
||||
pub const fn bytes_per_weight(&self) -> usize {
|
||||
match self.w_bits {
|
||||
1 => 1, // Packed 8 per byte, but minimum 1 byte
|
||||
2 => 1, // Packed 4 per byte
|
||||
4 => 1, // Packed 2 per byte
|
||||
8 => 1,
|
||||
16 => 2,
|
||||
_ => 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Weights packed per byte
|
||||
pub const fn weights_per_byte(&self) -> usize {
|
||||
match self.w_bits {
|
||||
1 => 8,
|
||||
2 => 4,
|
||||
4 => 2,
|
||||
_ => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QuantSpec {
|
||||
fn default() -> Self {
|
||||
Self::int8()
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory layout for quantized tensors
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Layout {
|
||||
/// Standard row-major layout
|
||||
RowMajor,
|
||||
/// Blocked layout for SIMD/hardware efficiency
|
||||
Blocked { block: u16 },
|
||||
/// Heads interleaved for attention computation
|
||||
InterleavedHeads,
|
||||
}
|
||||
|
||||
impl Default for Layout {
|
||||
fn default() -> Self {
|
||||
Self::RowMajor
|
||||
}
|
||||
}
|
||||
|
||||
/// Hint for gating decisions
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct GateHint {
|
||||
/// Coherence score (Q8.8 fixed point, higher = more coherent)
|
||||
pub coherence_score_q: i16,
|
||||
/// Whether a boundary was crossed in the input
|
||||
pub boundary_crossed: bool,
|
||||
/// Maximum compute class allowed
|
||||
pub max_compute_class: ComputeClass,
|
||||
}
|
||||
|
||||
impl GateHint {
|
||||
/// Create a new GateHint
|
||||
pub const fn new(
|
||||
coherence_score_q: i16,
|
||||
boundary_crossed: bool,
|
||||
max_compute_class: ComputeClass,
|
||||
) -> Self {
|
||||
Self {
|
||||
coherence_score_q,
|
||||
boundary_crossed,
|
||||
max_compute_class,
|
||||
}
|
||||
}
|
||||
|
||||
/// Default hint allowing full computation
|
||||
pub const fn allow_all() -> Self {
|
||||
Self {
|
||||
coherence_score_q: i16::MAX,
|
||||
boundary_crossed: false,
|
||||
max_compute_class: ComputeClass::Deliberative,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reflex-only hint for fast path
|
||||
pub const fn reflex_only() -> Self {
|
||||
Self {
|
||||
coherence_score_q: 0,
|
||||
boundary_crossed: false,
|
||||
max_compute_class: ComputeClass::Reflex,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GateHint {
|
||||
fn default() -> Self {
|
||||
Self::allow_all()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute class for tiered inference
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
pub enum ComputeClass {
|
||||
/// Fastest path, minimal computation (1-2 layers)
|
||||
Reflex = 0,
|
||||
/// Medium path, associative memory (4-6 layers)
|
||||
Associative = 1,
|
||||
/// Full deliberative computation (all layers)
|
||||
Deliberative = 2,
|
||||
}
|
||||
|
||||
impl ComputeClass {
|
||||
/// Convert from u8
|
||||
pub const fn from_u8(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0 => Some(Self::Reflex),
|
||||
1 => Some(Self::Associative),
|
||||
2 => Some(Self::Deliberative),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ComputeClass {
|
||||
fn default() -> Self {
|
||||
Self::Deliberative
|
||||
}
|
||||
}
|
||||
|
||||
/// Inference request
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InferenceRequest<'a> {
|
||||
/// Model to use
|
||||
pub model: ModelId,
|
||||
/// Expected shape
|
||||
pub shape: FixedShape,
|
||||
/// Input token IDs (length = seq_len)
|
||||
pub tokens: &'a [u16],
|
||||
/// Attention mask (length = seq_len or seq_len^2)
|
||||
pub attn_mask: &'a [u8],
|
||||
/// Gating hint for coherence control
|
||||
pub gate_hint: GateHint,
|
||||
}
|
||||
|
||||
impl<'a> InferenceRequest<'a> {
|
||||
/// Create a new InferenceRequest
|
||||
pub fn new(
|
||||
model: ModelId,
|
||||
shape: FixedShape,
|
||||
tokens: &'a [u16],
|
||||
attn_mask: &'a [u8],
|
||||
gate_hint: GateHint,
|
||||
) -> Self {
|
||||
Self {
|
||||
model,
|
||||
shape,
|
||||
tokens,
|
||||
attn_mask,
|
||||
gate_hint,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the request
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
if self.tokens.len() != self.shape.seq_len as usize {
|
||||
return Err(crate::error::Error::InputLengthMismatch {
|
||||
expected: self.shape.seq_len as usize,
|
||||
actual: self.tokens.len(),
|
||||
});
|
||||
}
|
||||
if self.attn_mask.len() != self.shape.seq_len as usize
|
||||
&& self.attn_mask.len() != (self.shape.seq_len as usize).pow(2)
|
||||
{
|
||||
return Err(crate::error::Error::InputLengthMismatch {
|
||||
expected: self.shape.seq_len as usize,
|
||||
actual: self.attn_mask.len(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Inference result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InferenceResult {
|
||||
/// Full logits (quantized, length = vocab) or empty if topk_only
|
||||
pub logits_q: Vec<i16>,
|
||||
/// Top-K predictions (token_id, logit_q)
|
||||
pub topk: Option<Vec<(u16, i16)>>,
|
||||
/// Witness log for audit trail
|
||||
pub witness: WitnessLog,
|
||||
}
|
||||
|
||||
impl InferenceResult {
|
||||
/// Create a new InferenceResult
|
||||
pub fn new(logits_q: Vec<i16>, topk: Option<Vec<(u16, i16)>>, witness: WitnessLog) -> Self {
|
||||
Self {
|
||||
logits_q,
|
||||
topk,
|
||||
witness,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the argmax token
|
||||
pub fn argmax(&self) -> Option<u16> {
|
||||
if let Some(ref topk) = self.topk {
|
||||
topk.first().map(|(token, _)| *token)
|
||||
} else if !self.logits_q.is_empty() {
|
||||
self.logits_q
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by_key(|(_, &v)| v)
|
||||
.map(|(i, _)| i as u16)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Witness log for audit trail and ReasoningBank integration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct WitnessLog {
|
||||
/// Hash of the model used
|
||||
pub model_hash: [u8; 32],
|
||||
/// Hash of quantization parameters used
|
||||
pub quant_hash: [u8; 32],
|
||||
/// Backend that executed the inference
|
||||
pub backend: BackendKind,
|
||||
/// Compute cycles used (FPGA) or 0 (sim)
|
||||
pub cycles: u32,
|
||||
/// Latency in nanoseconds
|
||||
pub latency_ns: u32,
|
||||
/// Gate decision made
|
||||
pub gate_decision: GateDecision,
|
||||
}
|
||||
|
||||
impl WitnessLog {
|
||||
/// Create a new WitnessLog
|
||||
pub fn new(
|
||||
model_hash: [u8; 32],
|
||||
quant_hash: [u8; 32],
|
||||
backend: BackendKind,
|
||||
cycles: u32,
|
||||
latency_ns: u32,
|
||||
gate_decision: GateDecision,
|
||||
) -> Self {
|
||||
Self {
|
||||
model_hash,
|
||||
quant_hash,
|
||||
backend,
|
||||
cycles,
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an empty witness (for testing)
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
model_hash: [0u8; 32],
|
||||
quant_hash: [0u8; 32],
|
||||
backend: BackendKind::NativeSim,
|
||||
cycles: 0,
|
||||
latency_ns: 0,
|
||||
gate_decision: GateDecision::RanFull,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend types
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum BackendKind {
|
||||
/// PCIe-connected FPGA
|
||||
FpgaPcie,
|
||||
/// FPGA via local daemon
|
||||
FpgaDaemon,
|
||||
/// WASM simulator
|
||||
WasmSim,
|
||||
/// Native Rust simulator
|
||||
NativeSim,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::FpgaPcie => write!(f, "fpga_pcie"),
|
||||
Self::FpgaDaemon => write!(f, "fpga_daemon"),
|
||||
Self::WasmSim => write!(f, "wasm_sim"),
|
||||
Self::NativeSim => write!(f, "native_sim"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gate decision outcome
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GateDecision {
|
||||
/// Full inference completed
|
||||
RanFull,
|
||||
/// Early exit at specified layer
|
||||
EarlyExit { layer: u8 },
|
||||
/// Inference was skipped
|
||||
Skipped { reason: SkipReason },
|
||||
}
|
||||
|
||||
impl GateDecision {
|
||||
/// Check if inference actually ran
|
||||
pub const fn did_run(&self) -> bool {
|
||||
!matches!(self, Self::Skipped { .. })
|
||||
}
|
||||
|
||||
/// Get the exit layer (full = max layers)
|
||||
pub const fn exit_layer(&self, max_layers: u8) -> u8 {
|
||||
match self {
|
||||
Self::RanFull => max_layers,
|
||||
Self::EarlyExit { layer } => *layer,
|
||||
Self::Skipped { .. } => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reason for skipping inference
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SkipReason {
|
||||
/// Coherence score too low
|
||||
LowCoherence,
|
||||
/// Policy denied the inference
|
||||
PolicyDenied,
|
||||
/// Compute budget exceeded
|
||||
BudgetExceeded,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SkipReason {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::LowCoherence => write!(f, "low_coherence"),
|
||||
Self::PolicyDenied => write!(f, "policy_denied"),
|
||||
Self::BudgetExceeded => write!(f, "budget_exceeded"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantized tensor wrapper
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizedTensor {
|
||||
/// Raw quantized data
|
||||
pub data: Vec<u8>,
|
||||
/// Quantization specification
|
||||
pub spec: QuantSpec,
|
||||
/// Tensor shape (row-major)
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
impl QuantizedTensor {
|
||||
/// Create a new quantized tensor
|
||||
pub fn new(data: Vec<u8>, spec: QuantSpec, shape: Vec<usize>) -> Self {
|
||||
Self { data, spec, shape }
|
||||
}
|
||||
|
||||
/// Total number of elements
|
||||
pub fn numel(&self) -> usize {
|
||||
self.shape.iter().product()
|
||||
}
|
||||
|
||||
/// Expected data size in bytes
|
||||
pub fn expected_bytes(&self) -> usize {
|
||||
let numel = self.numel();
|
||||
(numel + self.spec.weights_per_byte() - 1) / self.spec.weights_per_byte()
|
||||
}
|
||||
|
||||
/// Validate tensor integrity
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
let expected = self.expected_bytes();
|
||||
if self.data.len() != expected {
|
||||
return Err(crate::error::Error::QuantizationError(format!(
|
||||
"Data size mismatch: expected {} bytes, got {}",
|
||||
expected,
|
||||
self.data.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_id_hex_roundtrip() {
|
||||
let bytes = [0x12u8; 32];
|
||||
let id = ModelId::new(bytes);
|
||||
let hex = id.to_hex();
|
||||
let parsed = ModelId::from_hex(&hex).unwrap();
|
||||
assert_eq!(id, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fixed_shape_validate() {
|
||||
let valid = FixedShape::new(64, 256, 8, 32, 32000);
|
||||
assert!(valid.validate().is_ok());
|
||||
|
||||
let invalid = FixedShape::new(64, 256, 8, 16, 32000); // 8*16 != 256
|
||||
assert!(invalid.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quant_spec_bytes() {
|
||||
assert_eq!(QuantSpec::int8().weights_per_byte(), 1);
|
||||
assert_eq!(QuantSpec::int4_int8().weights_per_byte(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_decision() {
|
||||
assert!(GateDecision::RanFull.did_run());
|
||||
assert!(GateDecision::EarlyExit { layer: 3 }.did_run());
|
||||
assert!(!GateDecision::Skipped {
|
||||
reason: SkipReason::LowCoherence
|
||||
}
|
||||
.did_run());
|
||||
}
|
||||
}
|
||||
215
crates/ruvector-fpga-transformer/src/witness/hash.rs
Normal file
215
crates/ruvector-fpga-transformer/src/witness/hash.rs
Normal file
@@ -0,0 +1,215 @@
|
||||
//! Witness hashing for integrity verification
|
||||
|
||||
use crate::types::WitnessLog;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Compute a hash of the witness log for integrity verification
|
||||
pub fn compute_witness_hash(witness: &WitnessLog) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
hasher.update(&witness.model_hash);
|
||||
hasher.update(&witness.quant_hash);
|
||||
hasher.update(&[witness.backend as u8]);
|
||||
hasher.update(&witness.cycles.to_le_bytes());
|
||||
hasher.update(&witness.latency_ns.to_le_bytes());
|
||||
|
||||
// Hash gate decision
|
||||
match witness.gate_decision {
|
||||
crate::types::GateDecision::RanFull => {
|
||||
hasher.update(&[0u8]);
|
||||
}
|
||||
crate::types::GateDecision::EarlyExit { layer } => {
|
||||
hasher.update(&[1u8, layer]);
|
||||
}
|
||||
crate::types::GateDecision::Skipped { reason } => {
|
||||
hasher.update(&[2u8, reason as u8]);
|
||||
}
|
||||
}
|
||||
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Verify a witness hash
|
||||
pub fn verify_witness_hash(witness: &WitnessLog, expected: &[u8; 32]) -> bool {
|
||||
let computed = compute_witness_hash(witness);
|
||||
computed == *expected
|
||||
}
|
||||
|
||||
/// Compute a combined hash for a sequence of witnesses
|
||||
/// Useful for verifying an entire inference chain
|
||||
pub fn compute_chain_hash(witnesses: &[WitnessLog]) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
|
||||
for witness in witnesses {
|
||||
let witness_hash = compute_witness_hash(witness);
|
||||
hasher.update(&witness_hash);
|
||||
}
|
||||
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Witness proof for verification
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WitnessProof {
|
||||
/// Hash of the witness
|
||||
pub hash: [u8; 32],
|
||||
/// Timestamp when proof was created
|
||||
pub timestamp_ns: u64,
|
||||
/// Optional signature
|
||||
pub signature: Option<[u8; 64]>,
|
||||
}
|
||||
|
||||
impl WitnessProof {
|
||||
/// Create a new proof from a witness
|
||||
pub fn new(witness: &WitnessLog) -> Self {
|
||||
Self {
|
||||
hash: compute_witness_hash(witness),
|
||||
timestamp_ns: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos() as u64)
|
||||
.unwrap_or(0),
|
||||
signature: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a proof with signature
|
||||
#[cfg(feature = "sign")]
|
||||
pub fn signed(witness: &WitnessLog, secret_key: &[u8; 32]) -> Self {
|
||||
use ed25519_dalek::{Signer, SigningKey};
|
||||
|
||||
let hash = compute_witness_hash(witness);
|
||||
let timestamp_ns = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos() as u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
// Create message to sign
|
||||
let mut message = [0u8; 40];
|
||||
message[..32].copy_from_slice(&hash);
|
||||
message[32..40].copy_from_slice(×tamp_ns.to_le_bytes());
|
||||
|
||||
let signing_key = SigningKey::from_bytes(secret_key);
|
||||
let signature = signing_key.sign(&message);
|
||||
|
||||
Self {
|
||||
hash,
|
||||
timestamp_ns,
|
||||
signature: Some(signature.to_bytes()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify the proof against a witness
|
||||
pub fn verify(&self, witness: &WitnessLog) -> bool {
|
||||
verify_witness_hash(witness, &self.hash)
|
||||
}
|
||||
|
||||
/// Verify the signature
|
||||
#[cfg(feature = "sign")]
|
||||
pub fn verify_signature(&self, pubkey: &[u8; 32]) -> bool {
|
||||
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
|
||||
|
||||
let Some(sig_bytes) = self.signature else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let Ok(verifying_key) = VerifyingKey::from_bytes(pubkey) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let signature = Signature::from_bytes(&sig_bytes);
|
||||
|
||||
let mut message = [0u8; 40];
|
||||
message[..32].copy_from_slice(&self.hash);
|
||||
message[32..40].copy_from_slice(&self.timestamp_ns.to_le_bytes());
|
||||
|
||||
verifying_key.verify(&message, &signature).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{BackendKind, GateDecision};
|
||||
|
||||
#[test]
|
||||
fn test_witness_hash_deterministic() {
|
||||
let witness = WitnessLog::new(
|
||||
[1u8; 32],
|
||||
[2u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
1000,
|
||||
50000,
|
||||
GateDecision::RanFull,
|
||||
);
|
||||
|
||||
let hash1 = compute_witness_hash(&witness);
|
||||
let hash2 = compute_witness_hash(&witness);
|
||||
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_hash_changes() {
|
||||
let witness1 = WitnessLog::new(
|
||||
[1u8; 32],
|
||||
[2u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
1000,
|
||||
50000,
|
||||
GateDecision::RanFull,
|
||||
);
|
||||
|
||||
let witness2 = WitnessLog::new(
|
||||
[1u8; 32],
|
||||
[2u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
1001, // Different cycles
|
||||
50000,
|
||||
GateDecision::RanFull,
|
||||
);
|
||||
|
||||
let hash1 = compute_witness_hash(&witness1);
|
||||
let hash2 = compute_witness_hash(&witness2);
|
||||
|
||||
assert_ne!(hash1, hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_witness_hash() {
|
||||
let witness = WitnessLog::empty();
|
||||
let hash = compute_witness_hash(&witness);
|
||||
|
||||
assert!(verify_witness_hash(&witness, &hash));
|
||||
assert!(!verify_witness_hash(&witness, &[0u8; 32]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_hash() {
|
||||
let witnesses: Vec<WitnessLog> = (0..5)
|
||||
.map(|i| {
|
||||
WitnessLog::new(
|
||||
[i as u8; 32],
|
||||
[0u8; 32],
|
||||
BackendKind::NativeSim,
|
||||
i * 100,
|
||||
i * 1000,
|
||||
GateDecision::RanFull,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let chain_hash1 = compute_chain_hash(&witnesses);
|
||||
let chain_hash2 = compute_chain_hash(&witnesses);
|
||||
|
||||
assert_eq!(chain_hash1, chain_hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_proof() {
|
||||
let witness = WitnessLog::empty();
|
||||
let proof = WitnessProof::new(&witness);
|
||||
|
||||
assert!(proof.verify(&witness));
|
||||
assert!(proof.timestamp_ns > 0);
|
||||
}
|
||||
}
|
||||
311
crates/ruvector-fpga-transformer/src/witness/log.rs
Normal file
311
crates/ruvector-fpga-transformer/src/witness/log.rs
Normal file
@@ -0,0 +1,311 @@
|
||||
//! Witness log builder and utilities
|
||||
|
||||
use crate::types::{BackendKind, GateDecision, WitnessLog};
|
||||
use std::time::Instant;
|
||||
|
||||
/// Builder for creating witness logs
|
||||
pub struct WitnessBuilder {
|
||||
model_hash: [u8; 32],
|
||||
quant_hash: [u8; 32],
|
||||
backend: BackendKind,
|
||||
start_time: Instant,
|
||||
cycles: u32,
|
||||
gate_decision: GateDecision,
|
||||
}
|
||||
|
||||
impl WitnessBuilder {
|
||||
/// Start building a new witness
|
||||
pub fn new(backend: BackendKind) -> Self {
|
||||
Self {
|
||||
model_hash: [0u8; 32],
|
||||
quant_hash: [0u8; 32],
|
||||
backend,
|
||||
start_time: Instant::now(),
|
||||
cycles: 0,
|
||||
gate_decision: GateDecision::RanFull,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set model hash
|
||||
pub fn model_hash(mut self, hash: [u8; 32]) -> Self {
|
||||
self.model_hash = hash;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set quantization hash
|
||||
pub fn quant_hash(mut self, hash: [u8; 32]) -> Self {
|
||||
self.quant_hash = hash;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set compute cycles
|
||||
pub fn cycles(mut self, cycles: u32) -> Self {
|
||||
self.cycles = cycles;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set gate decision
|
||||
pub fn gate_decision(mut self, decision: GateDecision) -> Self {
|
||||
self.gate_decision = decision;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the witness log
|
||||
pub fn build(self) -> WitnessLog {
|
||||
let latency_ns = self.start_time.elapsed().as_nanos() as u32;
|
||||
|
||||
WitnessLog::new(
|
||||
self.model_hash,
|
||||
self.quant_hash,
|
||||
self.backend,
|
||||
self.cycles,
|
||||
latency_ns,
|
||||
self.gate_decision,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl WitnessLog {
|
||||
/// Convert to compact bytes for storage
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
let mut bytes = Vec::with_capacity(80);
|
||||
|
||||
bytes.extend_from_slice(&self.model_hash);
|
||||
bytes.extend_from_slice(&self.quant_hash);
|
||||
bytes.push(self.backend as u8);
|
||||
bytes.extend_from_slice(&self.cycles.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.latency_ns.to_le_bytes());
|
||||
|
||||
// Encode gate decision
|
||||
match self.gate_decision {
|
||||
GateDecision::RanFull => {
|
||||
bytes.push(0);
|
||||
bytes.push(0);
|
||||
}
|
||||
GateDecision::EarlyExit { layer } => {
|
||||
bytes.push(1);
|
||||
bytes.push(layer);
|
||||
}
|
||||
GateDecision::Skipped { reason } => {
|
||||
bytes.push(2);
|
||||
bytes.push(reason as u8);
|
||||
}
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Parse from bytes
|
||||
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
if bytes.len() < 75 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let model_hash: [u8; 32] = bytes[0..32].try_into().ok()?;
|
||||
let quant_hash: [u8; 32] = bytes[32..64].try_into().ok()?;
|
||||
|
||||
let backend = match bytes[64] {
|
||||
0 => BackendKind::FpgaPcie,
|
||||
1 => BackendKind::FpgaDaemon,
|
||||
2 => BackendKind::WasmSim,
|
||||
3 => BackendKind::NativeSim,
|
||||
_ => BackendKind::NativeSim,
|
||||
};
|
||||
|
||||
let cycles = u32::from_le_bytes(bytes[65..69].try_into().ok()?);
|
||||
let latency_ns = u32::from_le_bytes(bytes[69..73].try_into().ok()?);
|
||||
|
||||
let gate_decision = match bytes[73] {
|
||||
0 => GateDecision::RanFull,
|
||||
1 => GateDecision::EarlyExit { layer: bytes[74] },
|
||||
2 => GateDecision::Skipped {
|
||||
reason: match bytes[74] {
|
||||
0 => crate::types::SkipReason::LowCoherence,
|
||||
1 => crate::types::SkipReason::PolicyDenied,
|
||||
_ => crate::types::SkipReason::BudgetExceeded,
|
||||
},
|
||||
},
|
||||
_ => GateDecision::RanFull,
|
||||
};
|
||||
|
||||
Some(Self {
|
||||
model_hash,
|
||||
quant_hash,
|
||||
backend,
|
||||
cycles,
|
||||
latency_ns,
|
||||
gate_decision,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get latency in microseconds
|
||||
pub fn latency_us(&self) -> f64 {
|
||||
self.latency_ns as f64 / 1000.0
|
||||
}
|
||||
|
||||
/// Get latency in milliseconds
|
||||
pub fn latency_ms(&self) -> f64 {
|
||||
self.latency_ns as f64 / 1_000_000.0
|
||||
}
|
||||
|
||||
/// Check if this was a successful full inference
|
||||
pub fn is_full_inference(&self) -> bool {
|
||||
matches!(self.gate_decision, GateDecision::RanFull)
|
||||
}
|
||||
|
||||
/// Check if this was an early exit
|
||||
pub fn is_early_exit(&self) -> bool {
|
||||
matches!(self.gate_decision, GateDecision::EarlyExit { .. })
|
||||
}
|
||||
|
||||
/// Check if inference was skipped
|
||||
pub fn is_skipped(&self) -> bool {
|
||||
matches!(self.gate_decision, GateDecision::Skipped { .. })
|
||||
}
|
||||
}
|
||||
|
||||
/// Witness log aggregator for collecting statistics
|
||||
#[derive(Debug, Default)]
|
||||
pub struct WitnessAggregator {
|
||||
/// Total inferences
|
||||
pub count: u64,
|
||||
/// Total cycles
|
||||
pub total_cycles: u64,
|
||||
/// Total latency (ns)
|
||||
pub total_latency_ns: u64,
|
||||
/// Full inference count
|
||||
pub full_count: u64,
|
||||
/// Early exit count
|
||||
pub early_exit_count: u64,
|
||||
/// Skipped count
|
||||
pub skipped_count: u64,
|
||||
/// Sum of squares for variance calculation
|
||||
latency_sq_sum: u128,
|
||||
}
|
||||
|
||||
impl WitnessAggregator {
|
||||
/// Create a new aggregator
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Add a witness to the aggregate
|
||||
pub fn add(&mut self, witness: &WitnessLog) {
|
||||
self.count += 1;
|
||||
self.total_cycles += witness.cycles as u64;
|
||||
self.total_latency_ns += witness.latency_ns as u64;
|
||||
self.latency_sq_sum += (witness.latency_ns as u128).pow(2);
|
||||
|
||||
match witness.gate_decision {
|
||||
GateDecision::RanFull => self.full_count += 1,
|
||||
GateDecision::EarlyExit { .. } => self.early_exit_count += 1,
|
||||
GateDecision::Skipped { .. } => self.skipped_count += 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average latency (ns)
|
||||
pub fn avg_latency_ns(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_latency_ns as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average cycles
|
||||
pub fn avg_cycles(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_cycles as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get latency standard deviation (ns)
|
||||
pub fn latency_std_ns(&self) -> f64 {
|
||||
if self.count <= 1 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean = self.avg_latency_ns();
|
||||
let variance = (self.latency_sq_sum as f64 / self.count as f64) - (mean * mean);
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
/// Get early exit rate
|
||||
pub fn early_exit_rate(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.early_exit_count as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get skip rate
|
||||
pub fn skip_rate(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.skipped_count as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_witness_builder() {
|
||||
let witness = WitnessBuilder::new(BackendKind::NativeSim)
|
||||
.model_hash([1u8; 32])
|
||||
.quant_hash([2u8; 32])
|
||||
.cycles(1000)
|
||||
.gate_decision(GateDecision::RanFull)
|
||||
.build();
|
||||
|
||||
assert_eq!(witness.model_hash, [1u8; 32]);
|
||||
assert_eq!(witness.backend, BackendKind::NativeSim);
|
||||
assert_eq!(witness.cycles, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_bytes_roundtrip() {
|
||||
let witness = WitnessLog::new(
|
||||
[0x42u8; 32],
|
||||
[0x24u8; 32],
|
||||
BackendKind::FpgaDaemon,
|
||||
5000,
|
||||
100_000,
|
||||
GateDecision::EarlyExit { layer: 4 },
|
||||
);
|
||||
|
||||
let bytes = witness.to_bytes();
|
||||
let parsed = WitnessLog::from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(witness.model_hash, parsed.model_hash);
|
||||
assert_eq!(witness.quant_hash, parsed.quant_hash);
|
||||
assert_eq!(witness.backend, parsed.backend);
|
||||
assert_eq!(witness.cycles, parsed.cycles);
|
||||
assert_eq!(witness.latency_ns, parsed.latency_ns);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_aggregator() {
|
||||
let mut agg = WitnessAggregator::new();
|
||||
|
||||
for i in 0..10 {
|
||||
let mut witness = WitnessLog::empty();
|
||||
witness.latency_ns = 1000 * (i + 1);
|
||||
witness.cycles = 100 * (i + 1);
|
||||
if i < 3 {
|
||||
witness.gate_decision = GateDecision::EarlyExit { layer: 2 };
|
||||
}
|
||||
agg.add(&witness);
|
||||
}
|
||||
|
||||
assert_eq!(agg.count, 10);
|
||||
assert_eq!(agg.early_exit_count, 3);
|
||||
assert!((agg.early_exit_rate() - 0.3).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
12
crates/ruvector-fpga-transformer/src/witness/mod.rs
Normal file
12
crates/ruvector-fpga-transformer/src/witness/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Witness logging for audit trails and ReasoningBank integration
|
||||
//!
|
||||
//! Every inference produces a small witness bundle that records
|
||||
//! what happened and enables verification and replay.
|
||||
|
||||
pub mod hash;
|
||||
pub mod log;
|
||||
|
||||
// Re-export WitnessLog from types as the canonical location
|
||||
pub use crate::types::WitnessLog;
|
||||
pub use hash::{compute_witness_hash, verify_witness_hash};
|
||||
pub use log::{WitnessAggregator, WitnessBuilder};
|
||||
Reference in New Issue
Block a user