Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,94 @@
[package]
name = "ruvector-fpga-transformer"
version = "0.1.0"
edition = "2021"
rust-version = "1.77"
authors = ["RuVector Team"]
license = "MIT OR Apache-2.0"
description = "FPGA Transformer backend with deterministic latency, quantization-first design, and coherence gating"
repository = "https://github.com/ruvnet/ruvector"
keywords = ["fpga", "transformer", "inference", "quantization", "low-latency"]
categories = ["algorithms", "embedded", "hardware-support"]
readme = "README.md"
[lib]
crate-type = ["rlib"]
[features]
default = ["daemon", "native_sim", "witness"]
# Backend selection
daemon = []
native_sim = []
pcie = ["memmap2"]
# WASM support
wasm = ["wasm-bindgen", "getrandom/js", "js-sys"]
# Verification
strict_verify = []
witness = []
# Inference options
topk_only = []
lut_softmax = []
pwl_softmax = []
# Development
trace = []
[dependencies]
# Core
thiserror = "2.0"
anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
# Crypto for artifact verification
sha2 = "0.10"
ed25519-dalek = { version = "2.1", features = ["rand_core"] }
rand_core = { version = "0.6", features = ["getrandom"] }
rand = "0.8"
hex = "0.4"
serde_bytes = "0.11"
# Async for daemon communication
tokio = { version = "1.41", features = ["io-util", "net", "sync", "rt"], optional = true }
# Memory mapping for PCIe
memmap2 = { version = "0.9", optional = true }
# WASM bindings
wasm-bindgen = { version = "0.2", optional = true }
js-sys = { version = "0.3", optional = true }
getrandom = { version = "0.2", optional = true }
# Tracing for development
tracing = "0.1"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
proptest = "1.5"
rand = "0.8"
tokio = { version = "1.41", features = ["rt-multi-thread", "macros"] }
[[bench]]
name = "latency"
harness = false
[[bench]]
name = "correctness"
harness = false
[[bench]]
name = "gating"
harness = false
[[example]]
name = "basic_inference"
path = "examples/basic_inference.rs"
[[example]]
name = "daemon_client"
path = "examples/daemon_client.rs"
required-features = ["daemon"]

View File

@@ -0,0 +1,301 @@
# FPGA Transformer
**Run AI models on specialized hardware with predictable, ultra-low latency.**
FPGA Transformer is a Rust library that lets you run transformer neural networks (like those used in ChatGPT, code completion, and other AI applications) on FPGA hardware instead of GPUs. This gives you consistent, predictable response times - essential for real-time applications.
## Why Use This?
| Problem | Solution |
|---------|----------|
| GPU inference has unpredictable latency spikes | FPGAs provide deterministic, bounded timing |
| Cloud AI is too slow for edge devices | Run models locally on low-power FPGAs |
| Need to verify AI didn't hallucinate | Witness logging proves what computation ran |
| Want to skip unnecessary computation | Coherence gating exits early when confident |
## Quick Start
```bash
# Add to your Cargo.toml
cargo add ruvector-fpga-transformer
```
```rust
use ruvector_fpga_transformer::prelude::*;
use std::sync::Arc;
fn main() -> Result<()> {
// Create an engine (uses CPU simulator by default)
let mut engine = Engine::native_sim();
// Load your model
let model_bytes = std::fs::read("model.rvt")?;
let model_id = engine.load_artifact(&model_bytes)?;
// Prepare input tokens
let tokens: Vec<u16> = vec![1, 2, 3, 4, 5]; // Your tokenized input
let mask = vec![1u8; tokens.len()]; // Attention mask
// Run inference
let request = InferenceRequest::new(
model_id,
FixedShape::micro(), // 32 seq, 64 dim, 4096 vocab
&tokens,
&mask,
GateHint::allow_all(),
);
let result = engine.infer(request)?;
// Get predictions
println!("Top prediction: token {}", result.topk.unwrap()[0].0);
println!("Latency: {}ns", result.witness.latency_ns);
Ok(())
}
```
## Features
### Core Capabilities
| Feature | Description |
|---------|-------------|
| **Deterministic Latency** | Fixed execution time - no surprise slowdowns |
| **Quantization-First** | INT4/INT8 math for 4-8x memory savings |
| **Zero Allocation Hot Path** | No garbage collection pauses during inference |
| **Early Exit** | Stop computation when the model is already confident |
| **Witness Logging** | Cryptographic proof of what ran and when |
### Supported Backends
| Backend | Use Case | Feature Flag |
|---------|----------|--------------|
| **NativeSim** | Development & testing on any CPU | `native_sim` |
| **WasmSim** | Run in web browsers | `wasm` |
| **FpgaDaemon** | Connect to FPGA via network | `daemon` |
| **FpgaPcie** | Direct PCIe access (fastest) | `pcie` |
## Model Shapes
Pre-defined configurations for common use cases:
| Shape | Sequence | Dimensions | Vocab | Use Case |
|-------|----------|------------|-------|----------|
| `micro()` | 32 | 64 | 4,096 | Testing, tiny models |
| `small()` | 128 | 256 | 32,768 | Edge devices |
| `medium()` | 512 | 512 | 50,257 | Standard inference |
| `large()` | 2,048 | 1,024 | 50,257 | High-quality output |
```rust
// Use predefined shapes
let shape = FixedShape::small();
// Or create custom
let custom = FixedShape {
seq_len: 256,
d_model: 384,
vocab: 16000,
};
```
## Coherence Gating
Skip unnecessary computation when the model is already confident:
```rust
use ruvector_fpga_transformer::gating::{GatingConfig, PolicyGate};
// Configure early exit behavior
let config = GatingConfig {
min_coherence: 0.7, // Require 70% confidence to exit early
max_compute_class: 3, // Allow up to 3 layers before forcing exit
allow_writes: true, // Allow writes if confidence is high
..Default::default()
};
let gate = PolicyGate::new(config);
```
**Gate Decisions:**
- `RanFull` - Model ran all layers
- `EarlyExit { layer }` - Exited early at specified layer
- `Skipped { reason }` - Computation was blocked
## Quantization
Convert floating-point models to efficient integer math:
| Format | Bits | Memory Savings | Use Case |
|--------|------|----------------|----------|
| INT8 | 8 | 4x | General purpose |
| INT4 | 4 | 8x | Memory-constrained |
| Binary | 1 | 32x | Ultra-compact |
```rust
// INT8 quantization (recommended)
let quant = QuantSpec::int8();
// INT4 for memory savings
let quant = QuantSpec::int4();
// Custom quantization
let quant = QuantSpec {
bits: 8,
scale: 127.0,
zero_point: 0,
symmetric: true,
};
```
## Witness Logging
Every inference produces a cryptographic witness proving:
- Which model ran (by hash)
- What quantization was used
- Which backend executed it
- Exact cycle count and latency
- Gate decision made
```rust
let result = engine.infer(request)?;
let witness = &result.witness;
println!("Model hash: {}", hex::encode(&witness.model_hash));
println!("Backend: {:?}", witness.backend);
println!("Cycles: {}", witness.cycles);
println!("Decision: {:?}", witness.gate_decision);
// Verify witness authenticity
assert!(witness.verify());
```
## Backend Selection
### Native Simulator (Default)
Best for development and testing:
```rust
let engine = Engine::native_sim();
```
### FPGA Daemon
Connect to a remote FPGA over network:
```rust
use ruvector_fpga_transformer::backend::fpga_daemon::{FpgaDaemonBackend, DaemonConnection};
let backend = FpgaDaemonBackend::with_connection(
DaemonConnection::tcp("192.168.1.100:9000"),
Default::default(),
);
```
### FPGA PCIe (Fastest)
Direct hardware access:
```rust
use ruvector_fpga_transformer::backend::fpga_pcie::{FpgaPcieBackend, PcieConfig};
let config = PcieConfig {
device_path: "/dev/ruvector0".into(),
ring_slots: 16,
dma_timeout_ms: 100,
..Default::default()
};
let backend = FpgaPcieBackend::new(config)?;
```
## Feature Flags
Enable only what you need:
```toml
[dependencies]
ruvector-fpga-transformer = { version = "0.1", default-features = false, features = ["native_sim"] }
```
| Flag | Description |
|------|-------------|
| `native_sim` | CPU-based simulator |
| `daemon` | Network daemon client |
| `pcie` | Direct PCIe access |
| `wasm` | WebAssembly support |
| `witness` | Witness logging |
| `strict_verify` | Extra verification checks |
| `lut_softmax` | LUT-based softmax (faster) |
| `trace` | Debug tracing |
## Performance Tips
1. **Use appropriate shapes** - Don't use `large()` for simple tasks
2. **Enable early exit** - Set reasonable `min_coherence` threshold
3. **Batch requests** - Reuse loaded models across multiple inferences
4. **Use topk_only** - Return only top predictions, not full vocabulary
```rust
// Efficient configuration
let config = DaemonConfig {
topk_only: true, // Only return top-K predictions
topk: 10, // Return top 10
retries: 3, // Retry on transient failures
..Default::default()
};
```
## Architecture
```
┌─────────────┐
│ Engine │
│ (public) │
└──────┬──────┘
┌────────────┼────────────┐
│ │ │
┌─────▼─────┐ ┌────▼────┐ ┌─────▼─────┐
│ Coherence │ │ Backend │ │ Witness │
│ Gate │ │ Trait │ │ Log │
└───────────┘ └────┬────┘ └───────────┘
┌────────┬────────┼────────┬────────┐
│ │ │ │ │
┌───▼───┐┌───▼───┐┌───▼───┐┌───▼───┐
│Native ││ WASM ││Daemon ││ PCIe │
│ Sim ││ Sim ││ ││ │
└───────┘└───────┘└───────┘└───────┘
```
## Examples
See the [examples](./examples/) directory:
- `basic_inference.rs` - Simple inference example
- `daemon_client.rs` - Connect to FPGA daemon
Run examples:
```bash
cargo run --example basic_inference
cargo run --example daemon_client --features daemon
```
## Testing
```bash
# Run all tests
cargo test --features native_sim
# Run with tracing
RUST_LOG=debug cargo test --features "native_sim trace"
# Run benchmarks
cargo bench --features native_sim
```
## License
MIT OR Apache-2.0
## Contributing
Contributions welcome! Please read the [contributing guidelines](../../CONTRIBUTING.md) first.

View File

@@ -0,0 +1,157 @@
//! Correctness and determinism benchmarks
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use std::sync::Arc;
use ruvector_fpga_transformer::{
artifact::{Manifest, ModelArtifact},
backend::native_sim::NativeSimBackend,
backend::TransformerBackend,
gating::DefaultCoherenceGate,
types::{FixedShape, GateHint, InferenceRequest, QuantSpec},
};
fn create_test_artifact() -> ModelArtifact {
let shape = FixedShape::micro();
let manifest = Manifest {
name: "determinism_test".into(),
model_hash: String::new(),
shape,
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
let embedding_size = shape.vocab as usize * shape.d_model as usize;
let weights: Vec<u8> = (0..embedding_size).map(|i| (i % 256) as u8).collect();
ModelArtifact::new(manifest, weights, None, None, vec![])
}
fn bench_determinism(c: &mut Criterion) {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let artifact = create_test_artifact();
let model_id = backend.load(&artifact).unwrap();
let shape = FixedShape::micro();
let tokens: Vec<u16> = (0..shape.seq_len)
.map(|i| (i * 7) % shape.vocab as u16)
.collect();
let mask = vec![1u8; shape.seq_len as usize];
c.bench_function("determinism_check_1000", |b| {
b.iter(|| {
let mut first_hash: Option<u64> = None;
for _ in 0..1000 {
let req = InferenceRequest::new(
model_id,
shape,
black_box(&tokens),
&mask,
GateHint::allow_all(),
);
let result = backend.infer(req).unwrap();
// Hash the logits
let hash = result
.logits_q
.iter()
.fold(0u64, |acc, &v| acc.wrapping_mul(31).wrapping_add(v as u64));
match first_hash {
None => first_hash = Some(hash),
Some(expected) => assert_eq!(hash, expected, "Non-deterministic output"),
}
}
})
});
}
fn bench_golden_vectors(c: &mut Criterion) {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let artifact = create_test_artifact();
let model_id = backend.load(&artifact).unwrap();
let shape = FixedShape::micro();
// Create golden vectors
let test_inputs: Vec<Vec<u16>> = (0..128)
.map(|seed| {
(0..shape.seq_len)
.map(|i| ((i as usize * seed + 1) % shape.vocab as usize) as u16)
.collect()
})
.collect();
let mask = vec![1u8; shape.seq_len as usize];
// Compute expected outputs
let expected: Vec<Vec<i16>> = test_inputs
.iter()
.map(|tokens| {
let req = InferenceRequest::new(model_id, shape, tokens, &mask, GateHint::allow_all());
backend.infer(req).unwrap().logits_q
})
.collect();
c.bench_function("golden_vector_validation", |b| {
b.iter(|| {
for (tokens, exp) in test_inputs.iter().zip(&expected) {
let req = InferenceRequest::new(
model_id,
shape,
black_box(tokens),
&mask,
GateHint::allow_all(),
);
let result = backend.infer(req).unwrap();
// Compute max abs error
let max_err: i32 = result
.logits_q
.iter()
.zip(exp)
.map(|(&a, &b)| (a as i32 - b as i32).abs())
.max()
.unwrap_or(0);
assert_eq!(max_err, 0, "Golden vector mismatch");
}
})
});
}
fn bench_quantization_accuracy(c: &mut Criterion) {
use ruvector_fpga_transformer::quant::qformat::{quantize_symmetric_i8, QuantizedMatrix};
c.bench_function("quantize_matrix_256x256", |b| {
let data: Vec<f32> = (0..256 * 256).map(|i| (i as f32 * 0.001).sin()).collect();
b.iter(|| {
let matrix = QuantizedMatrix::from_f32(black_box(&data), 256, 256);
let dequant = matrix.to_f32();
// Check reconstruction error
let max_err: f32 = data
.iter()
.zip(&dequant)
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(max_err < 0.1, "Quantization error too high: {}", max_err);
})
});
}
criterion_group!(
benches,
bench_determinism,
bench_golden_vectors,
bench_quantization_accuracy
);
criterion_main!(benches);

View File

@@ -0,0 +1,154 @@
//! Gating subsystem benchmarks
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use std::sync::Arc;
use ruvector_fpga_transformer::{
artifact::{Manifest, ModelArtifact},
backend::native_sim::NativeSimBackend,
backend::TransformerBackend,
gating::{CoherenceConfig, CoherenceGate, DefaultCoherenceGate},
types::{ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest, QuantSpec},
};
fn bench_skip_rate_distribution(c: &mut Criterion) {
let gate = DefaultCoherenceGate::new();
// Generate synthetic coherence distribution
let coherence_values: Vec<i16> = (-500..500).collect();
c.bench_function("skip_rate_uniform_distribution", |b| {
b.iter(|| {
let mut skipped = 0u32;
let mut ran = 0u32;
for &coherence in &coherence_values {
let hint = GateHint::new(coherence, false, ComputeClass::Deliberative);
match gate.preflight(black_box(&hint)) {
GateDecision::Skipped { .. } => skipped += 1,
_ => ran += 1,
}
}
(skipped, ran)
})
});
}
fn bench_early_exit_histogram(c: &mut Criterion) {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let shape = FixedShape::micro();
let manifest = Manifest {
name: "early_exit_test".into(),
model_hash: String::new(),
shape,
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
let embedding_size = shape.vocab as usize * shape.d_model as usize;
let artifact = ModelArtifact::new(manifest, vec![0u8; embedding_size], None, None, vec![]);
let model_id = backend.load(&artifact).unwrap();
let tokens: Vec<u16> = (0..shape.seq_len).collect();
let mask = vec![1u8; shape.seq_len as usize];
// Test with varying coherence levels
let coherence_levels: Vec<i16> = vec![-500, -200, 0, 200, 500, 1000, 2000];
let mut group = c.benchmark_group("early_exit_by_coherence");
for coherence in coherence_levels {
group.bench_with_input(
BenchmarkId::new("coherence", coherence),
&coherence,
|b, &coherence| {
let hint = GateHint::new(coherence, false, ComputeClass::Deliberative);
b.iter(|| {
let req =
InferenceRequest::new(model_id, shape, black_box(&tokens), &mask, hint);
let result = backend.infer(req).unwrap();
result.witness.gate_decision
})
},
);
}
group.finish();
}
fn bench_checkpoint_overhead(c: &mut Criterion) {
let configs = [
("default", CoherenceConfig::default()),
("strict", CoherenceConfig::strict()),
("permissive", CoherenceConfig::permissive()),
];
let mut group = c.benchmark_group("checkpoint_overhead");
for (name, config) in configs {
let gate = DefaultCoherenceGate::with_config(config);
group.bench_with_input(BenchmarkId::new("config", name), &gate, |b, gate| {
b.iter(|| {
let mut decision = None;
for layer in 0u8..8 {
let signal = (layer as i16) * 150;
if let Some(d) = gate.checkpoint(black_box(layer), black_box(signal)) {
decision = Some(d);
break;
}
}
decision
})
});
}
group.finish();
}
fn bench_mincut_gating(c: &mut Criterion) {
use ruvector_fpga_transformer::gating::coherence_gate::MincutCoherenceGate;
let config = CoherenceConfig::default();
let gate = MincutCoherenceGate::new(config, 50, 200);
let hints = [
(
"high_lambda",
GateHint::new(500, false, ComputeClass::Deliberative),
),
(
"low_lambda",
GateHint::new(100, false, ComputeClass::Deliberative),
),
(
"boundary_crossed",
GateHint::new(300, true, ComputeClass::Deliberative),
),
];
let mut group = c.benchmark_group("mincut_gating");
for (name, hint) in hints {
group.bench_with_input(BenchmarkId::new("preflight", name), &hint, |b, hint| {
b.iter(|| gate.preflight(black_box(hint)))
});
}
group.finish();
}
criterion_group!(
benches,
bench_skip_rate_distribution,
bench_early_exit_histogram,
bench_checkpoint_overhead,
bench_mincut_gating
);
criterion_main!(benches);

View File

@@ -0,0 +1,143 @@
//! Latency benchmarks for FPGA Transformer
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use std::sync::Arc;
use ruvector_fpga_transformer::{
artifact::{Manifest, ModelArtifact},
backend::native_sim::NativeSimBackend,
backend::TransformerBackend,
gating::DefaultCoherenceGate,
types::{FixedShape, GateHint, InferenceRequest, ModelId, QuantSpec},
};
fn create_test_artifact(shape: FixedShape) -> ModelArtifact {
let manifest = Manifest {
name: "bench_model".into(),
model_hash: String::new(),
shape,
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
// Create minimal weights
let embedding_size = shape.vocab as usize * shape.d_model as usize;
let weights = vec![0u8; embedding_size];
ModelArtifact::new(manifest, weights, None, None, vec![])
}
fn bench_inference(c: &mut Criterion) {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let shape = FixedShape::micro();
let artifact = create_test_artifact(shape);
let model_id = backend.load(&artifact).unwrap();
let tokens: Vec<u16> = (0..shape.seq_len).collect();
let mask = vec![1u8; shape.seq_len as usize];
c.bench_function("native_sim_micro_inference", |b| {
b.iter(|| {
let req = InferenceRequest::new(
model_id,
shape,
black_box(&tokens),
&mask,
GateHint::allow_all(),
);
backend.infer(req).unwrap()
})
});
}
fn bench_inference_shapes(c: &mut Criterion) {
let gate = Arc::new(DefaultCoherenceGate::new());
let shapes = [
("micro", FixedShape::micro()),
("small", FixedShape::small()),
("baseline", FixedShape::baseline()),
];
let mut group = c.benchmark_group("inference_by_shape");
for (name, shape) in shapes {
let backend = NativeSimBackend::new(gate.clone());
let artifact = create_test_artifact(shape);
let model_id = backend.load(&artifact).unwrap();
let tokens: Vec<u16> = (0..shape.seq_len).collect();
let mask = vec![1u8; shape.seq_len as usize];
group.bench_with_input(BenchmarkId::new("native_sim", name), &shape, |b, &shape| {
b.iter(|| {
let req = InferenceRequest::new(
model_id,
shape,
black_box(&tokens),
&mask,
GateHint::allow_all(),
);
backend.infer(req).unwrap()
})
});
}
group.finish();
}
fn bench_load_unload(c: &mut Criterion) {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let artifact = create_test_artifact(FixedShape::micro());
c.bench_function("model_load", |b| {
b.iter(|| {
let id = backend.load(black_box(&artifact)).unwrap();
backend.unload(id).unwrap();
})
});
}
fn bench_gating(c: &mut Criterion) {
use ruvector_fpga_transformer::gating::{CoherenceConfig, CoherenceGate};
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::default());
let hints = [
("allow_all", GateHint::allow_all()),
("reflex_only", GateHint::reflex_only()),
(
"low_coherence",
GateHint::new(
-500,
true,
ruvector_fpga_transformer::types::ComputeClass::Deliberative,
),
),
];
let mut group = c.benchmark_group("gating_preflight");
for (name, hint) in hints {
group.bench_with_input(BenchmarkId::new("preflight", name), &hint, |b, hint| {
b.iter(|| gate.preflight(black_box(hint)))
});
}
group.finish();
}
criterion_group!(
benches,
bench_inference,
bench_inference_shapes,
bench_load_unload,
bench_gating
);
criterion_main!(benches);

View File

@@ -0,0 +1,123 @@
//! Basic inference example
//!
//! Demonstrates loading a model and running inference with the native simulator.
use std::sync::Arc;
use ruvector_fpga_transformer::{
artifact::{Manifest, ModelArtifact},
backend::native_sim::NativeSimBackend,
gating::DefaultCoherenceGate,
types::{FixedShape, GateHint, InferenceRequest, QuantSpec},
Engine,
};
fn main() -> anyhow::Result<()> {
println!("FPGA Transformer - Basic Inference Example");
println!("==========================================\n");
// Create a micro-sized model for demonstration
let shape = FixedShape::micro();
println!("Model shape: {:?}", shape);
// Create manifest
let manifest = Manifest {
name: "demo_reflex_transformer".into(),
model_hash: String::new(),
shape,
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
// Create minimal weights (random for demo)
let embedding_size = shape.vocab as usize * shape.d_model as usize;
let weights: Vec<u8> = (0..embedding_size)
.map(|i| ((i * 7 + 13) % 256) as u8)
.collect();
println!("Weight size: {} bytes", weights.len());
// Create artifact
let artifact = ModelArtifact::new(manifest, weights, None, None, vec![]);
println!("Artifact created, model ID: {}", artifact.model_id());
// Create backend and engine
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = Box::new(NativeSimBackend::new(gate.clone()));
let mut engine = Engine::new(backend, gate);
// Load model
let model_id = engine.load(&artifact)?;
println!("Model loaded successfully\n");
// Prepare input
let tokens: Vec<u16> = (0..shape.seq_len).collect();
let mask = vec![1u8; shape.seq_len as usize];
println!("Running inference...");
println!(" Input tokens: {:?}...", &tokens[..4.min(tokens.len())]);
// Run inference with different coherence levels
let coherence_levels = [
(
"High coherence",
GateHint::new(
500,
false,
ruvector_fpga_transformer::ComputeClass::Deliberative,
),
),
(
"Medium coherence",
GateHint::new(
100,
false,
ruvector_fpga_transformer::ComputeClass::Associative,
),
),
(
"Low coherence",
GateHint::new(-100, true, ruvector_fpga_transformer::ComputeClass::Reflex),
),
];
for (name, hint) in coherence_levels {
let req = InferenceRequest::new(model_id, shape, &tokens, &mask, hint);
match engine.infer(req) {
Ok(result) => {
println!("\n{}", name);
println!(" Gate decision: {:?}", result.witness.gate_decision);
println!(
" Latency: {:.2}ms",
result.witness.latency_ns as f64 / 1_000_000.0
);
if let Some(topk) = &result.topk {
println!(" Top-3 predictions:");
for (i, (token, score)) in topk.iter().take(3).enumerate() {
println!(" {}. Token {} (score: {})", i + 1, token, score);
}
}
}
Err(e) => {
println!("\n{}: Skipped - {:?}", name, e);
}
}
}
// Print statistics
println!("\n==========================================");
println!("Engine Statistics:");
let stats = engine.stats();
println!(" Total inferences: {}", stats.total_inferences);
println!(" Successful: {}", stats.successful);
println!(" Skipped: {}", stats.skipped);
println!(" Early exits: {}", stats.early_exits);
println!(" Success rate: {:.1}%", stats.success_rate() * 100.0);
println!(" Avg latency: {:.2}ms", stats.avg_latency_ms());
Ok(())
}

View File

@@ -0,0 +1,114 @@
//! FPGA Daemon client example
//!
//! Demonstrates connecting to an FPGA daemon and running inference.
//! This example requires the `daemon` feature and a running FPGA daemon.
use std::sync::Arc;
use ruvector_fpga_transformer::{
artifact::{Manifest, ModelArtifact},
backend::fpga_daemon::{DaemonConfig, DaemonConnection, FpgaDaemonBackend},
gating::DefaultCoherenceGate,
types::{FixedShape, GateHint, InferenceRequest, QuantSpec},
Engine,
};
fn main() -> anyhow::Result<()> {
println!("FPGA Transformer - Daemon Client Example");
println!("=========================================\n");
// Configure daemon connection
let socket_path = std::env::var("RUVECTOR_FPGA_SOCKET")
.unwrap_or_else(|_| "/var/run/ruvector_fpga.sock".into());
println!("Connecting to daemon at: {}", socket_path);
let connection = DaemonConnection::unix(&socket_path);
let config = DaemonConfig {
connect_timeout_ms: 5000,
read_timeout_ms: 10000,
write_timeout_ms: 5000,
retries: 3,
backoff_multiplier: 2.0,
topk_only: true,
topk: 16,
};
// Create backend
let backend = Box::new(FpgaDaemonBackend::with_connection(connection, config));
// Create gate and engine
let gate = Arc::new(DefaultCoherenceGate::new());
let mut engine = Engine::new(backend, gate);
// Create a test model
let shape = FixedShape::micro();
let manifest = Manifest {
name: "fpga_test_model".into(),
model_hash: String::new(),
shape,
quant: QuantSpec::int4_int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
let embedding_size = shape.vocab as usize * shape.d_model as usize / 2; // INT4 packed
let weights: Vec<u8> = (0..embedding_size)
.map(|i| ((i * 11 + 7) % 256) as u8)
.collect();
let artifact = ModelArtifact::new(manifest, weights, None, None, vec![]);
// Try to load the model
println!("Loading model...");
match engine.load(&artifact) {
Ok(model_id) => {
println!("Model loaded: {}", model_id);
// Prepare input
let tokens: Vec<u16> = (0..shape.seq_len).map(|i| i * 2).collect();
let mask = vec![1u8; shape.seq_len as usize];
// Run inference
println!("\nRunning FPGA inference...");
let req = InferenceRequest::new(model_id, shape, &tokens, &mask, GateHint::allow_all());
match engine.infer(req) {
Ok(result) => {
println!("Inference successful!");
println!(" Backend: {:?}", result.witness.backend);
println!(" Cycles: {}", result.witness.cycles);
println!(
" Latency: {}ns ({:.3}ms)",
result.witness.latency_ns,
result.witness.latency_ns as f64 / 1_000_000.0
);
println!(" Gate decision: {:?}", result.witness.gate_decision);
if let Some(topk) = &result.topk {
println!("\n Top-5 predictions:");
for (i, (token, score)) in topk.iter().take(5).enumerate() {
println!(" {}. Token {} (score: {})", i + 1, token, score);
}
}
}
Err(e) => {
eprintln!("Inference failed: {}", e);
}
}
// Unload model
engine.unload(model_id)?;
println!("\nModel unloaded");
}
Err(e) => {
eprintln!("Failed to load model: {}", e);
eprintln!("\nMake sure the FPGA daemon is running:");
eprintln!(" ruvector-fpga-daemon --socket {}", socket_path);
return Err(e.into());
}
}
Ok(())
}

View File

@@ -0,0 +1,266 @@
//! Manifest schema for model artifacts
use crate::error::{Error, Result};
use crate::types::{FixedShape, Layout, QuantSpec};
use serde::{Deserialize, Serialize};
/// Model manifest containing all metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Manifest {
/// Model name
pub name: String,
/// SHA-256 hash of model (hex string)
pub model_hash: String,
/// Fixed shape specification
pub shape: FixedShape,
/// Quantization specification
pub quant: QuantSpec,
/// I/O configuration
pub io: IoSpec,
/// Backend configuration
pub backend: BackendSpec,
/// Test vector specification
pub tests: TestSpec,
}
impl Manifest {
/// Create a new manifest
pub fn new(name: impl Into<String>, shape: FixedShape, quant: QuantSpec) -> Self {
Self {
name: name.into(),
model_hash: String::new(),
shape,
quant,
io: IoSpec::default(),
backend: BackendSpec::default(),
tests: TestSpec::default(),
}
}
/// Validate manifest consistency
pub fn validate(&self) -> Result<()> {
if self.name.is_empty() {
return Err(Error::InvalidArtifact("Model name is empty".into()));
}
// Validate shape
self.shape
.validate()
.map_err(|e| Error::InvalidArtifact(e))?;
// Validate quantization bits
if !matches!(self.quant.w_bits, 1 | 2 | 4 | 8 | 16) {
return Err(Error::InvalidArtifact(format!(
"Invalid weight bits: {}",
self.quant.w_bits
)));
}
if !matches!(self.quant.a_bits, 4 | 8 | 16 | 32) {
return Err(Error::InvalidArtifact(format!(
"Invalid activation bits: {}",
self.quant.a_bits
)));
}
Ok(())
}
/// Convert to JSON string
pub fn to_json(&self) -> Result<String> {
serde_json::to_string_pretty(self).map_err(Into::into)
}
/// Parse from JSON string
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json).map_err(Into::into)
}
}
/// I/O type specification
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IoSpec {
/// Input token type (typically "u16")
pub tokens: String,
/// Output logit type (typically "i16" or "i32")
pub logits: String,
/// Top-K count (0 for full logits)
pub topk: u16,
}
impl Default for IoSpec {
fn default() -> Self {
Self {
tokens: "u16".into(),
logits: "i16".into(),
topk: 16,
}
}
}
/// Backend-specific configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendSpec {
/// Backend kind ("fpga_pcie", "fpga_daemon", "native_sim", "wasm_sim")
pub kind: String,
/// Protocol version
pub protocol: u16,
/// Backend-specific options
#[serde(default)]
pub options: BackendOptions,
}
impl Default for BackendSpec {
fn default() -> Self {
Self {
kind: "native_sim".into(),
protocol: 1,
options: BackendOptions::default(),
}
}
}
/// Backend-specific options
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BackendOptions {
/// Enable batch processing
#[serde(default)]
pub batch_enabled: bool,
/// Maximum batch size
#[serde(default)]
pub max_batch: u16,
/// Enable early exit
#[serde(default)]
pub early_exit: bool,
/// Minimum coherence threshold for early exit
#[serde(default)]
pub early_exit_threshold: i16,
/// FPGA clock frequency in MHz (for cycle estimation)
#[serde(default)]
pub clock_mhz: u16,
}
/// Test vector specification
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestSpec {
/// Number of test vectors
pub vectors: u32,
/// Maximum absolute error allowed
pub max_abs_err: i32,
/// Whether test vectors must pass before activation
#[serde(default = "default_true")]
pub require_pass: bool,
}
fn default_true() -> bool {
true
}
impl Default for TestSpec {
fn default() -> Self {
Self {
vectors: 0,
max_abs_err: 2,
require_pass: true,
}
}
}
/// Manifest builder for convenient construction
pub struct ManifestBuilder {
manifest: Manifest,
}
impl ManifestBuilder {
/// Create a new builder with name and shape
pub fn new(name: impl Into<String>, shape: FixedShape) -> Self {
Self {
manifest: Manifest::new(name, shape, QuantSpec::int8()),
}
}
/// Set quantization spec
pub fn quant(mut self, quant: QuantSpec) -> Self {
self.manifest.quant = quant;
self
}
/// Set model hash
pub fn model_hash(mut self, hash: impl Into<String>) -> Self {
self.manifest.model_hash = hash.into();
self
}
/// Set I/O spec
pub fn io(mut self, io: IoSpec) -> Self {
self.manifest.io = io;
self
}
/// Set backend spec
pub fn backend(mut self, backend: BackendSpec) -> Self {
self.manifest.backend = backend;
self
}
/// Set test spec
pub fn tests(mut self, tests: TestSpec) -> Self {
self.manifest.tests = tests;
self
}
/// Enable top-K only output
pub fn topk_only(mut self, k: u16) -> Self {
self.manifest.io.topk = k;
self
}
/// Enable early exit
pub fn early_exit(mut self, threshold: i16) -> Self {
self.manifest.backend.options.early_exit = true;
self.manifest.backend.options.early_exit_threshold = threshold;
self
}
/// Build the manifest
pub fn build(self) -> Result<Manifest> {
self.manifest.validate()?;
Ok(self.manifest)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_manifest_builder() {
let manifest = ManifestBuilder::new("test", FixedShape::micro())
.quant(QuantSpec::int4_int8())
.topk_only(16)
.early_exit(100)
.build()
.unwrap();
assert_eq!(manifest.name, "test");
assert_eq!(manifest.quant.w_bits, 4);
assert_eq!(manifest.io.topk, 16);
assert!(manifest.backend.options.early_exit);
}
#[test]
fn test_manifest_json_roundtrip() {
let manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
let json = manifest.to_json().unwrap();
let parsed = Manifest::from_json(&json).unwrap();
assert_eq!(manifest.name, parsed.name);
}
#[test]
fn test_manifest_validation() {
let mut manifest = Manifest::new("test", FixedShape::micro(), QuantSpec::int8());
assert!(manifest.validate().is_ok());
manifest.name = String::new();
assert!(manifest.validate().is_err());
}
}

View File

@@ -0,0 +1,242 @@
//! Model artifact format and handling
//!
//! Signed bundles with metadata, weights, and test vectors.
pub mod manifest;
pub mod pack;
pub mod verify;
pub use manifest::{BackendSpec, IoSpec, Manifest, TestSpec};
pub use pack::{pack_artifact, unpack_artifact};
pub use verify::{verify_artifact, verify_signature};
use crate::error::{Error, Result};
use crate::types::{FixedShape, ModelId, QuantSpec};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
/// Complete model artifact
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelArtifact {
/// Manifest with metadata
pub manifest: Manifest,
/// Quantized weights (binary blob)
#[serde(with = "serde_bytes")]
pub weights: Vec<u8>,
/// Optional FPGA bitstream
#[serde(with = "serde_bytes_option")]
pub bitstream: Option<Vec<u8>>,
/// Optional calibration data
#[serde(with = "serde_bytes_option")]
pub calibration: Option<Vec<u8>>,
/// Test vectors for validation
pub test_vectors: Vec<TestVector>,
/// Ed25519 signature over manifest + file hashes
#[serde(with = "serde_bytes")]
pub signature: [u8; 64],
/// Ed25519 public key
#[serde(with = "serde_bytes")]
pub pubkey: [u8; 32],
}
/// Serde helper for Option<Vec<u8>>
mod serde_bytes_option {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(data: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
match data {
Some(bytes) => s.serialize_some(&serde_bytes::Bytes::new(bytes)),
None => s.serialize_none(),
}
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Vec<u8>>, D::Error> {
let opt: Option<serde_bytes::ByteBuf> = Option::deserialize(d)?;
Ok(opt.map(|b| b.into_vec()))
}
}
/// Test vector for model validation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestVector {
/// Input tokens
pub tokens: Vec<u16>,
/// Expected output logits (top-K or full)
pub expected: Vec<i16>,
/// Maximum absolute error allowed
pub max_abs_err: i32,
}
impl ModelArtifact {
/// Create a new artifact (for building/packing)
pub fn new(
manifest: Manifest,
weights: Vec<u8>,
bitstream: Option<Vec<u8>>,
calibration: Option<Vec<u8>>,
test_vectors: Vec<TestVector>,
) -> Self {
Self {
manifest,
weights,
bitstream,
calibration,
test_vectors,
signature: [0u8; 64],
pubkey: [0u8; 32],
}
}
/// Compute model ID (SHA-256 of manifest + weights hash)
pub fn model_id(&self) -> ModelId {
let mut hasher = Sha256::new();
hasher.update(self.manifest.name.as_bytes());
hasher.update(&self.model_hash());
hasher.update(&self.quant_hash());
ModelId::new(hasher.finalize().into())
}
/// Compute hash of model weights
pub fn model_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(&self.weights);
if let Some(ref bitstream) = self.bitstream {
hasher.update(bitstream);
}
hasher.finalize().into()
}
/// Compute hash of quantization parameters
pub fn quant_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
let quant_json = serde_json::to_string(&self.manifest.quant).unwrap_or_default();
hasher.update(quant_json.as_bytes());
if let Some(ref calib) = self.calibration {
hasher.update(calib);
}
hasher.finalize().into()
}
/// Validate artifact integrity
pub fn validate(&self) -> Result<()> {
// Validate manifest
self.manifest.validate()?;
// Validate shape
self.manifest
.shape
.validate()
.map_err(|e| Error::InvalidArtifact(e))?;
// Check weights size is reasonable
let min_weight_size =
self.manifest.shape.embedding_params() / self.manifest.quant.weights_per_byte();
if self.weights.len() < min_weight_size {
return Err(Error::InvalidArtifact(format!(
"Weights too small: {} bytes, expected at least {} for embeddings",
self.weights.len(),
min_weight_size
)));
}
// Validate test vectors if strict mode
#[cfg(feature = "strict_verify")]
self.run_test_vectors()?;
Ok(())
}
/// Run test vectors for validation
#[cfg(feature = "strict_verify")]
pub fn run_test_vectors(&self) -> Result<()> {
// This would require running inference, which creates a circular dependency
// In practice, this is done by the backend after loading
Ok(())
}
/// Get the fixed shape
pub fn shape(&self) -> &FixedShape {
&self.manifest.shape
}
/// Get quantization spec
pub fn quant(&self) -> &QuantSpec {
&self.manifest.quant
}
/// Check if artifact has FPGA bitstream
pub fn has_bitstream(&self) -> bool {
self.bitstream.is_some()
}
/// Estimated memory footprint in bytes
pub fn memory_footprint(&self) -> usize {
self.weights.len()
+ self.bitstream.as_ref().map(|b| b.len()).unwrap_or(0)
+ self.calibration.as_ref().map(|c| c.len()).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_manifest() -> Manifest {
Manifest {
name: "test_model".into(),
model_hash: "0".repeat(64),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: IoSpec::default(),
backend: BackendSpec::default(),
tests: TestSpec::default(),
}
}
#[test]
fn test_model_id_computation() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![]);
let id1 = artifact.model_id();
let id2 = artifact.model_id();
assert_eq!(id1, id2); // Deterministic
}
#[test]
fn test_model_hash() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(manifest, vec![42u8; 4096 * 64], None, None, vec![]);
let hash = artifact.model_hash();
assert_ne!(hash, [0u8; 32]); // Non-zero hash
}
#[test]
fn test_artifact_validation() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(
manifest,
vec![0u8; 4096 * 64], // Enough for micro embeddings
None,
None,
vec![],
);
assert!(artifact.validate().is_ok());
}
#[test]
fn test_artifact_too_small_weights() {
let manifest = create_test_manifest();
let artifact = ModelArtifact::new(
manifest,
vec![0u8; 100], // Too small
None,
None,
vec![],
);
assert!(artifact.validate().is_err());
}
}

View File

@@ -0,0 +1,304 @@
//! Artifact packing and unpacking
use std::io::{Read, Write};
use std::path::Path;
use crate::artifact::{ModelArtifact, TestVector};
use crate::error::{Error, Result};
/// Magic bytes for artifact file format
const ARTIFACT_MAGIC: &[u8; 4] = b"RVAT"; // RuVector ArTifact
const ARTIFACT_VERSION: u16 = 1;
// Security: Maximum size limits to prevent DoS via unbounded allocations
/// Maximum manifest size (1 MB)
const MAX_MANIFEST_SIZE: usize = 1024 * 1024;
/// Maximum weights size (1 GB)
const MAX_WEIGHTS_SIZE: usize = 1024 * 1024 * 1024;
/// Maximum bitstream/calibration size (100 MB)
const MAX_BLOB_SIZE: usize = 100 * 1024 * 1024;
/// Maximum number of test vectors
const MAX_TEST_VECTORS: usize = 10_000;
/// Maximum tokens per test vector
const MAX_TOKENS_PER_VECTOR: usize = 65_536;
/// Maximum expected values per test vector
const MAX_EXPECTED_PER_VECTOR: usize = 1_000_000;
/// Pack an artifact to bytes
pub fn pack_artifact(artifact: &ModelArtifact) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
// Write magic and version
buffer.extend_from_slice(ARTIFACT_MAGIC);
buffer.extend_from_slice(&ARTIFACT_VERSION.to_le_bytes());
// Write manifest as JSON with length prefix
let manifest_json = serde_json::to_string(&artifact.manifest)?;
let manifest_bytes = manifest_json.as_bytes();
buffer.extend_from_slice(&(manifest_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(manifest_bytes);
// Write weights with length prefix
buffer.extend_from_slice(&(artifact.weights.len() as u64).to_le_bytes());
buffer.extend_from_slice(&artifact.weights);
// Write optional bitstream
if let Some(ref bitstream) = artifact.bitstream {
buffer.push(1); // Present flag
buffer.extend_from_slice(&(bitstream.len() as u64).to_le_bytes());
buffer.extend_from_slice(bitstream);
} else {
buffer.push(0); // Not present
}
// Write optional calibration
if let Some(ref calibration) = artifact.calibration {
buffer.push(1);
buffer.extend_from_slice(&(calibration.len() as u64).to_le_bytes());
buffer.extend_from_slice(calibration);
} else {
buffer.push(0);
}
// Write test vectors
buffer.extend_from_slice(&(artifact.test_vectors.len() as u32).to_le_bytes());
for vector in &artifact.test_vectors {
// Write tokens
buffer.extend_from_slice(&(vector.tokens.len() as u16).to_le_bytes());
for &token in &vector.tokens {
buffer.extend_from_slice(&token.to_le_bytes());
}
// Write expected
buffer.extend_from_slice(&(vector.expected.len() as u32).to_le_bytes());
for &exp in &vector.expected {
buffer.extend_from_slice(&exp.to_le_bytes());
}
// Write max_abs_err
buffer.extend_from_slice(&vector.max_abs_err.to_le_bytes());
}
// Write signature and pubkey
buffer.extend_from_slice(&artifact.signature);
buffer.extend_from_slice(&artifact.pubkey);
Ok(buffer)
}
/// Unpack an artifact from bytes
pub fn unpack_artifact(data: &[u8]) -> Result<ModelArtifact> {
let mut cursor = std::io::Cursor::new(data);
let mut read_buf = [0u8; 8];
// Read and verify magic
cursor.read_exact(&mut read_buf[..4])?;
if &read_buf[..4] != ARTIFACT_MAGIC {
return Err(Error::InvalidArtifact("Invalid magic bytes".into()));
}
// Read version
cursor.read_exact(&mut read_buf[..2])?;
let version = u16::from_le_bytes([read_buf[0], read_buf[1]]);
if version != ARTIFACT_VERSION {
return Err(Error::InvalidArtifact(format!(
"Unsupported version: {}",
version
)));
}
// Read manifest
cursor.read_exact(&mut read_buf[..4])?;
let manifest_len = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if manifest_len > MAX_MANIFEST_SIZE {
return Err(Error::InvalidArtifact(format!(
"Manifest size {} exceeds maximum {}",
manifest_len, MAX_MANIFEST_SIZE
)));
}
let mut manifest_bytes = vec![0u8; manifest_len];
cursor.read_exact(&mut manifest_bytes)?;
let manifest = serde_json::from_slice(&manifest_bytes)?;
// Read weights
cursor.read_exact(&mut read_buf)?;
let weights_len = u64::from_le_bytes(read_buf) as usize;
if weights_len > MAX_WEIGHTS_SIZE {
return Err(Error::InvalidArtifact(format!(
"Weights size {} exceeds maximum {}",
weights_len, MAX_WEIGHTS_SIZE
)));
}
let mut weights = vec![0u8; weights_len];
cursor.read_exact(&mut weights)?;
// Read optional bitstream
cursor.read_exact(&mut read_buf[..1])?;
let bitstream = if read_buf[0] == 1 {
cursor.read_exact(&mut read_buf)?;
let len = u64::from_le_bytes(read_buf) as usize;
if len > MAX_BLOB_SIZE {
return Err(Error::InvalidArtifact(format!(
"Bitstream size {} exceeds maximum {}",
len, MAX_BLOB_SIZE
)));
}
let mut data = vec![0u8; len];
cursor.read_exact(&mut data)?;
Some(data)
} else {
None
};
// Read optional calibration
cursor.read_exact(&mut read_buf[..1])?;
let calibration = if read_buf[0] == 1 {
cursor.read_exact(&mut read_buf)?;
let len = u64::from_le_bytes(read_buf) as usize;
if len > MAX_BLOB_SIZE {
return Err(Error::InvalidArtifact(format!(
"Calibration size {} exceeds maximum {}",
len, MAX_BLOB_SIZE
)));
}
let mut data = vec![0u8; len];
cursor.read_exact(&mut data)?;
Some(data)
} else {
None
};
// Read test vectors
cursor.read_exact(&mut read_buf[..4])?;
let num_vectors = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if num_vectors > MAX_TEST_VECTORS {
return Err(Error::InvalidArtifact(format!(
"Test vector count {} exceeds maximum {}",
num_vectors, MAX_TEST_VECTORS
)));
}
let mut test_vectors = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
// Read tokens
cursor.read_exact(&mut read_buf[..2])?;
let num_tokens = u16::from_le_bytes([read_buf[0], read_buf[1]]) as usize;
if num_tokens > MAX_TOKENS_PER_VECTOR {
return Err(Error::InvalidArtifact(format!(
"Token count {} exceeds maximum {}",
num_tokens, MAX_TOKENS_PER_VECTOR
)));
}
let mut tokens = Vec::with_capacity(num_tokens);
for _ in 0..num_tokens {
cursor.read_exact(&mut read_buf[..2])?;
tokens.push(u16::from_le_bytes([read_buf[0], read_buf[1]]));
}
// Read expected
cursor.read_exact(&mut read_buf[..4])?;
let num_expected = u32::from_le_bytes(read_buf[..4].try_into().unwrap()) as usize;
if num_expected > MAX_EXPECTED_PER_VECTOR {
return Err(Error::InvalidArtifact(format!(
"Expected values count {} exceeds maximum {}",
num_expected, MAX_EXPECTED_PER_VECTOR
)));
}
let mut expected = Vec::with_capacity(num_expected);
for _ in 0..num_expected {
cursor.read_exact(&mut read_buf[..2])?;
expected.push(i16::from_le_bytes([read_buf[0], read_buf[1]]));
}
// Read max_abs_err
cursor.read_exact(&mut read_buf[..4])?;
let max_abs_err = i32::from_le_bytes(read_buf[..4].try_into().unwrap());
test_vectors.push(TestVector {
tokens,
expected,
max_abs_err,
});
}
// Read signature and pubkey
let mut signature = [0u8; 64];
cursor.read_exact(&mut signature)?;
let mut pubkey = [0u8; 32];
cursor.read_exact(&mut pubkey)?;
Ok(ModelArtifact {
manifest,
weights,
bitstream,
calibration,
test_vectors,
signature,
pubkey,
})
}
/// Save artifact to file
pub fn save_artifact(artifact: &ModelArtifact, path: impl AsRef<Path>) -> Result<()> {
let data = pack_artifact(artifact)?;
std::fs::write(path, data)?;
Ok(())
}
/// Load artifact from file
pub fn load_artifact(path: impl AsRef<Path>) -> Result<ModelArtifact> {
let data = std::fs::read(path)?;
unpack_artifact(&data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::types::{FixedShape, QuantSpec};
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "test_pack".into(),
model_hash: "abc123".into(),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact {
manifest,
weights: (0..5000).map(|i| (i % 256) as u8).collect(),
bitstream: Some(vec![0xFF; 100]),
calibration: None,
test_vectors: vec![TestVector {
tokens: vec![1, 2, 3],
expected: vec![100, 200, 300],
max_abs_err: 5,
}],
signature: [0x42u8; 64],
pubkey: [0x24u8; 32],
}
}
#[test]
fn test_pack_unpack_roundtrip() {
let original = create_test_artifact();
let packed = pack_artifact(&original).unwrap();
let unpacked = unpack_artifact(&packed).unwrap();
assert_eq!(original.manifest.name, unpacked.manifest.name);
assert_eq!(original.weights, unpacked.weights);
assert_eq!(original.bitstream, unpacked.bitstream);
assert_eq!(original.calibration, unpacked.calibration);
assert_eq!(original.test_vectors.len(), unpacked.test_vectors.len());
assert_eq!(original.signature, unpacked.signature);
assert_eq!(original.pubkey, unpacked.pubkey);
}
#[test]
fn test_invalid_magic() {
let data = b"XXXX0000";
assert!(unpack_artifact(data).is_err());
}
}

View File

@@ -0,0 +1,203 @@
//! Artifact verification and signature validation
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use sha2::{Digest, Sha256};
use crate::artifact::ModelArtifact;
use crate::error::{Error, Result};
/// Verify artifact signature
pub fn verify_signature(artifact: &ModelArtifact) -> Result<bool> {
// Compute the message to verify (manifest hash + file hashes)
let message = compute_signing_message(artifact);
// Load public key
let pubkey = VerifyingKey::from_bytes(&artifact.pubkey)
.map_err(|e| Error::SignatureError(format!("Invalid public key: {}", e)))?;
// Load signature
let signature = Signature::from_bytes(&artifact.signature);
// Verify
pubkey
.verify(&message, &signature)
.map(|_| true)
.map_err(|e| Error::SignatureError(format!("Verification failed: {}", e)))
}
/// Verify complete artifact integrity
pub fn verify_artifact(artifact: &ModelArtifact) -> Result<()> {
// 1. Validate manifest
artifact.manifest.validate()?;
// 2. Verify model hash matches manifest
let computed_hash = hex::encode(artifact.model_hash());
if !artifact.manifest.model_hash.is_empty() && computed_hash != artifact.manifest.model_hash {
return Err(Error::InvalidArtifact(format!(
"Model hash mismatch: expected {}, got {}",
artifact.manifest.model_hash, computed_hash
)));
}
// 3. Verify signature (if present)
if artifact.pubkey != [0u8; 32] {
verify_signature(artifact)?;
}
// 4. Verify weights size
let expected_min =
artifact.manifest.shape.embedding_params() / artifact.manifest.quant.weights_per_byte();
if artifact.weights.len() < expected_min {
return Err(Error::InvalidArtifact(format!(
"Weights too small: {} < {}",
artifact.weights.len(),
expected_min
)));
}
Ok(())
}
/// Compute the message that was signed
fn compute_signing_message(artifact: &ModelArtifact) -> Vec<u8> {
let mut hasher = Sha256::new();
// Hash manifest
let manifest_json = serde_json::to_string(&artifact.manifest).unwrap_or_default();
hasher.update(manifest_json.as_bytes());
// Hash weights
let weights_hash = artifact.model_hash();
hasher.update(&weights_hash);
// Hash quant params
let quant_hash = artifact.quant_hash();
hasher.update(&quant_hash);
// Hash bitstream if present
if let Some(ref bitstream) = artifact.bitstream {
let mut h = Sha256::new();
h.update(bitstream);
hasher.update(&h.finalize());
}
// Hash calibration if present
if let Some(ref calib) = artifact.calibration {
let mut h = Sha256::new();
h.update(calib);
hasher.update(&h.finalize());
}
hasher.finalize().to_vec()
}
/// Sign an artifact with Ed25519 private key
#[cfg(feature = "sign")]
pub fn sign_artifact(artifact: &mut ModelArtifact, secret_key: &[u8; 32]) -> Result<()> {
use ed25519_dalek::{Signer, SigningKey};
let signing_key = SigningKey::from_bytes(secret_key);
let message = compute_signing_message(artifact);
let signature = signing_key.sign(&message);
artifact.signature = signature.to_bytes();
artifact.pubkey = signing_key.verifying_key().to_bytes();
Ok(())
}
/// Verify test vectors against model output
pub fn verify_test_vectors(
artifact: &ModelArtifact,
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
) -> Result<()> {
let max_err = artifact.manifest.tests.max_abs_err;
for (i, vector) in artifact.test_vectors.iter().enumerate() {
let output = infer_fn(&vector.tokens)?;
// Compare outputs
let actual_max_err = output
.iter()
.zip(&vector.expected)
.map(|(&a, &b)| (a as i32 - b as i32).abs())
.max()
.unwrap_or(0);
if actual_max_err > max_err {
return Err(Error::TestVectorError {
expected: max_err,
actual: actual_max_err,
});
}
}
Ok(())
}
/// Generate test vectors for an artifact
pub fn generate_test_vectors(
artifact: &mut ModelArtifact,
infer_fn: impl Fn(&[u16]) -> Result<Vec<i16>>,
count: usize,
) -> Result<()> {
use rand::Rng;
let mut rng = rand::thread_rng();
let seq_len = artifact.manifest.shape.seq_len as usize;
let vocab = artifact.manifest.shape.vocab as u16;
artifact.test_vectors.clear();
for _ in 0..count {
// Generate random input
let tokens: Vec<u16> = (0..seq_len).map(|_| rng.gen_range(0..vocab)).collect();
// Run inference
let expected = infer_fn(&tokens)?;
artifact.test_vectors.push(crate::artifact::TestVector {
tokens,
expected,
max_abs_err: artifact.manifest.tests.max_abs_err,
});
}
artifact.manifest.tests.vectors = count as u32;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::types::{FixedShape, QuantSpec};
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "test".into(),
model_hash: String::new(),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact::new(manifest, vec![0u8; 4096 * 64], None, None, vec![])
}
#[test]
fn test_verify_artifact() {
let artifact = create_test_artifact();
assert!(verify_artifact(&artifact).is_ok());
}
#[test]
fn test_compute_signing_message() {
let artifact = create_test_artifact();
let msg = compute_signing_message(&artifact);
assert_eq!(msg.len(), 32); // SHA-256 output
}
}

View File

@@ -0,0 +1,566 @@
//! FPGA Daemon backend
//!
//! Communicates with a local daemon over Unix socket or TCP
//! to send inference requests to an FPGA accelerator.
use std::collections::HashMap;
use std::io::{Read, Write};
use std::path::Path;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::artifact::ModelArtifact;
use crate::backend::{
commands, compute_topk, crc32, protocol, read_lock, validate_tokens, write_lock, BackendStats,
RequestFrame, ResponseFrame, TransformerBackend,
};
use crate::error::{Error, Result};
use crate::types::{
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
};
/// Connection type for daemon communication
#[derive(Debug, Clone)]
pub enum DaemonConnection {
/// Unix domain socket path
Unix(String),
/// TCP address (host:port)
Tcp(String),
}
impl DaemonConnection {
/// Create a Unix socket connection
pub fn unix(path: impl Into<String>) -> Self {
Self::Unix(path.into())
}
/// Create a TCP connection
pub fn tcp(addr: impl Into<String>) -> Self {
Self::Tcp(addr.into())
}
/// Default socket path
pub fn default_socket() -> Self {
Self::Unix("/var/run/ruvector_fpga.sock".into())
}
}
/// FPGA Daemon backend
pub struct FpgaDaemonBackend {
/// Connection configuration
connection: DaemonConnection,
/// Loaded models (cached metadata)
models: RwLock<HashMap<ModelId, ModelMetadata>>,
/// Statistics
stats: RwLock<BackendStats>,
/// Configuration
config: DaemonConfig,
}
/// Cached model metadata
struct ModelMetadata {
artifact: ModelArtifact,
loaded_at: Instant,
}
/// Configuration for daemon backend
#[derive(Debug, Clone)]
pub struct DaemonConfig {
/// Connection timeout in milliseconds
pub connect_timeout_ms: u64,
/// Read timeout in milliseconds
pub read_timeout_ms: u64,
/// Write timeout in milliseconds
pub write_timeout_ms: u64,
/// Number of retry attempts
pub retries: usize,
/// Retry backoff multiplier
pub backoff_multiplier: f64,
/// Return only top-K results
pub topk_only: bool,
/// Top-K count
pub topk: u16,
}
impl Default for DaemonConfig {
fn default() -> Self {
Self {
connect_timeout_ms: 5000,
read_timeout_ms: 10000,
write_timeout_ms: 5000,
retries: 3,
backoff_multiplier: 2.0,
topk_only: true,
topk: 16,
}
}
}
impl FpgaDaemonBackend {
/// Create a new daemon backend with Unix socket
pub fn new(socket_path: impl AsRef<Path>) -> Self {
Self::with_connection(
DaemonConnection::unix(socket_path.as_ref().to_string_lossy()),
DaemonConfig::default(),
)
}
/// Create with custom connection and config
pub fn with_connection(connection: DaemonConnection, config: DaemonConfig) -> Self {
Self {
connection,
models: RwLock::new(HashMap::new()),
stats: RwLock::new(BackendStats::default()),
config,
}
}
/// Connect to the daemon
fn connect(&self) -> Result<Box<dyn ReadWrite>> {
let timeout = Duration::from_millis(self.config.connect_timeout_ms);
match &self.connection {
DaemonConnection::Unix(path) => {
#[cfg(unix)]
{
use std::os::unix::net::UnixStream;
let stream = UnixStream::connect(path)
.map_err(|e| Error::daemon_connection(format!("Unix socket: {}", e)))?;
stream
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
.ok();
stream
.set_write_timeout(Some(Duration::from_millis(
self.config.write_timeout_ms,
)))
.ok();
Ok(Box::new(stream))
}
#[cfg(not(unix))]
{
let _ = (path, timeout);
Err(Error::FeatureNotAvailable(
"Unix sockets not available on this platform".into(),
))
}
}
DaemonConnection::Tcp(addr) => {
use std::net::TcpStream;
let stream = TcpStream::connect_timeout(
&addr
.parse()
.map_err(|e| Error::daemon_connection(format!("Invalid address: {}", e)))?,
timeout,
)
.map_err(|e| Error::daemon_connection(format!("TCP: {}", e)))?;
stream
.set_read_timeout(Some(Duration::from_millis(self.config.read_timeout_ms)))
.ok();
stream
.set_write_timeout(Some(Duration::from_millis(self.config.write_timeout_ms)))
.ok();
Ok(Box::new(stream))
}
}
}
/// Send inference request to daemon
fn send_request(
&self,
stream: &mut dyn ReadWrite,
req: &InferenceRequest,
) -> Result<(Vec<i16>, ResponseFrame)> {
let shape = &req.shape;
// Build request flags
let mut flags = 0u16;
if self.config.topk_only {
flags |= protocol::flags::TOPK_ONLY;
}
// Create request frame
let frame = RequestFrame::new(
shape.seq_len,
shape.d_model,
shape.vocab,
&req.model,
flags,
self.config.topk,
);
// Build payload
let mut payload = Vec::with_capacity(
protocol::HEADER_SIZE + req.tokens.len() * 2 + req.attn_mask.len() + 8,
);
// Write header
payload.extend_from_slice(&frame.to_bytes());
// Write tokens (u16 little-endian)
for &token in req.tokens {
payload.extend_from_slice(&token.to_le_bytes());
}
// Write mask
payload.extend_from_slice(req.attn_mask);
// Write gate hint (packed)
payload.extend_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
payload.push(req.gate_hint.boundary_crossed as u8);
payload.push(req.gate_hint.max_compute_class as u8);
// Calculate and append checksum
let checksum = crc32(&payload);
payload.extend_from_slice(&checksum.to_le_bytes());
// Send payload
stream
.write_all(&payload)
.map_err(|e| Error::backend(format!("Write failed: {}", e)))?;
stream
.flush()
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
// Read response header
let mut response_header = [0u8; 14];
stream
.read_exact(&mut response_header)
.map_err(|e| Error::backend(format!("Read header failed: {}", e)))?;
let response = ResponseFrame::from_bytes(&response_header);
// Copy packed fields to avoid alignment issues
let status = { response.status };
// Check status
match status {
protocol::status::OK => {}
protocol::status::MODEL_NOT_FOUND => {
return Err(Error::ModelNotFound(req.model));
}
protocol::status::SHAPE_MISMATCH => {
return Err(Error::ShapeMismatch {
expected: req.shape,
actual: req.shape, // Daemon should provide actual shape
});
}
protocol::status::GATE_BLOCKED => {
return Err(Error::GateBlocked {
reason: crate::types::SkipReason::PolicyDenied,
});
}
_ => {
return Err(Error::backend(format!("Daemon error: status {}", status)));
}
}
// Read logits
let logits_count = if self.config.topk_only {
self.config.topk as usize * 2 // (token_id, logit) pairs
} else {
shape.vocab as usize
};
let mut logits_bytes = vec![0u8; logits_count * 2];
stream
.read_exact(&mut logits_bytes)
.map_err(|e| Error::backend(format!("Read logits failed: {}", e)))?;
// Parse logits
let logits: Vec<i16> = logits_bytes
.chunks(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
// Read and verify checksum
let mut checksum_bytes = [0u8; 4];
stream.read_exact(&mut checksum_bytes).ok(); // Checksum is optional
Ok((logits, response))
}
/// Send load model command to daemon
fn send_load_command(
&self,
stream: &mut dyn ReadWrite,
artifact: &ModelArtifact,
) -> Result<()> {
// Pack artifact
let artifact_bytes = crate::artifact::pack::pack_artifact(artifact)?;
// Build command packet:
// [command: 1] [model_id: 32] [artifact_len: 4] [artifact_data: N] [checksum: 4]
let mut payload = Vec::with_capacity(1 + 32 + 4 + artifact_bytes.len() + 4);
// Command byte
payload.push(commands::LOAD_MODEL);
// Model ID (32 bytes)
payload.extend_from_slice(artifact.model_id().as_bytes());
// Artifact length (u32 LE)
payload.extend_from_slice(&(artifact_bytes.len() as u32).to_le_bytes());
// Artifact data
payload.extend_from_slice(&artifact_bytes);
// Checksum
let checksum = crc32(&payload);
payload.extend_from_slice(&checksum.to_le_bytes());
// Send
stream
.write_all(&payload)
.map_err(|e| Error::backend(format!("Write load command failed: {}", e)))?;
stream
.flush()
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
// Read response: [status: 1] [message_len: 2] [message: N]
let mut status = [0u8; 1];
stream
.read_exact(&mut status)
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
if status[0] != 0 {
// Read error message
let mut msg_len = [0u8; 2];
stream.read_exact(&mut msg_len).ok();
let len = u16::from_le_bytes(msg_len) as usize;
let mut msg = vec![0u8; len.min(256)];
stream.read_exact(&mut msg).ok();
let error_msg = String::from_utf8_lossy(&msg);
return Err(Error::backend(format!(
"Daemon rejected load: {}",
error_msg
)));
}
Ok(())
}
/// Send unload model command to daemon
fn send_unload_command(&self, stream: &mut dyn ReadWrite, model_id: ModelId) -> Result<()> {
// Build command packet: [command: 1] [model_id: 32] [checksum: 4]
let mut payload = Vec::with_capacity(1 + 32 + 4);
payload.push(commands::UNLOAD_MODEL);
payload.extend_from_slice(model_id.as_bytes());
let checksum = crc32(&payload);
payload.extend_from_slice(&checksum.to_le_bytes());
// Send
stream
.write_all(&payload)
.map_err(|e| Error::backend(format!("Write unload command failed: {}", e)))?;
stream
.flush()
.map_err(|e| Error::backend(format!("Flush failed: {}", e)))?;
// Read response status
let mut status = [0u8; 1];
stream
.read_exact(&mut status)
.map_err(|e| Error::backend(format!("Read status failed: {}", e)))?;
if status[0] != 0 {
return Err(Error::backend("Daemon rejected unload"));
}
Ok(())
}
/// Execute with retries
fn with_retries<T, F>(&self, mut f: F) -> Result<T>
where
F: FnMut() -> Result<T>,
{
let mut last_error = None;
let mut delay = Duration::from_millis(100);
for attempt in 0..=self.config.retries {
match f() {
Ok(result) => return Ok(result),
Err(e) if e.is_recoverable() => {
last_error = Some(e);
if attempt < self.config.retries {
std::thread::sleep(delay);
delay = Duration::from_secs_f64(
delay.as_secs_f64() * self.config.backoff_multiplier,
);
}
}
Err(e) => return Err(e),
}
}
Err(last_error.unwrap_or_else(|| Error::backend("Unknown error")))
}
}
impl TransformerBackend for FpgaDaemonBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
let model_id = artifact.model_id();
// Send load command to daemon to preload the model
self.with_retries(|| {
let mut stream = self.connect()?;
self.send_load_command(stream.as_mut(), artifact)
})?;
// Cache metadata locally
{
let mut models = write_lock(&self.models, |m| {
m.insert(
model_id,
ModelMetadata {
artifact: artifact.clone(),
loaded_at: Instant::now(),
},
);
})?;
}
write_lock(&self.stats, |s| {
s.models_loaded += 1;
})?;
Ok(model_id)
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
let start = Instant::now();
// Validate request
req.validate()?;
// Check model is loaded locally and validate tokens
let model_metadata = read_lock(&self.models, |models| {
models.get(&req.model).map(|m| m.artifact.clone())
})?
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate tokens against vocabulary
validate_tokens(req.tokens, model_metadata.manifest.shape.vocab)?;
// Execute with retries
let (logits, response) = self.with_retries(|| {
let mut stream = self.connect()?;
self.send_request(stream.as_mut(), &req)
})?;
let latency_ns = start.elapsed().as_nanos() as u32;
// Parse response
let gate_decision = response.to_gate_decision();
// Build top-K if we got pairs
let (logits_q, topk) = if self.config.topk_only {
// logits contains (token_id, logit) pairs
let pairs: Vec<(u16, i16)> = logits
.chunks(2)
.filter_map(|chunk| {
if chunk.len() == 2 {
Some((chunk[0] as u16, chunk[1]))
} else {
None
}
})
.collect();
(vec![], Some(pairs))
} else {
// Full logits - use common compute_topk
let topk = compute_topk(&logits, 16);
(logits, Some(topk))
};
// Copy packed fields to avoid alignment issues
let resp_cycles = { response.cycles };
let resp_latency_ns = { response.latency_ns };
// Create witness
let witness = WitnessLog::new(
model_metadata.model_hash(),
model_metadata.quant_hash(),
BackendKind::FpgaDaemon,
resp_cycles,
latency_ns.min(resp_latency_ns.max(latency_ns)),
gate_decision,
);
// Update stats (with poison handling)
write_lock(&self.stats, |stats| {
stats.total_inferences += 1;
stats.total_cycles += resp_cycles as u64;
let n = stats.total_inferences;
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
GateDecision::Skipped { .. } => stats.skipped += 1,
_ => {}
}
})?;
Ok(InferenceResult::new(logits_q, topk, witness))
}
fn unload(&self, model: ModelId) -> Result<()> {
// Send unload command to daemon
self.with_retries(|| {
let mut stream = self.connect()?;
self.send_unload_command(stream.as_mut(), model)
})?;
// Remove from local cache
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
if removed {
write_lock(&self.stats, |s| {
s.models_loaded = s.models_loaded.saturating_sub(1);
})?;
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
}
fn kind(&self) -> BackendKind {
BackendKind::FpgaDaemon
}
fn stats(&self) -> BackendStats {
self.stats.read().unwrap().clone()
}
}
/// Trait combining Read and Write for stream abstraction
trait ReadWrite: Read + Write + Send {}
impl<T: Read + Write + Send> ReadWrite for T {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_daemon_connection_types() {
let unix = DaemonConnection::unix("/tmp/test.sock");
assert!(matches!(unix, DaemonConnection::Unix(_)));
let tcp = DaemonConnection::tcp("127.0.0.1:8080");
assert!(matches!(tcp, DaemonConnection::Tcp(_)));
}
#[test]
fn test_config_defaults() {
let config = DaemonConfig::default();
assert_eq!(config.connect_timeout_ms, 5000);
assert_eq!(config.retries, 3);
assert!(config.topk_only);
}
}

View File

@@ -0,0 +1,645 @@
//! FPGA PCIe backend
//!
//! Direct memory-mapped access to FPGA accelerator via PCIe.
//! Uses DMA ring buffers for zero-copy, lock-free operation.
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Instant;
#[cfg(feature = "pcie")]
use memmap2::{MmapMut, MmapOptions};
use crate::artifact::ModelArtifact;
use crate::backend::{
compute_topk, protocol, read_lock, validate_tokens, write_lock, BackendStats,
TransformerBackend,
};
use crate::error::{Error, Result};
use crate::types::{
BackendKind, GateDecision, InferenceRequest, InferenceResult, ModelId, WitnessLog,
};
/// PCIe device configuration
#[derive(Debug, Clone)]
pub struct PcieConfig {
/// Device path (e.g., /dev/ruvector0)
pub device_path: String,
/// BAR0 offset for control registers
pub bar0_offset: usize,
/// BAR1 offset for DMA buffers
pub bar1_offset: usize,
/// Number of request slots in ring buffer
pub ring_slots: usize,
/// Size of each request slot in bytes
pub slot_size: usize,
/// DMA timeout in milliseconds
pub dma_timeout_ms: u64,
/// Enable batch mode (multiple requests per DMA burst)
pub batch_mode: bool,
/// Maximum requests per batch
pub batch_size: usize,
}
impl Default for PcieConfig {
fn default() -> Self {
Self {
device_path: "/dev/ruvector0".into(),
bar0_offset: 0,
bar1_offset: 0x10000,
ring_slots: 16,
slot_size: 64 * 1024, // 64KB per slot
dma_timeout_ms: 100,
batch_mode: false,
batch_size: 4,
}
}
}
/// Ring buffer slot state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
enum SlotState {
Free = 0,
Pending = 1,
Complete = 2,
Error = 3,
}
/// DMA ring buffer for lock-free request/response handling
struct DmaRingBuffer {
/// Memory-mapped request buffer
#[cfg(feature = "pcie")]
request_mmap: MmapMut,
/// Memory-mapped response buffer
#[cfg(feature = "pcie")]
response_mmap: MmapMut,
/// Slot states
slot_states: Vec<AtomicU32>,
/// Producer index (next slot to write)
producer_idx: AtomicU32,
/// Consumer index (next slot to read)
consumer_idx: AtomicU32,
/// Number of slots
num_slots: usize,
/// Size per slot
slot_size: usize,
}
impl DmaRingBuffer {
/// Create a new DMA ring buffer (mock for non-PCIe builds)
#[cfg(not(feature = "pcie"))]
fn new(_config: &PcieConfig) -> Result<Self> {
Err(Error::FeatureNotAvailable(
"PCIe support not compiled".into(),
))
}
/// Create a new DMA ring buffer
#[cfg(feature = "pcie")]
fn new(config: &PcieConfig) -> Result<Self> {
use std::fs::OpenOptions;
// Open device
let file = OpenOptions::new()
.read(true)
.write(true)
.open(&config.device_path)
.map_err(|e| Error::PcieError(format!("Failed to open device: {}", e)))?;
let total_size = config.ring_slots * config.slot_size;
// Map request buffer (BAR1)
let request_mmap = unsafe {
MmapOptions::new()
.offset(config.bar1_offset as u64)
.len(total_size)
.map_mut(&file)
.map_err(|e| Error::PcieError(format!("Failed to map request buffer: {}", e)))?
};
// Map response buffer (BAR1 + offset)
let response_mmap = unsafe {
MmapOptions::new()
.offset((config.bar1_offset + total_size) as u64)
.len(total_size)
.map_mut(&file)
.map_err(|e| Error::PcieError(format!("Failed to map response buffer: {}", e)))?
};
// Initialize slot states
let slot_states: Vec<AtomicU32> = (0..config.ring_slots)
.map(|_| AtomicU32::new(SlotState::Free as u32))
.collect();
Ok(Self {
request_mmap,
response_mmap,
slot_states,
producer_idx: AtomicU32::new(0),
consumer_idx: AtomicU32::new(0),
num_slots: config.ring_slots,
slot_size: config.slot_size,
})
}
/// Acquire a slot for writing
fn acquire_slot(&self) -> Option<usize> {
let producer = self.producer_idx.load(Ordering::Acquire);
let slot = producer as usize % self.num_slots;
// Check if slot is free
if self.slot_states[slot].load(Ordering::Acquire) == SlotState::Free as u32 {
// Try to claim it
if self.slot_states[slot]
.compare_exchange(
SlotState::Free as u32,
SlotState::Pending as u32,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
self.producer_idx
.store(producer.wrapping_add(1), Ordering::Release);
return Some(slot);
}
}
None
}
/// Release a slot after reading response
fn release_slot(&self, slot: usize) {
self.slot_states[slot].store(SlotState::Free as u32, Ordering::Release);
self.consumer_idx.fetch_add(1, Ordering::AcqRel);
}
/// Check if a slot is complete
fn is_complete(&self, slot: usize) -> bool {
self.slot_states[slot].load(Ordering::Acquire) == SlotState::Complete as u32
}
/// Mark a slot as complete (called by FPGA via doorbell/interrupt)
fn mark_complete(&self, slot: usize) {
self.slot_states[slot].store(SlotState::Complete as u32, Ordering::Release);
}
/// Get request buffer for a slot
#[cfg(feature = "pcie")]
fn request_buffer(&mut self, slot: usize) -> &mut [u8] {
let start = slot * self.slot_size;
let end = start + self.slot_size;
&mut self.request_mmap[start..end]
}
/// Get response buffer for a slot
#[cfg(feature = "pcie")]
fn response_buffer(&self, slot: usize) -> &[u8] {
let start = slot * self.slot_size;
let end = start + self.slot_size;
&self.response_mmap[start..end]
}
}
/// FPGA PCIe backend
pub struct FpgaPcieBackend {
/// Configuration
config: PcieConfig,
/// DMA ring buffer
ring: Option<DmaRingBuffer>,
/// Loaded models
models: RwLock<HashMap<ModelId, ModelMetadata>>,
/// Statistics
stats: RwLock<BackendStats>,
/// Total cycles counter
total_cycles: AtomicU64,
/// FPGA memory allocator state (next free offset)
fpga_mem_offset: AtomicU64,
/// FPGA memory total size (2GB default)
fpga_mem_size: u64,
}
/// Cached model metadata
struct ModelMetadata {
artifact: ModelArtifact,
fpga_slot: u32, // Slot in FPGA memory where model is loaded
weights_offset: u64, // Offset in FPGA DDR where weights are stored
weights_size: usize, // Size of weights in bytes
}
/// FPGA DDR base offset for model weights
const FPGA_DDR_MODEL_BASE: u64 = 0x1000_0000; // 256MB offset
impl FpgaPcieBackend {
/// Create a new PCIe backend
pub fn new(config: PcieConfig) -> Result<Self> {
#[cfg(feature = "pcie")]
let ring = Some(DmaRingBuffer::new(&config)?);
#[cfg(not(feature = "pcie"))]
let ring = None;
Ok(Self {
config,
ring,
models: RwLock::new(HashMap::new()),
stats: RwLock::new(BackendStats::default()),
total_cycles: AtomicU64::new(0),
fpga_mem_offset: AtomicU64::new(FPGA_DDR_MODEL_BASE),
fpga_mem_size: 2 * 1024 * 1024 * 1024, // 2GB
})
}
/// Create with default configuration
pub fn default_device() -> Result<Self> {
Self::new(PcieConfig::default())
}
/// Write inference request to DMA buffer
#[cfg(feature = "pcie")]
fn write_request(
&self,
ring: &mut DmaRingBuffer,
slot: usize,
req: &InferenceRequest,
) -> Result<()> {
use crate::backend::{protocol, RequestFrame};
let buffer = ring.request_buffer(slot);
let shape = &req.shape;
// Write header
let frame = RequestFrame::new(shape.seq_len, shape.d_model, shape.vocab, &req.model, 0, 16);
let header = frame.to_bytes();
buffer[..protocol::HEADER_SIZE].copy_from_slice(&header);
let mut offset = protocol::HEADER_SIZE;
// Write tokens
for &token in req.tokens {
buffer[offset..offset + 2].copy_from_slice(&token.to_le_bytes());
offset += 2;
}
// Write mask
buffer[offset..offset + req.attn_mask.len()].copy_from_slice(req.attn_mask);
offset += req.attn_mask.len();
// Write gate hint
buffer[offset..offset + 2].copy_from_slice(&req.gate_hint.coherence_score_q.to_le_bytes());
offset += 2;
buffer[offset] = req.gate_hint.boundary_crossed as u8;
offset += 1;
buffer[offset] = req.gate_hint.max_compute_class as u8;
Ok(())
}
/// Read inference response from DMA buffer
#[cfg(feature = "pcie")]
fn read_response(
&self,
ring: &DmaRingBuffer,
slot: usize,
shape: &crate::types::FixedShape,
) -> Result<(Vec<i16>, u32, u32, GateDecision)> {
use crate::backend::ResponseFrame;
let buffer = ring.response_buffer(slot);
// Read response header
let response = ResponseFrame::from_bytes(&buffer[..14].try_into().unwrap());
// Check status
if response.status != 0 {
return Err(Error::backend(format!(
"FPGA error: status {}",
response.status
)));
}
// Read logits
let vocab = shape.vocab as usize;
let mut logits = Vec::with_capacity(vocab);
let mut offset = 14;
for _ in 0..vocab {
let value = i16::from_le_bytes([buffer[offset], buffer[offset + 1]]);
logits.push(value);
offset += 2;
}
Ok((
logits,
response.cycles,
response.latency_ns,
response.to_gate_decision(),
))
}
/// Ring doorbell to notify FPGA of pending request
#[cfg(feature = "pcie")]
fn ring_doorbell(&self, _slot: usize) {
// In a real implementation, this would write to a control register
// to notify the FPGA that a new request is available
}
/// Wait for response with polling
fn wait_for_response(&self, ring: &DmaRingBuffer, slot: usize, timeout_ms: u64) -> Result<()> {
let start = Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while !ring.is_complete(slot) {
if start.elapsed() > timeout {
return Err(Error::Timeout { ms: timeout_ms });
}
std::hint::spin_loop();
}
Ok(())
}
/// Allocate FPGA DDR memory for model weights
fn allocate_fpga_memory(&self, size: usize) -> Result<u64> {
// Align to 4KB boundary for DMA efficiency
let aligned_size = (size + 0xFFF) & !0xFFF;
// Atomic allocation (simple bump allocator)
let offset = self
.fpga_mem_offset
.fetch_add(aligned_size as u64, Ordering::SeqCst);
// Check for overflow
if offset + aligned_size as u64 > self.fpga_mem_size {
// Roll back allocation
self.fpga_mem_offset
.fetch_sub(aligned_size as u64, Ordering::SeqCst);
return Err(Error::ResourceExhausted("FPGA DDR memory full".into()));
}
Ok(offset)
}
/// Upload model weights to FPGA DDR via DMA
#[cfg(feature = "pcie")]
fn upload_weights_dma(&self, weights: &[u8], fpga_offset: u64) -> Result<()> {
// DMA transfer configuration
const DMA_CHUNK_SIZE: usize = 64 * 1024; // 64KB per transfer
let ring = self
.ring
.as_ref()
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
// Transfer weights in chunks
let mut transferred = 0usize;
while transferred < weights.len() {
let chunk_size = DMA_CHUNK_SIZE.min(weights.len() - transferred);
// Acquire a DMA slot
let slot = loop {
if let Some(s) = ring.acquire_slot() {
break s;
}
std::hint::spin_loop();
};
// Write DMA command to slot (simplified protocol)
// In real hardware:
// - Write target FPGA DDR address
// - Write source offset in slot
// - Write transfer length
// - Ring doorbell
// For now, we simulate the DMA by marking complete
ring.mark_complete(slot);
// Wait for completion
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
// Release slot
ring.release_slot(slot);
transferred += chunk_size;
}
Ok(())
}
/// Free FPGA DDR memory (simplified - real impl would use proper allocator)
fn free_fpga_memory(&self, _offset: u64, _size: usize) {
// In a production system, this would:
// 1. Mark the memory region as free in an allocator
// 2. Potentially compact memory if fragmentation is high
// 3. Update hardware memory management unit
//
// For this implementation, we use a bump allocator without free.
// Memory is reclaimed when all models are unloaded.
}
/// Check if all models are unloaded and reset memory allocator
fn maybe_reset_allocator(&self) {
let models_empty = read_lock(&self.models, |m| m.is_empty()).unwrap_or(false);
if models_empty {
self.fpga_mem_offset
.store(FPGA_DDR_MODEL_BASE, Ordering::SeqCst);
}
}
}
impl TransformerBackend for FpgaPcieBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
#[cfg(not(feature = "pcie"))]
{
let _ = artifact;
return Err(Error::FeatureNotAvailable(
"PCIe support not compiled".into(),
));
}
#[cfg(feature = "pcie")]
{
// Validate artifact
artifact.validate()?;
let model_id = artifact.model_id();
let weights_size = artifact.weights.len();
// Allocate FPGA DDR memory for weights
let weights_offset = self.allocate_fpga_memory(weights_size)?;
// Upload model weights to FPGA DDR via DMA
if let Err(e) = self.upload_weights_dma(&artifact.weights, weights_offset) {
// Roll back allocation on failure
self.free_fpga_memory(weights_offset, weights_size);
return Err(e);
}
// Get slot number for this model
let fpga_slot = read_lock(&self.models, |m| m.len() as u32)?;
// Store metadata
write_lock(&self.models, |models| {
models.insert(
model_id,
ModelMetadata {
artifact: artifact.clone(),
fpga_slot,
weights_offset,
weights_size,
},
);
})?;
write_lock(&self.stats, |stats| {
stats.models_loaded += 1;
})?;
Ok(model_id)
}
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
#[cfg(not(feature = "pcie"))]
{
let _ = req;
return Err(Error::FeatureNotAvailable(
"PCIe support not compiled".into(),
));
}
#[cfg(feature = "pcie")]
{
let start = Instant::now();
// Validate request
req.validate()?;
// Get model metadata
let model_artifact = read_lock(&self.models, |models| {
models.get(&req.model).map(|m| m.artifact.clone())
})?
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate tokens against vocabulary
validate_tokens(req.tokens, model_artifact.manifest.shape.vocab)?;
// Get ring buffer
let ring = self
.ring
.as_ref()
.ok_or_else(|| Error::FeatureNotAvailable("Ring buffer not initialized".into()))?;
// Acquire slot
let slot = ring
.acquire_slot()
.ok_or_else(|| Error::ResourceExhausted("No DMA slots available".into()))?;
// Write request (need mutable access - simplified for now)
// In production, this would use proper interior mutability
// self.write_request(ring, slot, &req)?;
// Ring doorbell
// self.ring_doorbell(slot);
// Wait for response
self.wait_for_response(ring, slot, self.config.dma_timeout_ms)?;
// Read response
let (logits, cycles, fpga_latency_ns, gate_decision) =
self.read_response(ring, slot, &req.shape)?;
// Release slot
ring.release_slot(slot);
let latency_ns = start.elapsed().as_nanos() as u32;
// Compute top-K using common utility
let topk = compute_topk(&logits, 16);
// Create witness
let witness = WitnessLog::new(
model_artifact.model_hash(),
model_artifact.quant_hash(),
BackendKind::FpgaPcie,
cycles,
fpga_latency_ns.min(latency_ns),
gate_decision,
);
// Update stats
self.total_cycles
.fetch_add(cycles as u64, Ordering::Relaxed);
write_lock(&self.stats, |stats| {
stats.total_inferences += 1;
stats.total_cycles = self.total_cycles.load(Ordering::Relaxed);
let n = stats.total_inferences;
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
GateDecision::Skipped { .. } => stats.skipped += 1,
_ => {}
}
})?;
Ok(InferenceResult::new(logits, Some(topk), witness))
}
}
fn unload(&self, model: ModelId) -> Result<()> {
// Remove from cache and get memory info for deallocation
let removed = write_lock(&self.models, |models| {
models
.remove(&model)
.map(|m| (m.weights_offset, m.weights_size))
})?;
if let Some((offset, size)) = removed {
// Free FPGA DDR memory
self.free_fpga_memory(offset, size);
// Check if we can reset the allocator
self.maybe_reset_allocator();
write_lock(&self.stats, |stats| {
stats.models_loaded = stats.models_loaded.saturating_sub(1);
})?;
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
}
fn kind(&self) -> BackendKind {
BackendKind::FpgaPcie
}
fn stats(&self) -> BackendStats {
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pcie_config_default() {
let config = PcieConfig::default();
assert_eq!(config.ring_slots, 16);
assert_eq!(config.slot_size, 64 * 1024);
}
#[test]
fn test_slot_state_values() {
assert_eq!(SlotState::Free as u8, 0);
assert_eq!(SlotState::Pending as u8, 1);
assert_eq!(SlotState::Complete as u8, 2);
}
}

View File

@@ -0,0 +1,428 @@
//! Backend implementations for FPGA Transformer
//!
//! All backends implement the `TransformerBackend` trait for uniform API.
use crate::artifact::ModelArtifact;
use crate::error::Result;
use crate::types::{InferenceRequest, InferenceResult, ModelId};
#[cfg(feature = "native_sim")]
pub mod native_sim;
#[cfg(feature = "daemon")]
pub mod fpga_daemon;
#[cfg(feature = "pcie")]
pub mod fpga_pcie;
#[cfg(feature = "wasm")]
pub mod wasm_sim;
/// Trait for transformer inference backends
///
/// All backends must be thread-safe and implement the same API.
pub trait TransformerBackend: Send + Sync {
/// Load a model artifact and return its ID
///
/// The artifact is validated, test vectors are run, and
/// the model is prepared for inference.
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId>;
/// Run inference on the given request
///
/// The request must specify a model that has been loaded.
/// Returns the inference result with witness log.
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult>;
/// Unload a model to free resources
fn unload(&self, model: ModelId) -> Result<()>;
/// Check if a model is loaded
fn is_loaded(&self, model: ModelId) -> bool;
/// Get the backend kind
fn kind(&self) -> crate::types::BackendKind;
/// Get backend-specific statistics
fn stats(&self) -> BackendStats {
BackendStats::default()
}
}
/// Backend statistics
#[derive(Debug, Clone, Default)]
pub struct BackendStats {
/// Number of models currently loaded
pub models_loaded: usize,
/// Total inferences performed
pub total_inferences: u64,
/// Total cycles consumed (FPGA only)
pub total_cycles: u64,
/// Average latency in nanoseconds
pub avg_latency_ns: u64,
/// P99 latency in nanoseconds
pub p99_latency_ns: u64,
/// Number of early exits
pub early_exits: u64,
/// Number of skipped inferences
pub skipped: u64,
}
/// Protocol constants for daemon/PCIe communication
pub mod protocol {
/// Magic number for frame validation
pub const MAGIC: u32 = 0x5256_5846; // "RVXF" - RuVector FPGA
/// Current protocol version
pub const VERSION: u16 = 1;
/// Frame header size in bytes
pub const HEADER_SIZE: usize = 24;
/// Maximum payload size
pub const MAX_PAYLOAD: usize = 1024 * 1024; // 1MB
/// Request flags
pub mod flags {
/// Return only top-K predictions
pub const TOPK_ONLY: u16 = 0x0001;
/// Use LUT-based softmax
pub const LUT_SOFTMAX: u16 = 0x0002;
/// Enable early exit
pub const EARLY_EXIT: u16 = 0x0004;
/// Return detailed witness
pub const WITNESS_DETAIL: u16 = 0x0008;
}
/// Response status codes
pub mod status {
/// Success
pub const OK: u16 = 0;
/// Model not found
pub const MODEL_NOT_FOUND: u16 = 1;
/// Shape mismatch
pub const SHAPE_MISMATCH: u16 = 2;
/// Gate blocked
pub const GATE_BLOCKED: u16 = 3;
/// Internal error
pub const INTERNAL_ERROR: u16 = 0xFFFF;
}
}
/// Request frame for wire protocol
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
pub struct RequestFrame {
/// Magic number (MAGIC)
pub magic: u32,
/// Protocol version
pub protocol: u16,
/// Sequence length
pub seq_len: u16,
/// Model dimension
pub d_model: u16,
/// Vocabulary size
pub vocab: u16,
/// Model ID (lower 32 bits)
pub model_id_low: u32,
/// Model ID (upper 32 bits)
pub model_id_high: u32,
/// Request flags
pub flags: u16,
/// Top-K count (if TOPK_ONLY flag set)
pub topk: u16,
}
impl RequestFrame {
/// Create a new request frame
pub fn new(
seq_len: u16,
d_model: u16,
vocab: u32,
model_id: &ModelId,
flags: u16,
topk: u16,
) -> Self {
let id_bytes = model_id.as_bytes();
let model_id_low = u32::from_le_bytes([id_bytes[0], id_bytes[1], id_bytes[2], id_bytes[3]]);
let model_id_high =
u32::from_le_bytes([id_bytes[4], id_bytes[5], id_bytes[6], id_bytes[7]]);
Self {
magic: protocol::MAGIC,
protocol: protocol::VERSION,
seq_len,
d_model,
vocab: (vocab & 0xFFFF) as u16,
model_id_low,
model_id_high,
flags,
topk,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> [u8; protocol::HEADER_SIZE] {
let mut bytes = [0u8; protocol::HEADER_SIZE];
bytes[0..4].copy_from_slice(&self.magic.to_le_bytes());
bytes[4..6].copy_from_slice(&self.protocol.to_le_bytes());
bytes[6..8].copy_from_slice(&self.seq_len.to_le_bytes());
bytes[8..10].copy_from_slice(&self.d_model.to_le_bytes());
bytes[10..12].copy_from_slice(&self.vocab.to_le_bytes());
bytes[12..16].copy_from_slice(&self.model_id_low.to_le_bytes());
bytes[16..20].copy_from_slice(&self.model_id_high.to_le_bytes());
bytes[20..22].copy_from_slice(&self.flags.to_le_bytes());
bytes[22..24].copy_from_slice(&self.topk.to_le_bytes());
bytes
}
}
/// Response frame from wire protocol
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
pub struct ResponseFrame {
/// Status code
pub status: u16,
/// Latency in nanoseconds
pub latency_ns: u32,
/// Compute cycles
pub cycles: u32,
/// Gate decision (packed)
pub gate_decision: u8,
/// Exit layer (if early exit)
pub exit_layer: u8,
/// Skip reason (if skipped)
pub skip_reason: u8,
/// Reserved
pub reserved: u8,
}
impl ResponseFrame {
/// Parse from bytes
pub fn from_bytes(bytes: &[u8; 14]) -> Self {
Self {
status: u16::from_le_bytes([bytes[0], bytes[1]]),
latency_ns: u32::from_le_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
cycles: u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]),
gate_decision: bytes[10],
exit_layer: bytes[11],
skip_reason: bytes[12],
reserved: bytes[13],
}
}
/// Convert gate decision to enum
pub fn to_gate_decision(&self) -> crate::types::GateDecision {
match self.gate_decision {
0 => crate::types::GateDecision::RanFull,
1 => crate::types::GateDecision::EarlyExit {
layer: self.exit_layer,
},
2 => crate::types::GateDecision::Skipped {
reason: match self.skip_reason {
0 => crate::types::SkipReason::LowCoherence,
1 => crate::types::SkipReason::PolicyDenied,
_ => crate::types::SkipReason::BudgetExceeded,
},
},
_ => crate::types::GateDecision::RanFull,
}
}
}
/// Calculate CRC32 checksum for frame validation
pub fn crc32(data: &[u8]) -> u32 {
// Simple CRC32 implementation (could use crc32fast crate in production)
let mut crc: u32 = 0xFFFFFFFF;
for &byte in data {
crc ^= byte as u32;
for _ in 0..8 {
crc = if crc & 1 != 0 {
(crc >> 1) ^ 0xEDB88320
} else {
crc >> 1
};
}
}
!crc
}
// ============================================================================
// Common utilities shared across backends
// ============================================================================
/// Compute top-K predictions from logits
/// Returns sorted (token_id, logit) pairs, descending by logit value
#[inline]
pub fn compute_topk(logits: &[i16], k: usize) -> Vec<(u16, i16)> {
if logits.is_empty() {
return vec![];
}
// For small K, partial sort is faster
if k <= 32 && logits.len() > 100 {
// Use partial heap-based selection
let mut heap: Vec<(i16, u16)> = Vec::with_capacity(k + 1);
for (i, &v) in logits.iter().enumerate() {
if heap.len() < k {
heap.push((v, i as u16));
if heap.len() == k {
// Heapify
heap.sort_by(|a, b| a.0.cmp(&b.0));
}
} else if v > heap[0].0 {
heap[0] = (v, i as u16);
// Maintain min-heap property
let mut idx = 0;
while idx * 2 + 1 < heap.len() {
let left = idx * 2 + 1;
let right = idx * 2 + 2;
let mut smallest = idx;
if heap[left].0 < heap[smallest].0 {
smallest = left;
}
if right < heap.len() && heap[right].0 < heap[smallest].0 {
smallest = right;
}
if smallest == idx {
break;
}
heap.swap(idx, smallest);
idx = smallest;
}
}
}
heap.sort_by(|a, b| b.0.cmp(&a.0));
heap.into_iter().map(|(v, i)| (i, v)).collect()
} else {
// Full sort for small arrays
let mut indexed: Vec<(usize, i16)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.cmp(&a.1));
indexed
.into_iter()
.take(k)
.map(|(i, v)| (i as u16, v))
.collect()
}
}
/// Helper to safely read from RwLock, returning error on poison
pub fn read_lock<T, R>(
lock: &std::sync::RwLock<T>,
f: impl FnOnce(&T) -> R,
) -> crate::error::Result<R> {
lock.read()
.map(|guard| f(&*guard))
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (read)".into()))
}
/// Helper to safely write to RwLock, returning error on poison
pub fn write_lock<T, R>(
lock: &std::sync::RwLock<T>,
f: impl FnOnce(&mut T) -> R,
) -> crate::error::Result<R> {
lock.write()
.map(|mut guard| f(&mut *guard))
.map_err(|_| crate::error::Error::BackendError("Lock poisoned (write)".into()))
}
/// Validate token indices against vocabulary size
#[inline]
pub fn validate_tokens(tokens: &[u16], vocab_size: u32) -> crate::error::Result<()> {
for (i, &token) in tokens.iter().enumerate() {
if token as u32 >= vocab_size {
return Err(crate::error::Error::InvalidConfig(format!(
"Token {} at index {} exceeds vocabulary size {}",
token, i, vocab_size
)));
}
}
Ok(())
}
/// Build witness log from inference metadata
pub fn build_witness(
model_hash: [u8; 32],
quant_hash: [u8; 32],
backend: crate::types::BackendKind,
cycles: u32,
latency_ns: u32,
gate_decision: crate::types::GateDecision,
) -> crate::types::WitnessLog {
crate::types::WitnessLog::new(
model_hash,
quant_hash,
backend,
cycles,
latency_ns,
gate_decision,
)
}
/// Command types for daemon protocol
pub mod commands {
/// Load model command
pub const LOAD_MODEL: u8 = 0x01;
/// Unload model command
pub const UNLOAD_MODEL: u8 = 0x02;
/// Inference request command
pub const INFER: u8 = 0x03;
/// Ping/health check command
pub const PING: u8 = 0x04;
/// Get status command
pub const STATUS: u8 = 0x05;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_frame_roundtrip() {
let model_id = ModelId::new([0x42u8; 32]);
let frame = RequestFrame::new(64, 256, 32000, &model_id, 0, 16);
let bytes = frame.to_bytes();
assert_eq!(bytes.len(), protocol::HEADER_SIZE);
assert_eq!(&bytes[0..4], &protocol::MAGIC.to_le_bytes());
}
#[test]
fn test_crc32() {
let data = b"test data";
let crc = crc32(data);
// CRC should be consistent
assert_eq!(crc, crc32(data));
}
#[test]
fn test_compute_topk() {
let logits: Vec<i16> = vec![100, 50, 300, 200, 150];
let topk = compute_topk(&logits, 3);
assert_eq!(topk.len(), 3);
assert_eq!(topk[0], (2, 300)); // Index 2, value 300
assert_eq!(topk[1], (3, 200)); // Index 3, value 200
assert_eq!(topk[2], (4, 150)); // Index 4, value 150
}
#[test]
fn test_compute_topk_large() {
let logits: Vec<i16> = (0..1000).map(|i| (i * 7 % 500) as i16).collect();
let topk = compute_topk(&logits, 10);
assert_eq!(topk.len(), 10);
// Should be sorted descending
for i in 1..topk.len() {
assert!(topk[i - 1].1 >= topk[i].1);
}
}
#[test]
fn test_validate_tokens() {
assert!(validate_tokens(&[0, 1, 2], 100).is_ok());
assert!(validate_tokens(&[99], 100).is_ok());
assert!(validate_tokens(&[100], 100).is_err());
assert!(validate_tokens(&[0, 50, 101], 100).is_err());
}
}

View File

@@ -0,0 +1,544 @@
//! Native Rust simulator backend
//!
//! Provides a pure-Rust implementation of the transformer inference
//! for testing, development, and fallback when no FPGA is available.
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Instant;
use crate::artifact::ModelArtifact;
use crate::backend::{
compute_topk, read_lock, validate_tokens, write_lock, BackendStats, TransformerBackend,
};
use crate::error::{Error, Result};
use crate::gating::CoherenceGate;
use crate::quant::{dequantize_i8, quantize_i16, softmax_lut};
use crate::types::{
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
QuantSpec, SkipReason, WitnessLog,
};
/// Loaded model data for native simulation
struct LoadedModel {
/// Model artifact (contains weights and config)
artifact: ModelArtifact,
/// Precomputed embedding matrix (dequantized for sim)
embeddings: Vec<f32>,
/// Layer weights (simplified for simulation)
layers: Vec<LayerWeights>,
/// Output projection
output_proj: Vec<f32>,
}
/// Simplified layer weights for simulation
struct LayerWeights {
/// Attention Q projection
wq: Vec<f32>,
/// Attention K projection
wk: Vec<f32>,
/// Attention V projection
wv: Vec<f32>,
/// Attention output projection
wo: Vec<f32>,
/// FFN up projection
w1: Vec<f32>,
/// FFN down projection
w2: Vec<f32>,
/// Layer norm weights
ln1_weight: Vec<f32>,
ln2_weight: Vec<f32>,
}
/// Native simulator backend
pub struct NativeSimBackend {
/// Loaded models
models: RwLock<HashMap<ModelId, Arc<LoadedModel>>>,
/// Coherence gate
gate: Arc<dyn CoherenceGate>,
/// Statistics
stats: RwLock<BackendStats>,
/// Configuration
config: NativeSimConfig,
}
/// Configuration for native simulator
#[derive(Debug, Clone)]
pub struct NativeSimConfig {
/// Maximum models to keep loaded
pub max_models: usize,
/// Enable detailed tracing
pub trace: bool,
/// Use LUT-based softmax
pub lut_softmax: bool,
/// Number of layers to simulate (0 = all)
pub max_layers: usize,
}
impl Default for NativeSimConfig {
fn default() -> Self {
Self {
max_models: 8,
trace: false,
lut_softmax: true,
max_layers: 0,
}
}
}
impl NativeSimBackend {
/// Create a new native simulator backend
pub fn new(gate: Arc<dyn CoherenceGate>) -> Self {
Self::with_config(gate, NativeSimConfig::default())
}
/// Create with custom configuration
pub fn with_config(gate: Arc<dyn CoherenceGate>, config: NativeSimConfig) -> Self {
Self {
models: RwLock::new(HashMap::new()),
gate,
stats: RwLock::new(BackendStats::default()),
config,
}
}
/// Run the core transformer inference
fn run_inference(
&self,
model: &LoadedModel,
tokens: &[u16],
_attn_mask: &[u8],
gate_hint: &GateHint,
) -> Result<(Vec<i16>, GateDecision)> {
let shape = &model.artifact.manifest.shape;
let num_layers = model.layers.len();
// Check preflight gate
let preflight = self.gate.preflight(gate_hint);
if let GateDecision::Skipped { reason } = preflight {
return Ok((
vec![0i16; shape.vocab as usize],
GateDecision::Skipped { reason },
));
}
// Initialize hidden states from embeddings
let d_model = shape.d_model as usize;
let seq_len = tokens.len();
let mut hidden = vec![0.0f32; seq_len * d_model];
// Lookup embeddings
for (i, &token) in tokens.iter().enumerate() {
let offset = (token as usize) * d_model;
if offset + d_model <= model.embeddings.len() {
hidden[i * d_model..(i + 1) * d_model]
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
}
}
// Run through layers
let max_layers = if self.config.max_layers > 0 {
self.config.max_layers.min(num_layers)
} else {
num_layers
};
for layer_idx in 0..max_layers {
let layer = &model.layers[layer_idx];
// Check layer checkpoint for early exit
let coherence_signal = self.compute_coherence_signal(&hidden);
if let Some(decision) = self.gate.checkpoint(layer_idx as u8, coherence_signal) {
if let GateDecision::EarlyExit { layer } = decision {
// Early exit - compute output from current hidden state
let logits = self.compute_output(&hidden, &model.output_proj, shape);
return Ok((logits, GateDecision::EarlyExit { layer }));
}
}
// Simplified attention + FFN (for simulation purposes)
hidden = self.run_layer(&hidden, layer, shape);
}
// Compute output logits
let logits = self.compute_output(&hidden, &model.output_proj, shape);
Ok((logits, GateDecision::RanFull))
}
/// Run a single transformer layer
fn run_layer(&self, hidden: &[f32], layer: &LayerWeights, shape: &FixedShape) -> Vec<f32> {
let d_model = shape.d_model as usize;
let seq_len = hidden.len() / d_model;
// Simplified layer computation
// In a real implementation, this would do full attention + FFN
let mut output = hidden.to_vec();
// Layer norm 1
for t in 0..seq_len {
let start = t * d_model;
let end = start + d_model;
layer_norm_inplace(&mut output[start..end], &layer.ln1_weight);
}
// Simplified attention (just apply output projection as placeholder)
// Real implementation would compute Q, K, V, attention scores, etc.
if !layer.wo.is_empty() {
let mut attn_out = vec![0.0f32; output.len()];
for t in 0..seq_len {
for i in 0..d_model {
let mut sum = 0.0f32;
for j in 0..d_model.min(layer.wo.len() / d_model) {
sum += output[t * d_model + j] * layer.wo[j * d_model + i];
}
attn_out[t * d_model + i] = sum;
}
}
// Residual connection
for i in 0..output.len() {
output[i] += attn_out[i];
}
}
// Layer norm 2
for t in 0..seq_len {
let start = t * d_model;
let end = start + d_model;
layer_norm_inplace(&mut output[start..end], &layer.ln2_weight);
}
// Simplified FFN (SwiGLU-like)
if !layer.w1.is_empty() && !layer.w2.is_empty() {
let ffn_dim = layer.w1.len() / d_model;
let mut ffn_out = vec![0.0f32; output.len()];
for t in 0..seq_len {
// Up projection
let mut up = vec![0.0f32; ffn_dim];
for i in 0..ffn_dim {
for j in 0..d_model {
up[i] += output[t * d_model + j] * layer.w1[j * ffn_dim + i];
}
// SiLU activation
up[i] = up[i] * sigmoid(up[i]);
}
// Down projection
for i in 0..d_model {
for j in 0..ffn_dim.min(layer.w2.len() / d_model) {
ffn_out[t * d_model + i] += up[j] * layer.w2[j * d_model + i];
}
}
}
// Residual connection
for i in 0..output.len() {
output[i] += ffn_out[i];
}
}
output
}
/// Compute output logits from hidden state
fn compute_output(&self, hidden: &[f32], output_proj: &[f32], shape: &FixedShape) -> Vec<i16> {
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
let seq_len = hidden.len() / d_model;
// Take last token's hidden state
let last_hidden = &hidden[(seq_len - 1) * d_model..];
// Compute logits
let mut logits = vec![0.0f32; vocab];
if output_proj.len() >= d_model * vocab {
for v in 0..vocab {
for d in 0..d_model {
logits[v] += last_hidden[d] * output_proj[d * vocab + v];
}
}
} else {
// Fallback: random logits for simulation when weights not available
for v in 0..vocab {
logits[v] = (v as f32 * 0.01).sin();
}
}
// Apply softmax (optional) and quantize
if self.config.lut_softmax {
softmax_lut(&mut logits);
} else {
softmax_f32(&mut logits);
}
// Quantize to i16
quantize_i16(&logits)
}
/// Compute coherence signal for early exit decision
fn compute_coherence_signal(&self, hidden: &[f32]) -> i16 {
// Simple coherence metric: variance of hidden states
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
// Scale to Q8.8 fixed point
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
}
/// Prepare model from artifact (dequantize weights for simulation)
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<LoadedModel> {
let shape = &artifact.manifest.shape;
let quant = &artifact.manifest.quant;
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
// Dequantize embeddings
let embedding_size = vocab * d_model;
let embeddings = if artifact.weights.len() >= embedding_size {
dequantize_i8(&artifact.weights[..embedding_size], quant)
} else {
// Generate random embeddings for testing
(0..embedding_size)
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
.collect()
};
// Create simplified layer weights
let num_layers = 4; // Default for simulation
let layers: Vec<LayerWeights> = (0..num_layers)
.map(|_| LayerWeights {
wq: vec![0.01; d_model * d_model],
wk: vec![0.01; d_model * d_model],
wv: vec![0.01; d_model * d_model],
wo: vec![0.01; d_model * d_model],
w1: vec![0.01; d_model * 4 * d_model],
w2: vec![0.01; 4 * d_model * d_model],
ln1_weight: vec![1.0; d_model],
ln2_weight: vec![1.0; d_model],
})
.collect();
// Output projection
let output_proj = vec![0.01; d_model * vocab];
Ok(LoadedModel {
artifact: artifact.clone(),
embeddings,
layers,
output_proj,
})
}
}
impl TransformerBackend for NativeSimBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
// Prepare model
let model = self.prepare_model(artifact)?;
let model_id = artifact.model_id();
// Check capacity (with poison handling)
let at_capacity = read_lock(&self.models, |models| {
models.len() >= self.config.max_models && !models.contains_key(&model_id)
})?;
if at_capacity {
return Err(Error::ResourceExhausted("Max models reached".into()));
}
// Store model
write_lock(&self.models, |models| {
models.insert(model_id, Arc::new(model));
})?;
// Update stats
write_lock(&self.stats, |stats| {
stats.models_loaded += 1;
})?;
Ok(model_id)
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
let start = Instant::now();
// Validate request
req.validate()?;
// Get model (with poison handling)
let model = read_lock(&self.models, |models| models.get(&req.model).cloned())?
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate shape
if model.artifact.manifest.shape != req.shape {
return Err(Error::ShapeMismatch {
expected: model.artifact.manifest.shape,
actual: req.shape,
});
}
// Validate tokens against vocabulary
validate_tokens(req.tokens, model.artifact.manifest.shape.vocab)?;
// Run inference
let (logits_q, gate_decision) =
self.run_inference(&model, req.tokens, req.attn_mask, &req.gate_hint)?;
let latency_ns = start.elapsed().as_nanos() as u32;
// Compute top-K using common utility
let topk = compute_topk(&logits_q, 16);
// Create witness
let witness = WitnessLog::new(
model.artifact.model_hash(),
model.artifact.quant_hash(),
BackendKind::NativeSim,
0, // No cycles for simulator
latency_ns,
gate_decision,
);
// Update stats (with poison handling)
write_lock(&self.stats, |stats| {
stats.total_inferences += 1;
let n = stats.total_inferences;
stats.avg_latency_ns = (stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => stats.early_exits += 1,
GateDecision::Skipped { .. } => stats.skipped += 1,
_ => {}
}
})?;
Ok(InferenceResult::new(logits_q, Some(topk), witness))
}
fn unload(&self, model: ModelId) -> Result<()> {
let removed = write_lock(&self.models, |models| models.remove(&model).is_some())?;
if removed {
write_lock(&self.stats, |stats| {
stats.models_loaded = stats.models_loaded.saturating_sub(1);
})?;
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
read_lock(&self.models, |m| m.contains_key(&model)).unwrap_or(false)
}
fn kind(&self) -> BackendKind {
BackendKind::NativeSim
}
fn stats(&self) -> BackendStats {
read_lock(&self.stats, |s| s.clone()).unwrap_or_default()
}
}
// Helper functions
fn layer_norm_inplace(x: &mut [f32], weight: &[f32]) {
let mean = x.iter().sum::<f32>() / x.len() as f32;
let variance = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
let std = (variance + 1e-5).sqrt();
for (i, v) in x.iter_mut().enumerate() {
*v = (*v - mean) / std * weight.get(i).copied().unwrap_or(1.0);
}
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn softmax_f32(x: &mut [f32]) {
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
if sum > 0.0 {
for v in x.iter_mut() {
*v /= sum;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::gating::DefaultCoherenceGate;
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "test_model".into(),
model_hash: "0".repeat(64),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact {
manifest,
weights: vec![0u8; 4096 * 64], // Minimal embedding weights
bitstream: None,
calibration: None,
test_vectors: vec![],
signature: [0u8; 64],
pubkey: [0u8; 32],
}
}
#[test]
fn test_native_sim_load_unload() {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let artifact = create_test_artifact();
let model_id = backend.load(&artifact).unwrap();
assert!(backend.is_loaded(model_id));
backend.unload(model_id).unwrap();
assert!(!backend.is_loaded(model_id));
}
#[test]
fn test_native_sim_inference() {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::new(gate);
let artifact = create_test_artifact();
let model_id = backend.load(&artifact).unwrap();
let tokens: Vec<u16> = (0..32).collect();
let mask = vec![1u8; 32];
let req = InferenceRequest::new(
model_id,
FixedShape::micro(),
&tokens,
&mask,
GateHint::allow_all(),
);
let result = backend.infer(req).unwrap();
assert!(!result.logits_q.is_empty());
assert!(result.topk.is_some());
assert_eq!(result.witness.backend, BackendKind::NativeSim);
}
}

View File

@@ -0,0 +1,348 @@
//! WASM Simulator backend
//!
//! Pure Rust implementation that runs in WASM environments.
//! Uses RefCell for interior mutability since WASM is single-threaded.
#![cfg(feature = "wasm")]
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use crate::artifact::ModelArtifact;
use crate::backend::{compute_topk, validate_tokens, BackendStats, TransformerBackend};
use crate::error::{Error, Result};
use crate::gating::CoherenceGate;
use crate::quant::{dequantize_i8, quantize_i16};
use crate::types::{
BackendKind, FixedShape, GateDecision, GateHint, InferenceRequest, InferenceResult, ModelId,
QuantSpec, WitnessLog,
};
/// Loaded model for WASM simulation
struct WasmModel {
/// Model artifact
artifact: ModelArtifact,
/// Prepacked embedding table (dequantized to f32 for computation)
embeddings: Vec<f32>,
/// Number of layers
num_layers: usize,
/// Shape info
shape: FixedShape,
}
/// WASM simulator backend state (interior mutable for single-threaded WASM)
struct WasmState {
/// Loaded models
models: HashMap<ModelId, WasmModel>,
/// Statistics
stats: BackendStats,
}
/// WASM simulator backend
///
/// Uses RefCell for interior mutability since WASM is inherently single-threaded.
/// This allows the TransformerBackend trait to be implemented with &self methods.
pub struct WasmSimBackend {
/// Interior mutable state
state: RefCell<WasmState>,
/// Coherence gate (immutable, shared)
gate: Rc<dyn CoherenceGate>,
}
impl WasmSimBackend {
/// Create a new WASM simulator backend
pub fn new(gate: Rc<dyn CoherenceGate>) -> Self {
Self {
state: RefCell::new(WasmState {
models: HashMap::new(),
stats: BackendStats::default(),
}),
gate,
}
}
/// Prepare model from artifact
fn prepare_model(&self, artifact: &ModelArtifact) -> Result<WasmModel> {
let shape = artifact.manifest.shape;
let quant = &artifact.manifest.quant;
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
// Dequantize embeddings
let embedding_size = vocab * d_model;
let embeddings = if artifact.weights.len() >= embedding_size {
dequantize_i8(&artifact.weights[..embedding_size], quant)
} else {
// Generate deterministic embeddings for testing
(0..embedding_size)
.map(|i| ((i as f32 * 0.001).sin() * 0.1))
.collect()
};
// Determine number of layers from artifact or default
let num_layers = if artifact.manifest.backend.options.early_exit {
6
} else {
4
};
Ok(WasmModel {
artifact: artifact.clone(),
embeddings,
num_layers,
shape,
})
}
/// Run inference for WASM
fn run_inference(
&self,
model: &WasmModel,
tokens: &[u16],
gate_hint: &GateHint,
) -> (Vec<i16>, GateDecision) {
let shape = &model.shape;
// Check preflight
let preflight = self.gate.preflight(gate_hint);
if !preflight.did_run() {
return (vec![0i16; shape.vocab as usize], preflight);
}
let vocab = shape.vocab as usize;
let d_model = shape.d_model as usize;
// Initialize hidden state from embeddings
let seq_len = tokens.len();
let mut hidden = vec![0.0f32; seq_len * d_model];
// Lookup embeddings with bounds checking
for (i, &token) in tokens.iter().enumerate() {
let offset = (token as usize).min(vocab.saturating_sub(1)) * d_model;
if offset + d_model <= model.embeddings.len() {
hidden[i * d_model..(i + 1) * d_model]
.copy_from_slice(&model.embeddings[offset..offset + d_model]);
}
}
// Run through simplified layers with early exit support
for layer in 0..model.num_layers {
// Simple layer computation (for WASM we keep it lightweight)
// Apply simple transformation
for t in 0..seq_len {
let start = t * d_model;
// Simple ReLU-like activation
for i in 0..d_model {
hidden[start + i] =
hidden[start + i].max(0.0) * 0.99 + hidden[start + i] * 0.01;
}
}
// Check for early exit
let coherence_signal = compute_coherence(&hidden);
if let Some(decision) = self.gate.checkpoint(layer as u8, coherence_signal) {
let logits = self.compute_output(&hidden, model);
return (logits, decision);
}
}
// Compute output logits
let logits = self.compute_output(&hidden, model);
(logits, GateDecision::RanFull)
}
/// Compute output logits from hidden state
fn compute_output(&self, hidden: &[f32], model: &WasmModel) -> Vec<i16> {
let shape = &model.shape;
let d_model = shape.d_model as usize;
let vocab = shape.vocab as usize;
let seq_len = hidden.len() / d_model;
// Take last token's hidden state
let last_hidden = &hidden[(seq_len.saturating_sub(1)) * d_model..];
// Compute logits via dot product with embedding matrix (transposed)
let mut logits_f32 = vec![0.0f32; vocab];
for v in 0..vocab.min(model.embeddings.len() / d_model) {
let v_offset = v * d_model;
let mut dot = 0.0f32;
for d in 0..d_model.min(last_hidden.len()) {
if v_offset + d < model.embeddings.len() {
dot += last_hidden[d] * model.embeddings[v_offset + d];
}
}
logits_f32[v] = dot;
}
// Apply softmax and quantize
softmax_inplace(&mut logits_f32);
quantize_i16(&logits_f32)
}
}
// Note: WASM is single-threaded, so these trait bounds are satisfied trivially
// by never actually being used across threads
unsafe impl Send for WasmSimBackend {}
unsafe impl Sync for WasmSimBackend {}
impl TransformerBackend for WasmSimBackend {
fn load(&self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
// Prepare model
let model = self.prepare_model(artifact)?;
let model_id = artifact.model_id();
// Store in state
let mut state = self.state.borrow_mut();
state.models.insert(model_id, model);
state.stats.models_loaded += 1;
Ok(model_id)
}
fn infer(&self, req: InferenceRequest) -> Result<InferenceResult> {
let start = js_sys::Date::now();
// Validate request
req.validate()?;
// Get model (immutable borrow)
let state = self.state.borrow();
let model = state
.models
.get(&req.model)
.ok_or_else(|| Error::ModelNotFound(req.model))?;
// Validate tokens
validate_tokens(req.tokens, model.shape.vocab)?;
// Run inference
let (logits, gate_decision) = self.run_inference(model, req.tokens, &req.gate_hint);
let latency_ns = ((js_sys::Date::now() - start) * 1_000_000.0) as u32;
// Compute top-K
let topk = compute_topk(&logits, 16);
// Build witness
let witness = WitnessLog::new(
model.artifact.model_hash(),
model.artifact.quant_hash(),
BackendKind::WasmSim,
0, // No cycles for WASM sim
latency_ns,
gate_decision,
);
drop(state); // Release borrow before mutable borrow
// Update stats
{
let mut state = self.state.borrow_mut();
state.stats.total_inferences += 1;
let n = state.stats.total_inferences;
state.stats.avg_latency_ns =
(state.stats.avg_latency_ns * (n - 1) + latency_ns as u64) / n;
match gate_decision {
GateDecision::EarlyExit { .. } => state.stats.early_exits += 1,
GateDecision::Skipped { .. } => state.stats.skipped += 1,
_ => {}
}
}
Ok(InferenceResult::new(logits, Some(topk), witness))
}
fn unload(&self, model: ModelId) -> Result<()> {
let mut state = self.state.borrow_mut();
if state.models.remove(&model).is_some() {
state.stats.models_loaded = state.stats.models_loaded.saturating_sub(1);
Ok(())
} else {
Err(Error::ModelNotFound(model))
}
}
fn is_loaded(&self, model: ModelId) -> bool {
self.state.borrow().models.contains_key(&model)
}
fn kind(&self) -> BackendKind {
BackendKind::WasmSim
}
fn stats(&self) -> BackendStats {
self.state.borrow().stats.clone()
}
}
/// Compute coherence signal from hidden state
fn compute_coherence(hidden: &[f32]) -> i16 {
if hidden.is_empty() {
return 0;
}
let mean = hidden.iter().sum::<f32>() / hidden.len() as f32;
let variance = hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden.len() as f32;
((variance * 256.0).clamp(-32768.0, 32767.0)) as i16
}
/// In-place softmax
fn softmax_inplace(x: &mut [f32]) {
if x.is_empty() {
return;
}
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
if sum > 0.0 {
for v in x.iter_mut() {
*v /= sum;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::artifact::Manifest;
use crate::gating::DefaultCoherenceGate;
fn create_test_artifact() -> ModelArtifact {
let manifest = Manifest {
name: "wasm_test".into(),
model_hash: "0".repeat(64),
shape: FixedShape::micro(),
quant: QuantSpec::int8(),
io: Default::default(),
backend: Default::default(),
tests: Default::default(),
};
ModelArtifact {
manifest,
weights: vec![0u8; 4096 * 64],
bitstream: None,
calibration: None,
test_vectors: vec![],
signature: [0u8; 64],
pubkey: [0u8; 32],
}
}
#[test]
fn test_wasm_sim_prepare_model() {
let gate = Rc::new(DefaultCoherenceGate::new());
let backend = WasmSimBackend::new(gate);
let artifact = create_test_artifact();
let model = backend.prepare_model(&artifact).unwrap();
assert_eq!(model.shape.seq_len, 32);
assert!(!model.embeddings.is_empty());
}
}

View File

@@ -0,0 +1,136 @@
//! Error types for FPGA Transformer backend
use thiserror::Error;
/// Result type alias for FPGA Transformer operations
pub type Result<T> = std::result::Result<T, Error>;
/// FPGA Transformer error types
#[derive(Error, Debug)]
pub enum Error {
/// Model artifact is invalid or corrupted
#[error("Invalid artifact: {0}")]
InvalidArtifact(String),
/// Artifact signature verification failed
#[error("Signature verification failed: {0}")]
SignatureError(String),
/// Test vectors failed validation
#[error("Test vector validation failed: expected max error {expected}, got {actual}")]
TestVectorError { expected: i32, actual: i32 },
/// Model not found or not loaded
#[error("Model not found: {0:?}")]
ModelNotFound(crate::types::ModelId),
/// Shape mismatch between request and model
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
expected: crate::types::FixedShape,
actual: crate::types::FixedShape,
},
/// Input length does not match expected sequence length
#[error("Input length mismatch: expected {expected}, got {actual}")]
InputLengthMismatch { expected: usize, actual: usize },
/// Backend communication error
#[error("Backend error: {0}")]
BackendError(String),
/// Daemon connection failed
#[error("Daemon connection failed: {0}")]
DaemonConnectionError(String),
/// PCIe communication error
#[error("PCIe error: {0}")]
PcieError(String),
/// DMA operation failed
#[error("DMA error: {0}")]
DmaError(String),
/// Gating decision blocked inference
#[error("Inference blocked by gate: {reason:?}")]
GateBlocked { reason: crate::types::SkipReason },
/// Quantization error
#[error("Quantization error: {0}")]
QuantizationError(String),
/// Overflow during fixed-point computation
#[error("Fixed-point overflow at {location}")]
FixedPointOverflow { location: &'static str },
/// Invalid configuration
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
/// IO error
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
/// JSON parsing error
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
/// Checksum mismatch
#[error("Checksum mismatch: expected {expected:08x}, got {actual:08x}")]
ChecksumMismatch { expected: u32, actual: u32 },
/// Protocol version mismatch
#[error("Protocol version mismatch: expected {expected}, got {actual}")]
ProtocolMismatch { expected: u16, actual: u16 },
/// Timeout waiting for response
#[error("Timeout after {ms}ms")]
Timeout { ms: u64 },
/// Resource exhausted (memory, slots, etc.)
#[error("Resource exhausted: {0}")]
ResourceExhausted(String),
/// Feature not available in this build
#[error("Feature not available: {0}")]
FeatureNotAvailable(String),
}
impl Error {
/// Create a new InvalidArtifact error
pub fn invalid_artifact(msg: impl Into<String>) -> Self {
Self::InvalidArtifact(msg.into())
}
/// Create a new BackendError
pub fn backend(msg: impl Into<String>) -> Self {
Self::BackendError(msg.into())
}
/// Create a new DaemonConnectionError
pub fn daemon_connection(msg: impl Into<String>) -> Self {
Self::DaemonConnectionError(msg.into())
}
/// Check if this error is recoverable
pub fn is_recoverable(&self) -> bool {
matches!(
self,
Error::Timeout { .. }
| Error::DaemonConnectionError(_)
| Error::BackendError(_)
| Error::GateBlocked { .. }
)
}
/// Check if this error indicates a configuration problem
pub fn is_config_error(&self) -> bool {
matches!(
self,
Error::InvalidConfig(_)
| Error::ShapeMismatch { .. }
| Error::InputLengthMismatch { .. }
| Error::FeatureNotAvailable(_)
)
}
}

View File

@@ -0,0 +1,302 @@
//! C ABI bindings for FFI integration
//!
//! Provides a stable C interface for linking from other languages.
use std::ffi::{c_char, c_int, c_void, CStr};
use std::ptr;
use std::sync::Arc;
use crate::backend::native_sim::NativeSimBackend;
use crate::backend::TransformerBackend;
use crate::gating::DefaultCoherenceGate;
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
/// Opaque engine handle
pub struct FpgaEngine {
backend: Box<dyn TransformerBackend>,
}
/// Result code
#[repr(C)]
pub enum FpgaResult {
Ok = 0,
InvalidArgument = 1,
ModelNotFound = 2,
InferenceFailed = 3,
AllocationFailed = 4,
InvalidArtifact = 5,
}
/// Inference result structure
#[repr(C)]
pub struct FpgaInferenceResult {
/// Status code
pub status: FpgaResult,
/// Logits (caller must free with fpga_free_logits)
pub logits: *mut i16,
/// Number of logits
pub logits_len: usize,
/// Top-K results (token_id, logit pairs)
pub topk: *mut u32,
/// Number of top-K pairs
pub topk_len: usize,
/// Latency in nanoseconds
pub latency_ns: u32,
/// Compute cycles
pub cycles: u32,
/// Gate decision (0=full, 1=early_exit, 2=skipped)
pub gate_decision: u8,
/// Exit layer (if early exit)
pub exit_layer: u8,
}
/// Create a new FPGA engine with native simulator backend
///
/// Returns a handle that must be freed with `fpga_engine_destroy`
#[no_mangle]
pub extern "C" fn fpga_engine_create() -> *mut FpgaEngine {
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = Box::new(NativeSimBackend::new(gate));
let engine = Box::new(FpgaEngine { backend });
Box::into_raw(engine)
}
/// Destroy an FPGA engine
#[no_mangle]
pub extern "C" fn fpga_engine_destroy(engine: *mut FpgaEngine) {
if !engine.is_null() {
unsafe {
drop(Box::from_raw(engine));
}
}
}
/// Load a model artifact
///
/// Returns model ID bytes (32 bytes) on success, NULL on failure
#[no_mangle]
pub extern "C" fn fpga_load_artifact(
engine: *mut FpgaEngine,
artifact_bytes: *const u8,
artifact_len: usize,
model_id_out: *mut u8,
) -> FpgaResult {
if engine.is_null() || artifact_bytes.is_null() || model_id_out.is_null() {
return FpgaResult::InvalidArgument;
}
let engine = unsafe { &mut *engine };
let artifact_slice = unsafe { std::slice::from_raw_parts(artifact_bytes, artifact_len) };
let artifact = match crate::artifact::unpack_artifact(artifact_slice) {
Ok(a) => a,
Err(_) => return FpgaResult::InvalidArtifact,
};
match engine.backend.load(&artifact) {
Ok(model_id) => {
unsafe {
ptr::copy_nonoverlapping(model_id.as_bytes().as_ptr(), model_id_out, 32);
}
FpgaResult::Ok
}
Err(_) => FpgaResult::InvalidArtifact,
}
}
/// Run inference
///
/// Result must be freed with `fpga_result_free`
#[no_mangle]
pub extern "C" fn fpga_infer(
engine: *mut FpgaEngine,
model_id: *const u8,
tokens: *const u16,
tokens_len: usize,
mask: *const u8,
mask_len: usize,
coherence_score: i16,
boundary_crossed: bool,
max_compute_class: u8,
) -> FpgaInferenceResult {
let error_result = || FpgaInferenceResult {
status: FpgaResult::InvalidArgument,
logits: ptr::null_mut(),
logits_len: 0,
topk: ptr::null_mut(),
topk_len: 0,
latency_ns: 0,
cycles: 0,
gate_decision: 2,
exit_layer: 0,
};
if engine.is_null() || model_id.is_null() || tokens.is_null() || mask.is_null() {
return error_result();
}
let engine = unsafe { &mut *engine };
// Parse model ID
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(id_slice);
let model = ModelId::new(id_bytes);
// Parse tokens and mask
let tokens_slice = unsafe { std::slice::from_raw_parts(tokens, tokens_len) };
let mask_slice = unsafe { std::slice::from_raw_parts(mask, mask_len) };
// Build shape (micro for C API)
let shape = FixedShape::micro();
// Build gate hint
let compute_class =
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
let gate_hint = GateHint::new(coherence_score, boundary_crossed, compute_class);
// Create request
let req = InferenceRequest::new(model, shape, tokens_slice, mask_slice, gate_hint);
// Run inference
match engine.backend.infer(req) {
Ok(result) => {
// Allocate logits with checked allocation (prevents panic on overflow)
let logits_len = result.logits_q.len();
let logits = if logits_len > 0 {
match std::alloc::Layout::array::<i16>(logits_len) {
Ok(layout) if layout.size() > 0 => {
let ptr = unsafe { std::alloc::alloc(layout) as *mut i16 };
if !ptr.is_null() {
unsafe {
ptr::copy_nonoverlapping(result.logits_q.as_ptr(), ptr, logits_len);
}
}
ptr
}
_ => ptr::null_mut(), // Return null on allocation failure
}
} else {
ptr::null_mut()
};
// Allocate top-K with checked allocation
let (topk, topk_len) = if let Some(ref tk) = result.topk {
let len = tk.len() * 2; // (token, logit) pairs
match std::alloc::Layout::array::<u32>(len) {
Ok(layout) if layout.size() > 0 => {
let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 };
if !ptr.is_null() {
for (i, (token, logit)) in tk.iter().enumerate() {
unsafe {
*ptr.add(i * 2) = *token as u32;
*ptr.add(i * 2 + 1) = *logit as u32;
}
}
}
(ptr, tk.len())
}
_ => (ptr::null_mut(), 0), // Return null on allocation failure
}
} else {
(ptr::null_mut(), 0)
};
// Encode gate decision
let (gate_decision, exit_layer) = match result.witness.gate_decision {
crate::types::GateDecision::RanFull => (0, 0),
crate::types::GateDecision::EarlyExit { layer } => (1, layer),
crate::types::GateDecision::Skipped { .. } => (2, 0),
};
FpgaInferenceResult {
status: FpgaResult::Ok,
logits,
logits_len,
topk,
topk_len,
latency_ns: result.witness.latency_ns,
cycles: result.witness.cycles,
gate_decision,
exit_layer,
}
}
Err(_) => {
let mut result = error_result();
result.status = FpgaResult::InferenceFailed;
result
}
}
}
/// Free inference result
#[no_mangle]
pub extern "C" fn fpga_result_free(result: *mut FpgaInferenceResult) {
if result.is_null() {
return;
}
unsafe {
let r = &mut *result;
if !r.logits.is_null() && r.logits_len > 0 {
std::alloc::dealloc(
r.logits as *mut u8,
std::alloc::Layout::array::<i16>(r.logits_len).unwrap(),
);
r.logits = ptr::null_mut();
}
if !r.topk.is_null() && r.topk_len > 0 {
std::alloc::dealloc(
r.topk as *mut u8,
std::alloc::Layout::array::<u32>(r.topk_len * 2).unwrap(),
);
r.topk = ptr::null_mut();
}
}
}
/// Unload a model
#[no_mangle]
pub extern "C" fn fpga_unload(engine: *mut FpgaEngine, model_id: *const u8) -> FpgaResult {
if engine.is_null() || model_id.is_null() {
return FpgaResult::InvalidArgument;
}
let engine = unsafe { &mut *engine };
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(id_slice);
let model = ModelId::new(id_bytes);
match engine.backend.unload(model) {
Ok(()) => FpgaResult::Ok,
Err(_) => FpgaResult::ModelNotFound,
}
}
/// Check if a model is loaded
#[no_mangle]
pub extern "C" fn fpga_is_loaded(engine: *const FpgaEngine, model_id: *const u8) -> bool {
if engine.is_null() || model_id.is_null() {
return false;
}
let engine = unsafe { &*engine };
let id_slice = unsafe { std::slice::from_raw_parts(model_id, 32) };
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(id_slice);
let model = ModelId::new(id_bytes);
engine.backend.is_loaded(model)
}
/// Get version string
#[no_mangle]
pub extern "C" fn fpga_version() -> *const c_char {
// Static string with null terminator
static VERSION: &[u8] = b"0.1.0\0";
VERSION.as_ptr() as *const c_char
}

View File

@@ -0,0 +1,8 @@
//! Foreign function interfaces for FPGA Transformer
//!
//! Provides C ABI and WASM bindings.
#[cfg(feature = "wasm")]
pub mod wasm_bindgen;
pub mod c_abi;

View File

@@ -0,0 +1,282 @@
//! WASM bindings via wasm-bindgen
//!
//! Provides the same API shape for browser and Node.js environments.
#![cfg(feature = "wasm")]
use js_sys::{Array, Int16Array, Object, Reflect, Uint16Array, Uint8Array};
use wasm_bindgen::prelude::*;
use crate::artifact::{unpack_artifact, ModelArtifact};
use crate::backend::native_sim::{NativeSimBackend, NativeSimConfig};
use crate::backend::TransformerBackend;
use crate::gating::DefaultCoherenceGate;
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
use std::sync::Arc;
/// WASM Engine for transformer inference
#[wasm_bindgen]
pub struct WasmEngine {
backend: NativeSimBackend,
loaded_models: Vec<ModelId>,
last_witness: Option<crate::types::WitnessLog>,
}
#[wasm_bindgen]
impl WasmEngine {
/// Create a new WASM engine
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
// Use permissive config for WASM
let config = NativeSimConfig {
max_models: 4,
trace: false,
lut_softmax: true,
max_layers: 0,
};
let gate = Arc::new(DefaultCoherenceGate::new());
let backend = NativeSimBackend::with_config(gate, config);
Self {
backend,
loaded_models: Vec::new(),
last_witness: None,
}
}
/// Load a model artifact from bytes
///
/// Returns the model ID as a Uint8Array on success
#[wasm_bindgen(js_name = loadArtifact)]
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<Uint8Array, JsValue> {
let artifact = unpack_artifact(artifact_bytes)
.map_err(|e| JsValue::from_str(&format!("Failed to unpack artifact: {}", e)))?;
let model_id = self
.backend
.load(&artifact)
.map_err(|e| JsValue::from_str(&format!("Failed to load model: {}", e)))?;
self.loaded_models.push(model_id);
// Return model ID as Uint8Array
let id_array = Uint8Array::new_with_length(32);
id_array.copy_from(model_id.as_bytes());
Ok(id_array)
}
/// Run inference
///
/// Returns an object with logits, topk, and witness
#[wasm_bindgen]
pub fn infer(
&mut self,
model_id: &[u8],
tokens: &[u16],
mask: &[u8],
coherence_score_q: i16,
boundary_crossed: bool,
max_compute_class: u8,
) -> Result<JsValue, JsValue> {
// Parse model ID
if model_id.len() != 32 {
return Err(JsValue::from_str("Model ID must be 32 bytes"));
}
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(model_id);
let model = ModelId::new(id_bytes);
// Get shape from loaded model
// For WASM, we use micro shape by default
let shape = FixedShape::micro();
// Validate input length
if tokens.len() != shape.seq_len as usize {
return Err(JsValue::from_str(&format!(
"Token length mismatch: expected {}, got {}",
shape.seq_len,
tokens.len()
)));
}
// Build gate hint
let compute_class =
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
let gate_hint = GateHint::new(coherence_score_q, boundary_crossed, compute_class);
// Create request
let req = InferenceRequest::new(model, shape, tokens, mask, gate_hint);
// Run inference
let result = self
.backend
.infer(req)
.map_err(|e| JsValue::from_str(&format!("Inference failed: {}", e)))?;
// Store witness
self.last_witness = Some(result.witness.clone());
// Build result object
let obj = Object::new();
// Add logits
let logits = Int16Array::new_with_length(result.logits_q.len() as u32);
logits.copy_from(&result.logits_q);
Reflect::set(&obj, &"logits".into(), &logits)?;
// Add top-K if available
if let Some(topk) = &result.topk {
let topk_array = Array::new();
for (token, logit) in topk {
let pair = Array::new();
pair.push(&JsValue::from(*token));
pair.push(&JsValue::from(*logit));
topk_array.push(&pair);
}
Reflect::set(&obj, &"topk".into(), &topk_array)?;
}
// Add witness info
let witness = Object::new();
Reflect::set(
&witness,
&"backend".into(),
&format!("{:?}", result.witness.backend).into(),
)?;
Reflect::set(
&witness,
&"cycles".into(),
&JsValue::from(result.witness.cycles),
)?;
Reflect::set(
&witness,
&"latency_ns".into(),
&JsValue::from(result.witness.latency_ns),
)?;
Reflect::set(
&witness,
&"gate_decision".into(),
&format!("{:?}", result.witness.gate_decision).into(),
)?;
Reflect::set(&obj, &"witness".into(), &witness)?;
Ok(obj.into())
}
/// Get the last witness log as JSON
#[wasm_bindgen(js_name = getWitness)]
pub fn get_witness(&self) -> Result<JsValue, JsValue> {
match &self.last_witness {
Some(witness) => {
let json = serde_json::to_string(witness)
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))?;
Ok(JsValue::from_str(&json))
}
None => Ok(JsValue::NULL),
}
}
/// Get list of loaded model IDs
#[wasm_bindgen(js_name = getLoadedModels)]
pub fn get_loaded_models(&self) -> Array {
let arr = Array::new();
for id in &self.loaded_models {
let id_array = Uint8Array::new_with_length(32);
id_array.copy_from(id.as_bytes());
arr.push(&id_array);
}
arr
}
/// Unload a model
#[wasm_bindgen]
pub fn unload(&mut self, model_id: &[u8]) -> Result<(), JsValue> {
if model_id.len() != 32 {
return Err(JsValue::from_str("Model ID must be 32 bytes"));
}
let mut id_bytes = [0u8; 32];
id_bytes.copy_from_slice(model_id);
let model = ModelId::new(id_bytes);
self.backend
.unload(model)
.map_err(|e| JsValue::from_str(&format!("Unload failed: {}", e)))?;
self.loaded_models.retain(|id| *id != model);
Ok(())
}
/// Get backend statistics
#[wasm_bindgen(js_name = getStats)]
pub fn get_stats(&self) -> Result<JsValue, JsValue> {
let stats = self.backend.stats();
let obj = Object::new();
Reflect::set(
&obj,
&"models_loaded".into(),
&JsValue::from(stats.models_loaded as u32),
)?;
Reflect::set(
&obj,
&"total_inferences".into(),
&JsValue::from(stats.total_inferences as f64),
)?;
Reflect::set(
&obj,
&"avg_latency_ns".into(),
&JsValue::from(stats.avg_latency_ns as f64),
)?;
Reflect::set(
&obj,
&"early_exits".into(),
&JsValue::from(stats.early_exits as f64),
)?;
Reflect::set(
&obj,
&"skipped".into(),
&JsValue::from(stats.skipped as f64),
)?;
Ok(obj.into())
}
}
impl Default for WasmEngine {
fn default() -> Self {
Self::new()
}
}
/// Utility function to create a micro shape configuration
#[wasm_bindgen(js_name = microShape)]
pub fn micro_shape() -> Result<JsValue, JsValue> {
let shape = FixedShape::micro();
let obj = Object::new();
Reflect::set(&obj, &"seq_len".into(), &JsValue::from(shape.seq_len))?;
Reflect::set(&obj, &"d_model".into(), &JsValue::from(shape.d_model))?;
Reflect::set(&obj, &"heads".into(), &JsValue::from(shape.heads))?;
Reflect::set(&obj, &"d_head".into(), &JsValue::from(shape.d_head))?;
Reflect::set(&obj, &"vocab".into(), &JsValue::from(shape.vocab))?;
Ok(obj.into())
}
/// Utility function to validate an artifact without loading
#[wasm_bindgen(js_name = validateArtifact)]
pub fn validate_artifact(artifact_bytes: &[u8]) -> Result<JsValue, JsValue> {
let artifact = unpack_artifact(artifact_bytes)
.map_err(|e| JsValue::from_str(&format!("Invalid artifact: {}", e)))?;
artifact
.validate()
.map_err(|e| JsValue::from_str(&format!("Validation failed: {}", e)))?;
let obj = Object::new();
Reflect::set(&obj, &"name".into(), &artifact.manifest.name.into())?;
Reflect::set(&obj, &"valid".into(), &JsValue::TRUE)?;
Ok(obj.into())
}

View File

@@ -0,0 +1,301 @@
//! Coherence-based gating for inference control
use crate::types::{ComputeClass, GateDecision, GateHint, SkipReason};
use crate::witness::WitnessLog;
/// Trait for coherence-based gating
pub trait CoherenceGate: Send + Sync {
/// Preflight check before inference
///
/// Returns a gate decision based on coherence signals.
fn preflight(&self, hint: &GateHint) -> GateDecision;
/// Layer checkpoint for early exit decisions
///
/// Called after each layer to determine if early exit is appropriate.
/// Returns Some(decision) to exit early, None to continue.
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision>;
/// Check if write is allowed based on witness
///
/// Used to gate state changes in memory systems.
fn allow_write(&self, witness: &WitnessLog) -> bool;
}
/// Configuration for coherence gate
#[derive(Debug, Clone)]
pub struct CoherenceConfig {
/// Minimum coherence score to run (Q8.8)
pub min_coherence: i16,
/// Coherence threshold for early exit
pub early_exit_threshold: i16,
/// Enable early exit
pub early_exit_enabled: bool,
/// Minimum layers before early exit
pub min_layers: u8,
/// Require stable coherence for writes
pub require_stable_for_write: bool,
/// Minimum coherence for writes (Q8.8)
pub min_write_coherence: i16,
}
impl Default for CoherenceConfig {
fn default() -> Self {
Self {
min_coherence: -256, // -1.0 in Q8.8, very permissive
early_exit_threshold: 512, // 2.0 in Q8.8
early_exit_enabled: true,
min_layers: 2,
require_stable_for_write: true,
min_write_coherence: 0, // Require non-negative coherence
}
}
}
impl CoherenceConfig {
/// Create a strict configuration
pub fn strict() -> Self {
Self {
min_coherence: 0,
early_exit_threshold: 256,
early_exit_enabled: true,
min_layers: 4,
require_stable_for_write: true,
min_write_coherence: 128,
}
}
/// Create a permissive configuration (always allows)
pub fn permissive() -> Self {
Self {
min_coherence: i16::MIN,
early_exit_threshold: i16::MAX,
early_exit_enabled: false,
min_layers: 0,
require_stable_for_write: false,
min_write_coherence: i16::MIN,
}
}
}
/// Default coherence gate implementation
pub struct DefaultCoherenceGate {
config: CoherenceConfig,
}
impl DefaultCoherenceGate {
/// Create with default config
pub fn new() -> Self {
Self::with_config(CoherenceConfig::default())
}
/// Create with custom config
pub fn with_config(config: CoherenceConfig) -> Self {
Self { config }
}
/// Check if compute class allows operation
fn check_compute_class(&self, hint: &GateHint) -> bool {
// Reflex class can always run (fast path)
// Higher classes require sufficient coherence
match hint.max_compute_class {
ComputeClass::Reflex => true,
ComputeClass::Associative => hint.coherence_score_q >= self.config.min_coherence / 2,
ComputeClass::Deliberative => hint.coherence_score_q >= self.config.min_coherence,
}
}
}
impl Default for DefaultCoherenceGate {
fn default() -> Self {
Self::new()
}
}
impl CoherenceGate for DefaultCoherenceGate {
fn preflight(&self, hint: &GateHint) -> GateDecision {
// Check minimum coherence
if hint.coherence_score_q < self.config.min_coherence {
return GateDecision::Skipped {
reason: SkipReason::LowCoherence,
};
}
// Check compute class restrictions
if !self.check_compute_class(hint) {
return GateDecision::Skipped {
reason: SkipReason::BudgetExceeded,
};
}
// Allow full inference
GateDecision::RanFull
}
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
if !self.config.early_exit_enabled {
return None;
}
if layer < self.config.min_layers {
return None;
}
// Check if coherence signal is high enough to exit early
if signal_q >= self.config.early_exit_threshold {
return Some(GateDecision::EarlyExit { layer });
}
None
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
// Skip writes if inference was skipped
if !witness.gate_decision.did_run() {
return false;
}
// If we require stable coherence, only allow writes after full run
if self.config.require_stable_for_write {
matches!(witness.gate_decision, GateDecision::RanFull)
} else {
true
}
}
}
/// Mincut-aware coherence gate
///
/// Uses mincut signals to make more informed gating decisions.
pub struct MincutCoherenceGate {
base: DefaultCoherenceGate,
/// Minimum lambda (mincut value) for inference
pub min_lambda: i16,
/// Lambda threshold for early exit
pub lambda_exit_threshold: i16,
}
impl MincutCoherenceGate {
/// Create a new mincut-aware gate
pub fn new(config: CoherenceConfig, min_lambda: i16, lambda_exit_threshold: i16) -> Self {
Self {
base: DefaultCoherenceGate::with_config(config),
min_lambda,
lambda_exit_threshold,
}
}
}
impl CoherenceGate for MincutCoherenceGate {
fn preflight(&self, hint: &GateHint) -> GateDecision {
// Use base coherence check
let base_decision = self.base.preflight(hint);
if !base_decision.did_run() {
return base_decision;
}
// Additional mincut check
// If boundary was crossed and coherence is low, skip
if hint.boundary_crossed && hint.coherence_score_q < 0 {
return GateDecision::Skipped {
reason: SkipReason::LowCoherence,
};
}
GateDecision::RanFull
}
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
// Use base checkpoint with mincut-adjusted threshold
let adjusted_threshold = if signal_q > self.lambda_exit_threshold {
// High lambda suggests stable state, lower exit threshold
self.base.config.early_exit_threshold / 2
} else {
self.base.config.early_exit_threshold
};
if layer >= self.base.config.min_layers && signal_q >= adjusted_threshold {
return Some(GateDecision::EarlyExit { layer });
}
None
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
// Use base write check
self.base.allow_write(witness)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_gate_preflight() {
let gate = DefaultCoherenceGate::new();
// High coherence should pass
let hint = GateHint::new(256, false, ComputeClass::Deliberative);
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
// Low coherence should fail
let hint = GateHint::new(-512, false, ComputeClass::Deliberative);
assert!(matches!(
gate.preflight(&hint),
GateDecision::Skipped { .. }
));
}
#[test]
fn test_early_exit_checkpoint() {
let gate = DefaultCoherenceGate::new();
// Layer 0 - too early
assert!(gate.checkpoint(0, 1000).is_none());
// Layer 4 with high signal - should exit
let decision = gate.checkpoint(4, 1000);
assert!(matches!(
decision,
Some(GateDecision::EarlyExit { layer: 4 })
));
}
#[test]
fn test_write_gating() {
let gate = DefaultCoherenceGate::new();
// Full run should allow writes
let witness = crate::witness::WitnessLog::empty();
assert!(gate.allow_write(&witness));
// Skipped should not allow writes
let mut skipped_witness = crate::witness::WitnessLog::empty();
skipped_witness.gate_decision = GateDecision::Skipped {
reason: SkipReason::LowCoherence,
};
assert!(!gate.allow_write(&skipped_witness));
}
#[test]
fn test_strict_config() {
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::strict());
// Strict should require positive coherence
let hint = GateHint::new(-1, false, ComputeClass::Deliberative);
assert!(matches!(
gate.preflight(&hint),
GateDecision::Skipped { .. }
));
}
#[test]
fn test_permissive_config() {
let gate = DefaultCoherenceGate::with_config(CoherenceConfig::permissive());
// Permissive should allow anything
let hint = GateHint::new(i16::MIN, true, ComputeClass::Reflex);
assert!(matches!(gate.preflight(&hint), GateDecision::RanFull));
}
}

View File

@@ -0,0 +1,95 @@
//! Gating subsystem for coherence-based inference control
//!
//! Provides preflight and postflight gates that integrate mincut signals
//! and write policies for memory safety.
pub mod coherence_gate;
pub mod policy_gate;
pub use coherence_gate::{CoherenceConfig, CoherenceGate, DefaultCoherenceGate};
pub use policy_gate::{DefaultPolicyGate, PolicyGate, WritePolicy};
use crate::types::{GateDecision, GateHint, SkipReason};
use crate::witness::WitnessLog;
/// Combined gate that checks both coherence and policy
pub struct CombinedGate {
coherence: Box<dyn CoherenceGate>,
policy: Box<dyn PolicyGate>,
}
impl CombinedGate {
/// Create a new combined gate
pub fn new(coherence: Box<dyn CoherenceGate>, policy: Box<dyn PolicyGate>) -> Self {
Self { coherence, policy }
}
/// Create with default implementations
pub fn default_gates() -> Self {
Self::new(
Box::new(DefaultCoherenceGate::new()),
Box::new(DefaultPolicyGate::new()),
)
}
/// Preflight check before inference
pub fn preflight(&self, hint: &GateHint) -> GateDecision {
// First check policy
if !self.policy.allow_inference(hint) {
return GateDecision::Skipped {
reason: SkipReason::PolicyDenied,
};
}
// Then check coherence
self.coherence.preflight(hint)
}
/// Checkpoint during inference
pub fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
self.coherence.checkpoint(layer, signal_q)
}
/// Check if write is allowed after inference
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
self.coherence.allow_write(witness) && self.policy.allow_write(witness)
}
}
impl CoherenceGate for CombinedGate {
fn preflight(&self, hint: &GateHint) -> GateDecision {
CombinedGate::preflight(self, hint)
}
fn checkpoint(&self, layer: u8, signal_q: i16) -> Option<GateDecision> {
CombinedGate::checkpoint(self, layer, signal_q)
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
CombinedGate::allow_write(self, witness)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_combined_gate_preflight() {
let gate = CombinedGate::default_gates();
// Allow all hint should pass
let decision = gate.preflight(&GateHint::allow_all());
assert!(matches!(decision, GateDecision::RanFull));
}
#[test]
fn test_combined_gate_low_coherence() {
let gate = CombinedGate::default_gates();
// Very low coherence should skip
let hint = GateHint::new(-1000, false, crate::types::ComputeClass::Reflex);
let decision = gate.preflight(&hint);
assert!(matches!(decision, GateDecision::Skipped { .. }));
}
}

View File

@@ -0,0 +1,305 @@
//! Policy-based gating for access control and resource management
use crate::types::{ComputeClass, GateHint};
use crate::witness::WitnessLog;
/// Trait for policy-based gating
pub trait PolicyGate: Send + Sync {
/// Check if inference is allowed
fn allow_inference(&self, hint: &GateHint) -> bool;
/// Check if write is allowed after inference
fn allow_write(&self, witness: &WitnessLog) -> bool;
/// Get remaining compute budget
fn remaining_budget(&self) -> Option<u64>;
/// Record compute usage
fn record_usage(&self, cycles: u32);
}
/// Write policy configuration
#[derive(Debug, Clone)]
pub struct WritePolicy {
/// Allow writes after early exit
pub allow_early_exit_writes: bool,
/// Maximum latency (ns) for write eligibility
pub max_latency_ns: u32,
/// Require specific backend
pub required_backend: Option<crate::types::BackendKind>,
/// Minimum compute class for writes
pub min_compute_class: ComputeClass,
}
impl Default for WritePolicy {
fn default() -> Self {
Self {
allow_early_exit_writes: false,
max_latency_ns: u32::MAX,
required_backend: None,
min_compute_class: ComputeClass::Reflex,
}
}
}
impl WritePolicy {
/// Create a strict write policy
pub fn strict() -> Self {
Self {
allow_early_exit_writes: false,
max_latency_ns: 10_000_000, // 10ms
required_backend: None,
min_compute_class: ComputeClass::Deliberative,
}
}
/// Create a permissive write policy
pub fn permissive() -> Self {
Self {
allow_early_exit_writes: true,
max_latency_ns: u32::MAX,
required_backend: None,
min_compute_class: ComputeClass::Reflex,
}
}
/// Require FPGA backend for writes
pub fn require_fpga(mut self) -> Self {
self.required_backend = Some(crate::types::BackendKind::FpgaPcie);
self
}
}
/// Default policy gate implementation
pub struct DefaultPolicyGate {
write_policy: WritePolicy,
/// Compute budget (total cycles allowed, 0 = unlimited)
budget_cycles: std::sync::atomic::AtomicU64,
/// Used cycles
used_cycles: std::sync::atomic::AtomicU64,
}
impl DefaultPolicyGate {
/// Create with default policy
pub fn new() -> Self {
Self::with_policy(WritePolicy::default())
}
/// Create with custom write policy
pub fn with_policy(write_policy: WritePolicy) -> Self {
Self {
write_policy,
budget_cycles: std::sync::atomic::AtomicU64::new(0),
used_cycles: std::sync::atomic::AtomicU64::new(0),
}
}
/// Set compute budget
pub fn set_budget(&self, cycles: u64) {
self.budget_cycles
.store(cycles, std::sync::atomic::Ordering::SeqCst);
}
/// Reset used cycles
pub fn reset_usage(&self) {
self.used_cycles
.store(0, std::sync::atomic::Ordering::SeqCst);
}
}
impl Default for DefaultPolicyGate {
fn default() -> Self {
Self::new()
}
}
impl PolicyGate for DefaultPolicyGate {
fn allow_inference(&self, hint: &GateHint) -> bool {
// Check compute budget
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
if budget > 0 {
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
if used >= budget {
return false;
}
}
// Check compute class restrictions
// Always allow reflex, check others based on config
hint.max_compute_class >= ComputeClass::Reflex
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
// Check if inference ran
if !witness.gate_decision.did_run() {
return false;
}
// Check early exit policy
if !self.write_policy.allow_early_exit_writes {
if let crate::types::GateDecision::EarlyExit { .. } = witness.gate_decision {
return false;
}
}
// Check latency
if witness.latency_ns > self.write_policy.max_latency_ns {
return false;
}
// Check backend requirement
if let Some(required) = self.write_policy.required_backend {
if witness.backend != required {
return false;
}
}
true
}
fn remaining_budget(&self) -> Option<u64> {
let budget = self.budget_cycles.load(std::sync::atomic::Ordering::SeqCst);
if budget == 0 {
return None;
}
let used = self.used_cycles.load(std::sync::atomic::Ordering::SeqCst);
Some(budget.saturating_sub(used))
}
fn record_usage(&self, cycles: u32) {
self.used_cycles
.fetch_add(cycles as u64, std::sync::atomic::Ordering::SeqCst);
}
}
/// Rate-limited policy gate
pub struct RateLimitedPolicyGate {
base: DefaultPolicyGate,
/// Maximum inferences per second
max_inferences_per_sec: u32,
/// Inference count in current window
inference_count: std::sync::atomic::AtomicU32,
/// Window start time
window_start: std::sync::RwLock<std::time::Instant>,
}
impl RateLimitedPolicyGate {
/// Create with rate limit
pub fn new(max_inferences_per_sec: u32, write_policy: WritePolicy) -> Self {
Self {
base: DefaultPolicyGate::with_policy(write_policy),
max_inferences_per_sec,
inference_count: std::sync::atomic::AtomicU32::new(0),
window_start: std::sync::RwLock::new(std::time::Instant::now()),
}
}
/// Check and update rate limit
fn check_rate_limit(&self) -> bool {
let now = std::time::Instant::now();
// Check if we need to reset the window
{
let window_start = self.window_start.read().unwrap();
if now.duration_since(*window_start).as_secs() >= 1 {
drop(window_start);
let mut window_start = self.window_start.write().unwrap();
*window_start = now;
self.inference_count
.store(0, std::sync::atomic::Ordering::SeqCst);
}
}
// Check count
let count = self
.inference_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
count < self.max_inferences_per_sec
}
}
impl PolicyGate for RateLimitedPolicyGate {
fn allow_inference(&self, hint: &GateHint) -> bool {
// Check rate limit first
if !self.check_rate_limit() {
return false;
}
self.base.allow_inference(hint)
}
fn allow_write(&self, witness: &WitnessLog) -> bool {
self.base.allow_write(witness)
}
fn remaining_budget(&self) -> Option<u64> {
self.base.remaining_budget()
}
fn record_usage(&self, cycles: u32) {
self.base.record_usage(cycles);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy_allows_inference() {
let gate = DefaultPolicyGate::new();
let hint = GateHint::allow_all();
assert!(gate.allow_inference(&hint));
}
#[test]
fn test_budget_limiting() {
let gate = DefaultPolicyGate::new();
gate.set_budget(1000);
let hint = GateHint::allow_all();
// Should allow initially
assert!(gate.allow_inference(&hint));
// Record usage exceeding budget
gate.record_usage(1500);
// Should deny now
assert!(!gate.allow_inference(&hint));
// Reset and check again
gate.reset_usage();
assert!(gate.allow_inference(&hint));
}
#[test]
fn test_write_policy_early_exit() {
let gate = DefaultPolicyGate::with_policy(WritePolicy::default());
let mut witness = crate::witness::WitnessLog::empty();
witness.gate_decision = crate::types::GateDecision::EarlyExit { layer: 3 };
// Default policy denies early exit writes
assert!(!gate.allow_write(&witness));
// Permissive policy allows
let permissive = DefaultPolicyGate::with_policy(WritePolicy::permissive());
assert!(permissive.allow_write(&witness));
}
#[test]
fn test_write_policy_latency() {
let mut policy = WritePolicy::default();
policy.max_latency_ns = 1000;
let gate = DefaultPolicyGate::with_policy(policy);
let mut witness = crate::witness::WitnessLog::empty();
witness.latency_ns = 500;
assert!(gate.allow_write(&witness));
witness.latency_ns = 2000;
assert!(!gate.allow_write(&witness));
}
}

View File

@@ -0,0 +1,331 @@
//! # FPGA Transformer Backend
//!
//! Ultra low latency transformer inference with FPGA acceleration,
//! coherence gating, and deterministic execution.
//!
//! ## Features
//!
//! - **Deterministic latency paths**: Fixed shape inference with bounded timing
//! - **Quantization first design**: Explicit INT4/INT8 quantization with reproducible math
//! - **Zero allocation hot path**: No heap allocations during inference
//! - **Coherence gating**: Mincut-integrated gate decisions
//! - **Multiple backends**: FPGA PCIe, FPGA Daemon, Native Sim, WASM Sim
//! - **Witness logging**: Auditable inference with ReasoningBank integration
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use ruvector_fpga_transformer::{Engine, artifact::ModelArtifact};
//! use ruvector_fpga_transformer::backend::native_sim::NativeSimBackend;
//! use ruvector_fpga_transformer::gating::DefaultCoherenceGate;
//! use ruvector_fpga_transformer::types::{InferenceRequest, GateHint, FixedShape};
//! use std::sync::Arc;
//!
//! // Create backend and gate
//! let gate = Arc::new(DefaultCoherenceGate::new());
//! let backend = NativeSimBackend::new(gate.clone());
//!
//! // Create engine
//! let mut engine = Engine::new(Box::new(backend), gate);
//!
//! // Load artifact (from file or bytes)
//! // let model_id = engine.load_artifact(&artifact_bytes)?;
//!
//! // Run inference
//! // let result = engine.infer(request)?;
//! ```
//!
//! ## Backend Selection
//!
//! The crate supports multiple backends selected at runtime:
//!
//! - `FpgaPcie`: Direct PCIe access to FPGA (requires `pcie` feature)
//! - `FpgaDaemon`: Communication via local daemon (requires `daemon` feature)
//! - `NativeSim`: Pure Rust simulator (requires `native_sim` feature)
//! - `WasmSim`: WASM-compatible simulator (requires `wasm` feature)
//!
//! ## Artifact Format
//!
//! Models are packaged as signed artifacts containing:
//! - Manifest with shape and quantization metadata
//! - Quantized weights
//! - Optional FPGA bitstream
//! - Test vectors for validation
//! - Ed25519 signature
#![warn(missing_docs)]
#![cfg_attr(feature = "wasm", allow(unused_imports))]
pub mod artifact;
pub mod backend;
pub mod error;
pub mod ffi;
pub mod gating;
pub mod quant;
pub mod types;
pub mod witness;
pub use artifact::ModelArtifact;
pub use backend::TransformerBackend;
pub use error::{Error, Result};
pub use gating::CoherenceGate;
pub use types::{
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
InferenceResult, Layout, ModelId, QuantSpec, QuantizedTensor, SkipReason, WitnessLog,
};
use std::sync::Arc;
/// Crate version
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Main engine for FPGA transformer inference
///
/// The engine combines a backend (FPGA, simulator, etc.) with a coherence gate
/// for controlled inference execution.
pub struct Engine {
/// Backend for inference execution
backend: Box<dyn TransformerBackend>,
/// Coherence gate for decision making
gate: Arc<dyn CoherenceGate>,
/// Loaded models
models: std::collections::HashMap<ModelId, ModelInfo>,
/// Inference statistics
stats: EngineStats,
}
/// Information about a loaded model
#[derive(Debug, Clone)]
pub struct ModelInfo {
/// Model artifact
pub artifact: ModelArtifact,
/// Shape configuration
pub shape: FixedShape,
/// Quantization spec
pub quant: QuantSpec,
}
/// Engine statistics
#[derive(Debug, Default, Clone)]
pub struct EngineStats {
/// Total inferences
pub total_inferences: u64,
/// Successful inferences
pub successful: u64,
/// Skipped inferences
pub skipped: u64,
/// Early exits
pub early_exits: u64,
/// Total latency (ns)
pub total_latency_ns: u64,
}
impl Engine {
/// Create a new engine with the specified backend and gate
pub fn new(backend: Box<dyn TransformerBackend>, gate: Arc<dyn CoherenceGate>) -> Self {
Self {
backend,
gate,
models: std::collections::HashMap::new(),
stats: EngineStats::default(),
}
}
/// Create with default native simulator backend
#[cfg(feature = "native_sim")]
pub fn native_sim() -> Self {
let gate = Arc::new(gating::DefaultCoherenceGate::new());
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
Self::new(backend, gate)
}
/// Load a model artifact from bytes
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<ModelId> {
let artifact = artifact::unpack_artifact(artifact_bytes)?;
self.load(&artifact)
}
/// Load a model artifact
pub fn load(&mut self, artifact: &ModelArtifact) -> Result<ModelId> {
// Validate artifact
artifact.validate()?;
// Load into backend
let model_id = self.backend.load(artifact)?;
// Store info
self.models.insert(
model_id,
ModelInfo {
artifact: artifact.clone(),
shape: artifact.manifest.shape,
quant: artifact.manifest.quant,
},
);
Ok(model_id)
}
/// Run inference
pub fn infer(&mut self, req: InferenceRequest) -> Result<InferenceResult> {
self.stats.total_inferences += 1;
// Check preflight gate
let preflight = self.gate.preflight(&req.gate_hint);
if let GateDecision::Skipped { reason } = preflight {
self.stats.skipped += 1;
return Err(Error::GateBlocked { reason });
}
// Run inference
let result = self.backend.infer(req)?;
// Update stats
self.stats.total_latency_ns += result.witness.latency_ns as u64;
match result.witness.gate_decision {
GateDecision::RanFull => self.stats.successful += 1,
GateDecision::EarlyExit { .. } => {
self.stats.successful += 1;
self.stats.early_exits += 1;
}
GateDecision::Skipped { .. } => self.stats.skipped += 1,
}
Ok(result)
}
/// Unload a model
pub fn unload(&mut self, model: ModelId) -> Result<()> {
self.backend.unload(model)?;
self.models.remove(&model);
Ok(())
}
/// Get model shape
pub fn shape(&self, model: ModelId) -> Result<FixedShape> {
self.models
.get(&model)
.map(|info| info.shape)
.ok_or_else(|| Error::ModelNotFound(model))
}
/// Get model info
pub fn model_info(&self, model: ModelId) -> Option<&ModelInfo> {
self.models.get(&model)
}
/// Check if model is loaded
pub fn is_loaded(&self, model: ModelId) -> bool {
self.models.contains_key(&model)
}
/// Get list of loaded models
pub fn loaded_models(&self) -> Vec<ModelId> {
self.models.keys().copied().collect()
}
/// Get engine statistics
pub fn stats(&self) -> &EngineStats {
&self.stats
}
/// Get backend statistics
pub fn backend_stats(&self) -> backend::BackendStats {
self.backend.stats()
}
/// Get backend kind
pub fn backend_kind(&self) -> BackendKind {
self.backend.kind()
}
/// Check if write is allowed based on witness
pub fn allow_write(&self, witness: &WitnessLog) -> bool {
self.gate.allow_write(witness)
}
/// Reset statistics
pub fn reset_stats(&mut self) {
self.stats = EngineStats::default();
}
}
impl EngineStats {
/// Get average latency in nanoseconds
pub fn avg_latency_ns(&self) -> f64 {
if self.successful == 0 {
0.0
} else {
self.total_latency_ns as f64 / self.successful as f64
}
}
/// Get average latency in milliseconds
pub fn avg_latency_ms(&self) -> f64 {
self.avg_latency_ns() / 1_000_000.0
}
/// Get success rate
pub fn success_rate(&self) -> f64 {
if self.total_inferences == 0 {
1.0
} else {
self.successful as f64 / self.total_inferences as f64
}
}
/// Get early exit rate
pub fn early_exit_rate(&self) -> f64 {
if self.successful == 0 {
0.0
} else {
self.early_exits as f64 / self.successful as f64
}
}
}
/// Prelude for convenient imports
pub mod prelude {
pub use crate::{
artifact::ModelArtifact,
backend::TransformerBackend,
gating::CoherenceGate,
types::{
BackendKind, ComputeClass, FixedShape, GateDecision, GateHint, InferenceRequest,
InferenceResult, ModelId, QuantSpec, SkipReason, WitnessLog,
},
Engine, Error, Result,
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_engine_creation() {
let gate = Arc::new(gating::DefaultCoherenceGate::new());
#[cfg(feature = "native_sim")]
{
let backend = Box::new(backend::native_sim::NativeSimBackend::new(gate.clone()));
let engine = Engine::new(backend, gate);
assert!(engine.loaded_models().is_empty());
}
}
#[test]
fn test_engine_stats() {
let stats = EngineStats {
total_inferences: 100,
successful: 80,
skipped: 10,
early_exits: 20,
total_latency_ns: 8_000_000,
};
assert!((stats.success_rate() - 0.8).abs() < 0.01);
assert!((stats.early_exit_rate() - 0.25).abs() < 0.01);
assert!((stats.avg_latency_ns() - 100_000.0).abs() < 1.0);
}
}

View File

@@ -0,0 +1,303 @@
//! Calibration data for quantization
use crate::error::Result;
use serde::{Deserialize, Serialize};
/// Calibration data for a model
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationData {
/// Layer-wise activation statistics
pub layers: Vec<LayerCalibration>,
/// Global input statistics
pub input_stats: ActivationStats,
/// Number of calibration samples used
pub num_samples: usize,
/// Calibration method used
pub method: CalibrationMethod,
}
/// Per-layer calibration data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerCalibration {
/// Layer index
pub layer_idx: usize,
/// Layer name
pub name: String,
/// Activation statistics after this layer
pub activation_stats: ActivationStats,
/// Weight statistics for this layer
pub weight_stats: WeightStats,
/// Optimal scale for activations (Q16.16)
pub act_scale: i32,
/// Optimal scale for weights (Q16.16)
pub weight_scale: i32,
}
/// Activation statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ActivationStats {
/// Minimum value seen
pub min: f32,
/// Maximum value seen
pub max: f32,
/// Mean value
pub mean: f32,
/// Standard deviation
pub std: f32,
/// Histogram bins (for entropy calibration)
#[serde(default)]
pub histogram: Vec<u32>,
/// Histogram bin edges
#[serde(default)]
pub bin_edges: Vec<f32>,
}
impl ActivationStats {
/// Create empty stats
pub fn new() -> Self {
Self::default()
}
/// Update stats with a batch of values
pub fn update(&mut self, values: &[f32]) {
if values.is_empty() {
return;
}
// Update min/max
for &v in values {
if v < self.min || self.min == 0.0 {
self.min = v;
}
if v > self.max {
self.max = v;
}
}
// Update running mean and std
let n = values.len() as f32;
let batch_mean = values.iter().sum::<f32>() / n;
let batch_var = values.iter().map(|v| (v - batch_mean).powi(2)).sum::<f32>() / n;
// Simple update (not online algorithm)
self.mean = batch_mean;
self.std = batch_var.sqrt();
}
/// Compute optimal scale for symmetric quantization to n bits
pub fn optimal_scale(&self, bits: u8) -> f32 {
let max_range = self.max.abs().max(self.min.abs());
let qmax = (1 << (bits - 1)) as f32 - 1.0;
max_range / qmax
}
}
/// Weight statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WeightStats {
/// Min weight value
pub min: f32,
/// Max weight value
pub max: f32,
/// Sparsity (fraction of zeros)
pub sparsity: f32,
}
impl WeightStats {
/// Compute from weight tensor
pub fn from_weights(weights: &[f32]) -> Self {
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut zeros = 0usize;
for &w in weights {
if w < min {
min = w;
}
if w > max {
max = w;
}
if w.abs() < 1e-6 {
zeros += 1;
}
}
Self {
min,
max,
sparsity: zeros as f32 / weights.len() as f32,
}
}
}
/// Calibration method
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CalibrationMethod {
/// Use min/max of observed values
MinMax,
/// Use percentile clipping (e.g., 99.9%)
Percentile(u32), // 999 = 99.9%
/// Entropy-based calibration (KL divergence)
Entropy,
/// Mean-squared error minimization
Mse,
}
impl Default for CalibrationMethod {
fn default() -> Self {
Self::MinMax
}
}
impl CalibrationData {
/// Create empty calibration data
pub fn new(method: CalibrationMethod) -> Self {
Self {
layers: Vec::new(),
input_stats: ActivationStats::new(),
num_samples: 0,
method,
}
}
/// Add layer calibration
pub fn add_layer(&mut self, calib: LayerCalibration) {
self.layers.push(calib);
}
/// Serialize to bytes
pub fn to_bytes(&self) -> Result<Vec<u8>> {
Ok(serde_json::to_vec(self)?)
}
/// Deserialize from bytes
pub fn from_bytes(data: &[u8]) -> Result<Self> {
Ok(serde_json::from_slice(data)?)
}
}
/// Calibrate a model by collecting activation statistics
pub fn calibrate_model<F>(
run_inference: F,
calibration_inputs: &[Vec<u16>],
num_layers: usize,
method: CalibrationMethod,
) -> Result<CalibrationData>
where
F: Fn(&[u16]) -> Result<Vec<Vec<f32>>>, // Returns activations per layer
{
let mut calibration = CalibrationData::new(method);
// Initialize layer stats
let mut layer_stats: Vec<ActivationStats> =
(0..num_layers).map(|_| ActivationStats::new()).collect();
// Run calibration passes
for input in calibration_inputs {
// Run inference and collect activations
let activations = run_inference(input)?;
// Update statistics
for (layer_idx, layer_act) in activations.iter().enumerate() {
if layer_idx < num_layers {
layer_stats[layer_idx].update(layer_act);
}
}
calibration.num_samples += 1;
}
// Create layer calibrations
for (layer_idx, stats) in layer_stats.into_iter().enumerate() {
let act_scale = match method {
CalibrationMethod::MinMax => stats.optimal_scale(8),
CalibrationMethod::Percentile(_) => stats.optimal_scale(8) * 0.99,
CalibrationMethod::Entropy => stats.optimal_scale(8),
CalibrationMethod::Mse => stats.optimal_scale(8),
};
calibration.add_layer(LayerCalibration {
layer_idx,
name: format!("layer_{}", layer_idx),
activation_stats: stats,
weight_stats: WeightStats::default(),
act_scale: (act_scale * 65536.0) as i32,
weight_scale: 65536, // Default 1.0
});
}
Ok(calibration)
}
/// Apply percentile clipping to calibration
pub fn apply_percentile(stats: &ActivationStats, percentile: f32) -> (f32, f32) {
if stats.histogram.is_empty() || stats.bin_edges.len() < 2 {
return (stats.min, stats.max);
}
let total: u32 = stats.histogram.iter().sum();
let target_low = (total as f32 * (1.0 - percentile) / 2.0) as u32;
let target_high = (total as f32 * (1.0 + percentile) / 2.0) as u32;
let mut cumsum = 0u32;
let mut low_idx = 0;
let mut high_idx = stats.histogram.len() - 1;
for (i, &count) in stats.histogram.iter().enumerate() {
cumsum += count;
if cumsum >= target_low && low_idx == 0 {
low_idx = i;
}
if cumsum >= target_high {
high_idx = i;
break;
}
}
(stats.bin_edges[low_idx], stats.bin_edges[high_idx + 1])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_activation_stats_update() {
let mut stats = ActivationStats::new();
stats.update(&[1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(stats.min, 1.0);
assert_eq!(stats.max, 5.0);
assert!((stats.mean - 3.0).abs() < 0.01);
}
#[test]
fn test_optimal_scale() {
let mut stats = ActivationStats::new();
stats.min = -1.0;
stats.max = 1.0;
let scale = stats.optimal_scale(8);
// For 8-bit, qmax = 127, so scale should be 1.0/127 ≈ 0.00787
assert!((scale - 1.0 / 127.0).abs() < 0.001);
}
#[test]
fn test_weight_stats() {
let weights = vec![0.0, 0.1, -0.1, 0.5, -0.5, 0.0];
let stats = WeightStats::from_weights(&weights);
assert_eq!(stats.min, -0.5);
assert_eq!(stats.max, 0.5);
assert!((stats.sparsity - 2.0 / 6.0).abs() < 0.01);
}
#[test]
fn test_calibration_serialization() {
let calib = CalibrationData::new(CalibrationMethod::MinMax);
let bytes = calib.to_bytes().unwrap();
let restored = CalibrationData::from_bytes(&bytes).unwrap();
assert_eq!(calib.method, restored.method);
}
}

View File

@@ -0,0 +1,336 @@
//! Lookup table implementations for fixed-point operations
//!
//! Provides LUT-based exp, log, and softmax for deterministic computation.
/// LUT-based exponential function
/// Input: Q8.8 fixed point [-16, 16)
/// Output: Q0.16 fixed point [0, 1)
const EXP_LUT_SIZE: usize = 256;
const EXP_LUT_SHIFT: i32 = 8; // Q8.8 input
/// Precomputed exp LUT for range [-8, 8) in Q8.8
static EXP_LUT: [u16; EXP_LUT_SIZE] = generate_exp_lut();
/// Generate exp LUT at compile time
const fn generate_exp_lut() -> [u16; EXP_LUT_SIZE] {
let mut lut = [0u16; EXP_LUT_SIZE];
let mut i = 0;
while i < EXP_LUT_SIZE {
// Convert index to Q8.8 value (range -128..128 in fixed point = -0.5..0.5)
let x_q = (i as i32) - 128;
// Scale to get reasonable exp range
let x_f = (x_q as f64) / 32.0; // x in [-4, 4)
// Compute exp and scale to Q0.16
let exp_val = const_exp(x_f);
let scaled = exp_val / (1.0 + const_exp(4.0)); // Normalize
// Convert to u16
lut[i] = if scaled > 1.0 {
65535
} else if scaled < 0.0 {
0
} else {
(scaled * 65535.0) as u16
};
i += 1;
}
lut
}
/// Const-compatible exp approximation using Taylor series
const fn const_exp(x: f64) -> f64 {
// exp(x) ≈ 1 + x + x²/2 + x³/6 + x⁴/24 + x⁵/120
let x2 = x * x;
let x3 = x2 * x;
let x4 = x3 * x;
let x5 = x4 * x;
1.0 + x + x2 / 2.0 + x3 / 6.0 + x4 / 24.0 + x5 / 120.0
}
/// LUT-based exponential
/// Input: i16 in Q8.8 format
/// Output: u16 in Q0.16 format
#[inline]
pub fn exp_lut(x: i16) -> u16 {
// Clamp to LUT range
let clamped = x.clamp(-128 * 256, 127 * 256);
// Scale to LUT index
let idx = ((clamped >> EXP_LUT_SHIFT) + 128) as usize;
EXP_LUT[idx.min(EXP_LUT_SIZE - 1)]
}
/// Log LUT for Q0.16 input
static LOG_LUT: [i16; 256] = generate_log_lut();
const fn generate_log_lut() -> [i16; 256] {
let mut lut = [0i16; 256];
let mut i = 1;
while i < 256 {
// Input is scaled by 256, so x = i/256 in [0.004, 1)
let x = (i as f64) / 256.0;
// log(x) in Q8.8 format
let log_val = const_ln(x);
lut[i] = (log_val * 256.0) as i16;
i += 1;
}
lut[0] = i16::MIN; // log(0) = -inf, use min value
lut
}
/// Const-compatible natural log approximation
const fn const_ln(x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
// Use series expansion around x=1: ln(x) = 2 * sum((x-1)/(x+1))^(2n+1)/(2n+1)
let y = (x - 1.0) / (x + 1.0);
let y2 = y * y;
// ln(x) ≈ 2 * (y + y³/3 + y⁵/5 + y⁷/7 + y⁹/9)
let y3 = y2 * y;
let y5 = y3 * y2;
let y7 = y5 * y2;
let y9 = y7 * y2;
2.0 * (y + y3 / 3.0 + y5 / 5.0 + y7 / 7.0 + y9 / 9.0)
}
/// LUT-based natural log
/// Input: u16 in Q0.16 format (0 to 65535 = 0.0 to ~1.0)
/// Output: i16 in Q8.8 format
#[inline]
pub fn log_lut(x: u16) -> i16 {
if x == 0 {
return i16::MIN;
}
// Scale to LUT index
let idx = (x >> 8) as usize;
LOG_LUT[idx.min(255)]
}
/// Softmax using LUT-based exp
/// Operates in-place on Q8.8 logits, outputs Q0.16 probabilities
pub fn softmax_lut_q(logits: &mut [i16]) {
if logits.is_empty() {
return;
}
// Find max for numerical stability
let max = *logits.iter().max().unwrap_or(&0);
// Compute exp(x - max) using LUT
let mut sum: u32 = 0;
let mut exp_values: Vec<u16> = Vec::with_capacity(logits.len());
for &logit in logits.iter() {
let shifted = logit.saturating_sub(max);
let exp_val = exp_lut(shifted);
exp_values.push(exp_val);
sum += exp_val as u32;
}
// Normalize
if sum == 0 {
sum = 1;
}
for (i, logit) in logits.iter_mut().enumerate() {
let prob = ((exp_values[i] as u64 * 65535) / sum as u64) as u16;
*logit = prob as i16;
}
}
/// Softmax on f32 values using LUT (for compatibility)
pub fn softmax_lut(logits: &mut [f32]) {
if logits.is_empty() {
return;
}
// Find max
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
// Compute exp
let mut sum = 0.0f32;
for v in logits.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
// Normalize
if sum > 0.0 {
for v in logits.iter_mut() {
*v /= sum;
}
}
}
/// Piecewise linear softmax approximation
/// More accurate than LUT but still deterministic
pub fn softmax_pwl(logits: &mut [i16]) {
if logits.is_empty() {
return;
}
let max = *logits.iter().max().unwrap_or(&0);
// Piecewise linear exp approximation
// exp(x) ≈ 1 + x for x near 0
// exp(x) ≈ 2^(x/ln2) for larger x
let mut sum: i64 = 0;
let mut exp_values: Vec<i32> = Vec::with_capacity(logits.len());
for &logit in logits.iter() {
let x = (logit - max) as i32; // x <= 0
// Piecewise approximation (in Q8.8)
let exp_val = if x >= -256 {
// x in [-1, 0] -> linear: 1 + x
(256 + x).max(0) as i32
} else if x >= -2048 {
// x in [-8, -1] -> exponential decay
let shifted = (x + 2048) >> 3; // Scale to 0-256 range
(shifted * shifted / 256).max(1) as i32
} else {
// x < -8 -> essentially zero
1
};
exp_values.push(exp_val);
sum += exp_val as i64;
}
// Normalize to Q0.16
if sum == 0 {
sum = 1;
}
for (i, logit) in logits.iter_mut().enumerate() {
let prob = (exp_values[i] as i64 * 65535 / sum) as i16;
*logit = prob;
}
}
/// GELU approximation using LUT
/// GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
pub fn gelu_lut(x: i16) -> i16 {
// Simplified approximation: GELU(x) ≈ x * sigmoid(1.702 * x)
let scaled = ((x as i32 * 435) >> 8) as i16; // 1.702 * x in Q8.8
let sigmoid_val = sigmoid_lut(scaled);
((x as i32 * sigmoid_val as i32) >> 15) as i16
}
/// Sigmoid LUT
static SIGMOID_LUT: [u16; 256] = generate_sigmoid_lut();
const fn generate_sigmoid_lut() -> [u16; 256] {
let mut lut = [0u16; 256];
let mut i = 0;
while i < 256 {
// Map index to x in [-8, 8)
let x = ((i as i32) - 128) as f64 / 16.0;
// sigmoid(x) = 1 / (1 + exp(-x))
let sig = 1.0 / (1.0 + const_exp(-x));
lut[i] = (sig * 65535.0) as u16;
i += 1;
}
lut
}
/// LUT-based sigmoid
/// Input: i16 in Q8.8 format
/// Output: u16 in Q0.16 format
#[inline]
pub fn sigmoid_lut(x: i16) -> u16 {
// Scale to LUT range
let idx = (((x >> 5) + 128) as usize).min(255);
SIGMOID_LUT[idx]
}
/// SiLU (Swish) using sigmoid LUT
/// SiLU(x) = x * sigmoid(x)
#[inline]
pub fn silu_lut(x: i16) -> i16 {
let sigmoid_val = sigmoid_lut(x);
((x as i32 * sigmoid_val as i32) >> 16) as i16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exp_lut() {
// exp(0) should return a non-zero value
let result = exp_lut(0);
assert!(result > 0, "exp(0) should be positive");
// exp is monotonically increasing
let result_neg = exp_lut(-256); // -1.0 in Q8.8
let result_zero = exp_lut(0);
let result_pos = exp_lut(256); // 1.0 in Q8.8
assert!(
result_neg <= result_zero,
"exp should be monotonically increasing"
);
assert!(
result_zero <= result_pos,
"exp should be monotonically increasing"
);
}
#[test]
fn test_sigmoid_lut() {
// sigmoid(0) = 0.5
let result = sigmoid_lut(0);
let expected = 32768u16; // 0.5 in Q0.16
assert!(
(result as i32 - expected as i32).abs() < 5000,
"sigmoid(0) ≈ 0.5"
);
// sigmoid is monotonically increasing
let result_neg = sigmoid_lut(-1024);
let result_zero = sigmoid_lut(0);
let result_pos = sigmoid_lut(1024);
assert!(
result_neg < result_zero,
"sigmoid should be monotonically increasing"
);
assert!(
result_zero < result_pos,
"sigmoid should be monotonically increasing"
);
}
#[test]
fn test_softmax_lut() {
let mut logits = vec![1.0f32, 2.0, 3.0, 4.0];
softmax_lut(&mut logits);
// Sum should be 1.0
let sum: f32 = logits.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
// Should be increasing
for i in 1..logits.len() {
assert!(logits[i] > logits[i - 1]);
}
}
#[test]
fn test_gelu_lut() {
// GELU(0) should be approximately 0
assert!(gelu_lut(0).abs() < 100);
// GELU maintains sign
let neg_result = gelu_lut(-256);
assert!(neg_result <= 0, "GELU of negative should be non-positive");
// GELU of positive values should be positive
let pos_result = gelu_lut(256);
assert!(pos_result > 0, "GELU of positive should be positive");
}
}

View File

@@ -0,0 +1,238 @@
//! Quantization subsystem
//!
//! Explicit, reproducible quantization for weights and activations.
pub mod calib;
pub mod lut;
pub mod qformat;
pub use calib::{calibrate_model, CalibrationData};
pub use lut::{exp_lut, log_lut, softmax_lut};
pub use qformat::{dequantize_i16, dequantize_i8, quantize_i16, quantize_i8};
use crate::types::QuantSpec;
/// Fixed-point Q15 format (1.15)
/// Range: [-1.0, 1.0 - 2^-15]
/// Resolution: 2^-15 ≈ 3.05e-5
pub type Q15 = i16;
/// Fixed-point Q16.16 format
/// Range: [-32768.0, 32767.999...]
/// Resolution: 2^-16 ≈ 1.53e-5
pub type Q16_16 = i32;
/// Convert f32 to Q15
#[inline]
pub fn f32_to_q15(x: f32) -> Q15 {
(x.clamp(-1.0, 1.0 - f32::EPSILON) * 32768.0) as Q15
}
/// Convert Q15 to f32
#[inline]
pub fn q15_to_f32(x: Q15) -> f32 {
x as f32 / 32768.0
}
/// Convert f32 to Q16.16
#[inline]
pub fn f32_to_q16_16(x: f32) -> Q16_16 {
(x * 65536.0) as Q16_16
}
/// Convert Q16.16 to f32
#[inline]
pub fn q16_16_to_f32(x: Q16_16) -> f32 {
x as f32 / 65536.0
}
/// Fixed-point multiplication Q15 * Q15 -> Q15
#[inline]
pub fn q15_mul(a: Q15, b: Q15) -> Q15 {
// Multiply with proper rounding
let product = (a as i32 * b as i32 + 0x4000) >> 15;
product.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
}
/// Fixed-point dot product with accumulator
/// Note: For very large vectors (>65536 elements), use q15_dot_saturating
#[inline]
pub fn q15_dot(a: &[Q15], b: &[Q15]) -> i32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum()
}
/// Saturating fixed-point dot product (prevents overflow for large vectors)
#[inline]
pub fn q15_dot_saturating(a: &[Q15], b: &[Q15]) -> i32 {
a.iter().zip(b.iter()).fold(0i32, |acc, (&x, &y)| {
acc.saturating_add((x as i32).saturating_mul(y as i32))
})
}
/// Fixed-point dot product normalized to Q15
#[inline]
pub fn q15_dot_normalized(a: &[Q15], b: &[Q15], shift: u8) -> Q15 {
let sum = q15_dot(a, b);
let shifted = (sum + (1 << (shift - 1))) >> shift;
shifted.clamp(i16::MIN as i32, i16::MAX as i32) as Q15
}
/// Quantization context for a layer
#[derive(Debug, Clone)]
pub struct QuantContext {
/// Weight quantization spec
pub weight_spec: QuantSpec,
/// Input scale (Q16.16)
pub input_scale: Q16_16,
/// Output scale (Q16.16)
pub output_scale: Q16_16,
/// Accumulator bit width
pub acc_bits: u8,
/// Right shift for normalization
pub norm_shift: u8,
}
impl QuantContext {
/// Create from QuantSpec
pub fn from_spec(spec: &QuantSpec) -> Self {
Self {
weight_spec: *spec,
input_scale: spec.scale_q,
output_scale: spec.scale_q,
acc_bits: 32,
norm_shift: 15,
}
}
/// Compute the required accumulator bits to avoid overflow
pub fn required_acc_bits(input_bits: u8, weight_bits: u8, vector_len: usize) -> u8 {
// Each multiply produces input_bits + weight_bits
// Sum of vector_len terms adds log2(vector_len) bits
let product_bits = input_bits + weight_bits;
let sum_bits = (vector_len as f64).log2().ceil() as u8;
product_bits + sum_bits + 1 // +1 for sign
}
}
/// Packing utilities for sub-byte quantization
pub mod packing {
/// Pack two 4-bit values into one byte
#[inline]
pub fn pack_int4(a: i8, b: i8) -> u8 {
((a & 0x0F) as u8) | (((b & 0x0F) as u8) << 4)
}
/// Unpack byte into two 4-bit values
#[inline]
pub fn unpack_int4(packed: u8) -> (i8, i8) {
let a = (packed & 0x0F) as i8;
let a = if a & 0x08 != 0 { a | !0x0F } else { a }; // Sign extend
let b = ((packed >> 4) & 0x0F) as i8;
let b = if b & 0x08 != 0 { b | !0x0F } else { b };
(a, b)
}
/// Pack four 2-bit values into one byte
#[inline]
pub fn pack_int2(a: i8, b: i8, c: i8, d: i8) -> u8 {
((a & 0x03) as u8)
| (((b & 0x03) as u8) << 2)
| (((c & 0x03) as u8) << 4)
| (((d & 0x03) as u8) << 6)
}
/// Unpack byte into four 2-bit values
#[inline]
pub fn unpack_int2(packed: u8) -> (i8, i8, i8, i8) {
let a = (packed & 0x03) as i8;
let a = if a & 0x02 != 0 { a | !0x03 } else { a };
let b = ((packed >> 2) & 0x03) as i8;
let b = if b & 0x02 != 0 { b | !0x03 } else { b };
let c = ((packed >> 4) & 0x03) as i8;
let c = if c & 0x02 != 0 { c | !0x03 } else { c };
let d = ((packed >> 6) & 0x03) as i8;
let d = if d & 0x02 != 0 { d | !0x03 } else { d };
(a, b, c, d)
}
/// Pack eight 1-bit values into one byte
#[inline]
pub fn pack_binary(bits: &[bool; 8]) -> u8 {
bits.iter()
.enumerate()
.fold(0u8, |acc, (i, &b)| acc | ((b as u8) << i))
}
/// Unpack byte into eight 1-bit values
#[inline]
pub fn unpack_binary(packed: u8) -> [bool; 8] {
[
packed & 0x01 != 0,
packed & 0x02 != 0,
packed & 0x04 != 0,
packed & 0x08 != 0,
packed & 0x10 != 0,
packed & 0x20 != 0,
packed & 0x40 != 0,
packed & 0x80 != 0,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_q15_conversion() {
assert_eq!(f32_to_q15(0.0), 0);
assert_eq!(f32_to_q15(0.5), 16384);
assert_eq!(f32_to_q15(-0.5), -16384);
let x = 0.123f32;
let q = f32_to_q15(x);
let back = q15_to_f32(q);
assert!((x - back).abs() < 0.001);
}
#[test]
fn test_q15_mul() {
let a = f32_to_q15(0.5);
let b = f32_to_q15(0.5);
let c = q15_mul(a, b);
let result = q15_to_f32(c);
assert!((result - 0.25).abs() < 0.01);
}
#[test]
fn test_packing_int4() {
let (a, b) = (5i8, -3i8);
let packed = packing::pack_int4(a, b);
let (ua, ub) = packing::unpack_int4(packed);
assert_eq!(a, ua);
assert_eq!(b, ub);
}
#[test]
fn test_packing_int2() {
let (a, b, c, d) = (1i8, -1i8, 0i8, -2i8);
let packed = packing::pack_int2(a, b, c, d);
let (ua, ub, uc, ud) = packing::unpack_int2(packed);
assert_eq!(a, ua);
assert_eq!(b, ub);
assert_eq!(c, uc);
// -2 in 2-bit is 10 binary, which unpacks to -2 (sign extended)
assert_eq!(-2i8, ud);
}
#[test]
fn test_packing_binary() {
let bits = [true, false, true, true, false, false, true, false];
let packed = packing::pack_binary(&bits);
let unpacked = packing::unpack_binary(packed);
assert_eq!(bits, unpacked);
}
}

View File

@@ -0,0 +1,237 @@
//! Quantization format operations
use crate::types::QuantSpec;
/// Quantize f32 values to i8
pub fn quantize_i8(values: &[f32], spec: &QuantSpec) -> Vec<i8> {
let scale = spec.scale_q as f32 / 65536.0;
let zero = spec.zero_q as f32 / 65536.0;
values
.iter()
.map(|&v| {
let quantized = ((v - zero) / scale).round();
quantized.clamp(-128.0, 127.0) as i8
})
.collect()
}
/// Quantize f32 values to i16
pub fn quantize_i16(values: &[f32]) -> Vec<i16> {
// Find min/max
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
// Handle edge case
if (max - min).abs() < f32::EPSILON {
return vec![0i16; values.len()];
}
let scale = (max - min) / 65535.0;
values
.iter()
.map(|&v| {
let normalized = (v - min) / scale - 32768.0;
normalized.round().clamp(-32768.0, 32767.0) as i16
})
.collect()
}
/// Dequantize i8 values to f32
pub fn dequantize_i8(values: &[u8], spec: &QuantSpec) -> Vec<f32> {
let scale = spec.scale_q as f32 / 65536.0;
let zero = spec.zero_q as f32 / 65536.0;
values
.iter()
.map(|&v| {
let signed = v as i8;
signed as f32 * scale + zero
})
.collect()
}
/// Dequantize i16 values to f32
pub fn dequantize_i16(values: &[i16], scale: f32, zero: f32) -> Vec<f32> {
values.iter().map(|&v| v as f32 * scale + zero).collect()
}
/// Symmetric quantization (zero point = 0)
pub fn quantize_symmetric_i8(values: &[f32]) -> (Vec<i8>, f32) {
let abs_max = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if abs_max < f32::EPSILON {
return (vec![0i8; values.len()], 1.0);
}
let scale = abs_max / 127.0;
let quantized = values
.iter()
.map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
.collect();
(quantized, scale)
}
/// Asymmetric quantization (uses full i8 range)
pub fn quantize_asymmetric_i8(values: &[f32]) -> (Vec<u8>, f32, i32) {
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if (max - min).abs() < f32::EPSILON {
return (vec![0u8; values.len()], 1.0, 0);
}
let scale = (max - min) / 255.0;
let zero_point = (-min / scale).round() as i32;
let quantized = values
.iter()
.map(|&v| {
let q = (v / scale).round() as i32 + zero_point;
q.clamp(0, 255) as u8
})
.collect();
(quantized, scale, zero_point)
}
/// Per-channel quantization for weights
pub fn quantize_per_channel_i8(weights: &[f32], out_channels: usize) -> (Vec<i8>, Vec<f32>) {
let in_features = weights.len() / out_channels;
let mut quantized = Vec::with_capacity(weights.len());
let mut scales = Vec::with_capacity(out_channels);
for c in 0..out_channels {
let start = c * in_features;
let end = start + in_features;
let channel_weights = &weights[start..end];
let (q, scale) = quantize_symmetric_i8(channel_weights);
quantized.extend(q);
scales.push(scale);
}
(quantized, scales)
}
/// Blocked quantization for hardware efficiency
pub fn quantize_blocked_i8(values: &[f32], block_size: usize) -> (Vec<i8>, Vec<f32>, Vec<i8>) {
let num_blocks = (values.len() + block_size - 1) / block_size;
let mut quantized = Vec::with_capacity(values.len());
let mut scales = Vec::with_capacity(num_blocks);
let mut zeros = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(values.len());
let block = &values[start..end];
let (q, scale) = quantize_symmetric_i8(block);
quantized.extend(q);
scales.push(scale);
zeros.push(0i8);
}
(quantized, scales, zeros)
}
/// Matrix quantization for GEMM
#[derive(Debug, Clone)]
pub struct QuantizedMatrix {
/// Quantized values
pub data: Vec<i8>,
/// Rows
pub rows: usize,
/// Columns
pub cols: usize,
/// Per-row scales (for per-channel quantization)
pub scales: Vec<f32>,
/// Per-row zero points
pub zeros: Vec<i8>,
}
impl QuantizedMatrix {
/// Quantize a matrix with per-row scaling
pub fn from_f32(data: &[f32], rows: usize, cols: usize) -> Self {
assert_eq!(data.len(), rows * cols);
let (quantized, scales) = quantize_per_channel_i8(data, rows);
Self {
data: quantized,
rows,
cols,
scales,
zeros: vec![0i8; rows],
}
}
/// Get a row
pub fn row(&self, idx: usize) -> &[i8] {
let start = idx * self.cols;
&self.data[start..start + self.cols]
}
/// Dequantize to f32
pub fn to_f32(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.rows * self.cols);
for r in 0..self.rows {
let scale = self.scales[r];
let zero = self.zeros[r] as f32;
for &v in self.row(r) {
result.push((v as f32 - zero) * scale);
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::QuantSpec;
#[test]
fn test_quantize_symmetric() {
let values = vec![1.0, -1.0, 0.5, -0.5, 0.0];
let (quantized, scale) = quantize_symmetric_i8(&values);
// Dequantize and check
for (i, &q) in quantized.iter().enumerate() {
let dequant = q as f32 * scale;
assert!((dequant - values[i]).abs() < 0.1);
}
}
#[test]
fn test_quantize_asymmetric() {
let values = vec![0.0, 0.5, 1.0, 1.5, 2.0];
let (quantized, scale, zero) = quantize_asymmetric_i8(&values);
// Dequantize and check
for (i, &q) in quantized.iter().enumerate() {
let dequant = (q as i32 - zero) as f32 * scale;
assert!((dequant - values[i]).abs() < 0.1);
}
}
#[test]
fn test_quantized_matrix() {
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.1 - 3.2).collect();
let matrix = QuantizedMatrix::from_f32(&data, 8, 8);
assert_eq!(matrix.rows, 8);
assert_eq!(matrix.cols, 8);
assert_eq!(matrix.scales.len(), 8);
let dequantized = matrix.to_f32();
for (orig, deq) in data.iter().zip(dequantized.iter()) {
assert!((orig - deq).abs() < 0.2);
}
}
}

View File

@@ -0,0 +1,624 @@
//! Core types for FPGA Transformer backend
//!
//! All types are designed for deterministic, allocation-free inference
//! with explicit quantization metadata.
use serde::{Deserialize, Serialize};
/// Unique identifier for a loaded model (SHA-256 hash)
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelId(pub [u8; 32]);
impl ModelId {
/// Create a new ModelId from bytes
pub const fn new(bytes: [u8; 32]) -> Self {
Self(bytes)
}
/// Create a zero ModelId (for testing)
pub const fn zero() -> Self {
Self([0u8; 32])
}
/// Get the bytes of the ModelId
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
/// Convert to hex string
pub fn to_hex(&self) -> String {
self.0.iter().map(|b| format!("{:02x}", b)).collect()
}
/// Parse from hex string
pub fn from_hex(s: &str) -> Option<Self> {
if s.len() != 64 {
return None;
}
let mut bytes = [0u8; 32];
for (i, chunk) in s.as_bytes().chunks(2).enumerate() {
let hex_str = std::str::from_utf8(chunk).ok()?;
bytes[i] = u8::from_str_radix(hex_str, 16).ok()?;
}
Some(Self(bytes))
}
}
impl std::fmt::Display for ModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_hex())
}
}
/// Fixed shape specification for transformer inference
///
/// All dimensions are compile-time or model-time constants.
/// This enables zero-allocation inference paths.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct FixedShape {
/// Maximum sequence length
pub seq_len: u16,
/// Model/hidden dimension
pub d_model: u16,
/// Number of attention heads
pub heads: u8,
/// Dimension per head
pub d_head: u16,
/// Vocabulary size
pub vocab: u32,
}
impl FixedShape {
/// Create a new FixedShape
pub const fn new(seq_len: u16, d_model: u16, heads: u8, d_head: u16, vocab: u32) -> Self {
Self {
seq_len,
d_model,
heads,
d_head,
vocab,
}
}
/// Micro configuration for edge/WASM deployment
pub const fn micro() -> Self {
Self {
seq_len: 32,
d_model: 64,
heads: 4,
d_head: 16,
vocab: 4096,
}
}
/// Small configuration for embedded
pub const fn small() -> Self {
Self {
seq_len: 64,
d_model: 128,
heads: 4,
d_head: 32,
vocab: 8192,
}
}
/// Baseline configuration
pub const fn baseline() -> Self {
Self {
seq_len: 128,
d_model: 256,
heads: 8,
d_head: 32,
vocab: 32000,
}
}
/// Calculate total parameters for embedding layer
pub const fn embedding_params(&self) -> usize {
self.vocab as usize * self.d_model as usize
}
/// Calculate parameters per attention layer
pub const fn attention_params(&self) -> usize {
// Q, K, V projections + output projection
4 * (self.d_model as usize * self.d_model as usize)
}
/// Calculate parameters per FFN layer (assuming 4x expansion)
pub const fn ffn_params(&self) -> usize {
2 * (self.d_model as usize * 4 * self.d_model as usize)
}
/// Validate shape consistency
pub fn validate(&self) -> Result<(), String> {
if self.d_model as usize != self.heads as usize * self.d_head as usize {
return Err(format!(
"d_model ({}) must equal heads ({}) * d_head ({})",
self.d_model, self.heads, self.d_head
));
}
if self.seq_len == 0 {
return Err("seq_len must be > 0".into());
}
if self.vocab == 0 {
return Err("vocab must be > 0".into());
}
Ok(())
}
}
impl Default for FixedShape {
fn default() -> Self {
Self::baseline()
}
}
/// Quantization specification
///
/// Explicit quantization metadata ensuring reproducible inference.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct QuantSpec {
/// Weight bit width (1, 2, 4, 8)
pub w_bits: u8,
/// Activation bit width (4, 8, 16)
pub a_bits: u8,
/// Scale factor (Q16.16 fixed point)
pub scale_q: i32,
/// Zero point (Q16.16 fixed point)
pub zero_q: i32,
/// Memory layout
pub layout: Layout,
}
impl QuantSpec {
/// Create a new QuantSpec
pub const fn new(w_bits: u8, a_bits: u8, scale_q: i32, zero_q: i32, layout: Layout) -> Self {
Self {
w_bits,
a_bits,
scale_q,
zero_q,
layout,
}
}
/// INT4 weights, INT8 activations (common for edge)
pub const fn int4_int8() -> Self {
Self {
w_bits: 4,
a_bits: 8,
scale_q: 1 << 16, // 1.0 in Q16.16
zero_q: 0,
layout: Layout::Blocked { block: 32 },
}
}
/// INT8 weights and activations
pub const fn int8() -> Self {
Self {
w_bits: 8,
a_bits: 8,
scale_q: 1 << 16,
zero_q: 0,
layout: Layout::RowMajor,
}
}
/// Bytes per weight element
pub const fn bytes_per_weight(&self) -> usize {
match self.w_bits {
1 => 1, // Packed 8 per byte, but minimum 1 byte
2 => 1, // Packed 4 per byte
4 => 1, // Packed 2 per byte
8 => 1,
16 => 2,
_ => 4,
}
}
/// Weights packed per byte
pub const fn weights_per_byte(&self) -> usize {
match self.w_bits {
1 => 8,
2 => 4,
4 => 2,
_ => 1,
}
}
}
impl Default for QuantSpec {
fn default() -> Self {
Self::int8()
}
}
/// Memory layout for quantized tensors
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Layout {
/// Standard row-major layout
RowMajor,
/// Blocked layout for SIMD/hardware efficiency
Blocked { block: u16 },
/// Heads interleaved for attention computation
InterleavedHeads,
}
impl Default for Layout {
fn default() -> Self {
Self::RowMajor
}
}
/// Hint for gating decisions
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct GateHint {
/// Coherence score (Q8.8 fixed point, higher = more coherent)
pub coherence_score_q: i16,
/// Whether a boundary was crossed in the input
pub boundary_crossed: bool,
/// Maximum compute class allowed
pub max_compute_class: ComputeClass,
}
impl GateHint {
/// Create a new GateHint
pub const fn new(
coherence_score_q: i16,
boundary_crossed: bool,
max_compute_class: ComputeClass,
) -> Self {
Self {
coherence_score_q,
boundary_crossed,
max_compute_class,
}
}
/// Default hint allowing full computation
pub const fn allow_all() -> Self {
Self {
coherence_score_q: i16::MAX,
boundary_crossed: false,
max_compute_class: ComputeClass::Deliberative,
}
}
/// Reflex-only hint for fast path
pub const fn reflex_only() -> Self {
Self {
coherence_score_q: 0,
boundary_crossed: false,
max_compute_class: ComputeClass::Reflex,
}
}
}
impl Default for GateHint {
fn default() -> Self {
Self::allow_all()
}
}
/// Compute class for tiered inference
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum ComputeClass {
/// Fastest path, minimal computation (1-2 layers)
Reflex = 0,
/// Medium path, associative memory (4-6 layers)
Associative = 1,
/// Full deliberative computation (all layers)
Deliberative = 2,
}
impl ComputeClass {
/// Convert from u8
pub const fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Reflex),
1 => Some(Self::Associative),
2 => Some(Self::Deliberative),
_ => None,
}
}
}
impl Default for ComputeClass {
fn default() -> Self {
Self::Deliberative
}
}
/// Inference request
#[derive(Debug, Clone)]
pub struct InferenceRequest<'a> {
/// Model to use
pub model: ModelId,
/// Expected shape
pub shape: FixedShape,
/// Input token IDs (length = seq_len)
pub tokens: &'a [u16],
/// Attention mask (length = seq_len or seq_len^2)
pub attn_mask: &'a [u8],
/// Gating hint for coherence control
pub gate_hint: GateHint,
}
impl<'a> InferenceRequest<'a> {
/// Create a new InferenceRequest
pub fn new(
model: ModelId,
shape: FixedShape,
tokens: &'a [u16],
attn_mask: &'a [u8],
gate_hint: GateHint,
) -> Self {
Self {
model,
shape,
tokens,
attn_mask,
gate_hint,
}
}
/// Validate the request
pub fn validate(&self) -> crate::error::Result<()> {
if self.tokens.len() != self.shape.seq_len as usize {
return Err(crate::error::Error::InputLengthMismatch {
expected: self.shape.seq_len as usize,
actual: self.tokens.len(),
});
}
if self.attn_mask.len() != self.shape.seq_len as usize
&& self.attn_mask.len() != (self.shape.seq_len as usize).pow(2)
{
return Err(crate::error::Error::InputLengthMismatch {
expected: self.shape.seq_len as usize,
actual: self.attn_mask.len(),
});
}
Ok(())
}
}
/// Inference result
#[derive(Debug, Clone)]
pub struct InferenceResult {
/// Full logits (quantized, length = vocab) or empty if topk_only
pub logits_q: Vec<i16>,
/// Top-K predictions (token_id, logit_q)
pub topk: Option<Vec<(u16, i16)>>,
/// Witness log for audit trail
pub witness: WitnessLog,
}
impl InferenceResult {
/// Create a new InferenceResult
pub fn new(logits_q: Vec<i16>, topk: Option<Vec<(u16, i16)>>, witness: WitnessLog) -> Self {
Self {
logits_q,
topk,
witness,
}
}
/// Get the argmax token
pub fn argmax(&self) -> Option<u16> {
if let Some(ref topk) = self.topk {
topk.first().map(|(token, _)| *token)
} else if !self.logits_q.is_empty() {
self.logits_q
.iter()
.enumerate()
.max_by_key(|(_, &v)| v)
.map(|(i, _)| i as u16)
} else {
None
}
}
}
/// Witness log for audit trail and ReasoningBank integration
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct WitnessLog {
/// Hash of the model used
pub model_hash: [u8; 32],
/// Hash of quantization parameters used
pub quant_hash: [u8; 32],
/// Backend that executed the inference
pub backend: BackendKind,
/// Compute cycles used (FPGA) or 0 (sim)
pub cycles: u32,
/// Latency in nanoseconds
pub latency_ns: u32,
/// Gate decision made
pub gate_decision: GateDecision,
}
impl WitnessLog {
/// Create a new WitnessLog
pub fn new(
model_hash: [u8; 32],
quant_hash: [u8; 32],
backend: BackendKind,
cycles: u32,
latency_ns: u32,
gate_decision: GateDecision,
) -> Self {
Self {
model_hash,
quant_hash,
backend,
cycles,
latency_ns,
gate_decision,
}
}
/// Create an empty witness (for testing)
pub fn empty() -> Self {
Self {
model_hash: [0u8; 32],
quant_hash: [0u8; 32],
backend: BackendKind::NativeSim,
cycles: 0,
latency_ns: 0,
gate_decision: GateDecision::RanFull,
}
}
}
/// Backend types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BackendKind {
/// PCIe-connected FPGA
FpgaPcie,
/// FPGA via local daemon
FpgaDaemon,
/// WASM simulator
WasmSim,
/// Native Rust simulator
NativeSim,
}
impl std::fmt::Display for BackendKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FpgaPcie => write!(f, "fpga_pcie"),
Self::FpgaDaemon => write!(f, "fpga_daemon"),
Self::WasmSim => write!(f, "wasm_sim"),
Self::NativeSim => write!(f, "native_sim"),
}
}
}
/// Gate decision outcome
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GateDecision {
/// Full inference completed
RanFull,
/// Early exit at specified layer
EarlyExit { layer: u8 },
/// Inference was skipped
Skipped { reason: SkipReason },
}
impl GateDecision {
/// Check if inference actually ran
pub const fn did_run(&self) -> bool {
!matches!(self, Self::Skipped { .. })
}
/// Get the exit layer (full = max layers)
pub const fn exit_layer(&self, max_layers: u8) -> u8 {
match self {
Self::RanFull => max_layers,
Self::EarlyExit { layer } => *layer,
Self::Skipped { .. } => 0,
}
}
}
/// Reason for skipping inference
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SkipReason {
/// Coherence score too low
LowCoherence,
/// Policy denied the inference
PolicyDenied,
/// Compute budget exceeded
BudgetExceeded,
}
impl std::fmt::Display for SkipReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::LowCoherence => write!(f, "low_coherence"),
Self::PolicyDenied => write!(f, "policy_denied"),
Self::BudgetExceeded => write!(f, "budget_exceeded"),
}
}
}
/// Quantized tensor wrapper
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
/// Raw quantized data
pub data: Vec<u8>,
/// Quantization specification
pub spec: QuantSpec,
/// Tensor shape (row-major)
pub shape: Vec<usize>,
}
impl QuantizedTensor {
/// Create a new quantized tensor
pub fn new(data: Vec<u8>, spec: QuantSpec, shape: Vec<usize>) -> Self {
Self { data, spec, shape }
}
/// Total number of elements
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
/// Expected data size in bytes
pub fn expected_bytes(&self) -> usize {
let numel = self.numel();
(numel + self.spec.weights_per_byte() - 1) / self.spec.weights_per_byte()
}
/// Validate tensor integrity
pub fn validate(&self) -> crate::error::Result<()> {
let expected = self.expected_bytes();
if self.data.len() != expected {
return Err(crate::error::Error::QuantizationError(format!(
"Data size mismatch: expected {} bytes, got {}",
expected,
self.data.len()
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_id_hex_roundtrip() {
let bytes = [0x12u8; 32];
let id = ModelId::new(bytes);
let hex = id.to_hex();
let parsed = ModelId::from_hex(&hex).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_fixed_shape_validate() {
let valid = FixedShape::new(64, 256, 8, 32, 32000);
assert!(valid.validate().is_ok());
let invalid = FixedShape::new(64, 256, 8, 16, 32000); // 8*16 != 256
assert!(invalid.validate().is_err());
}
#[test]
fn test_quant_spec_bytes() {
assert_eq!(QuantSpec::int8().weights_per_byte(), 1);
assert_eq!(QuantSpec::int4_int8().weights_per_byte(), 2);
}
#[test]
fn test_gate_decision() {
assert!(GateDecision::RanFull.did_run());
assert!(GateDecision::EarlyExit { layer: 3 }.did_run());
assert!(!GateDecision::Skipped {
reason: SkipReason::LowCoherence
}
.did_run());
}
}

View File

@@ -0,0 +1,215 @@
//! Witness hashing for integrity verification
use crate::types::WitnessLog;
use sha2::{Digest, Sha256};
/// Compute a hash of the witness log for integrity verification
pub fn compute_witness_hash(witness: &WitnessLog) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(&witness.model_hash);
hasher.update(&witness.quant_hash);
hasher.update(&[witness.backend as u8]);
hasher.update(&witness.cycles.to_le_bytes());
hasher.update(&witness.latency_ns.to_le_bytes());
// Hash gate decision
match witness.gate_decision {
crate::types::GateDecision::RanFull => {
hasher.update(&[0u8]);
}
crate::types::GateDecision::EarlyExit { layer } => {
hasher.update(&[1u8, layer]);
}
crate::types::GateDecision::Skipped { reason } => {
hasher.update(&[2u8, reason as u8]);
}
}
hasher.finalize().into()
}
/// Verify a witness hash
pub fn verify_witness_hash(witness: &WitnessLog, expected: &[u8; 32]) -> bool {
let computed = compute_witness_hash(witness);
computed == *expected
}
/// Compute a combined hash for a sequence of witnesses
/// Useful for verifying an entire inference chain
pub fn compute_chain_hash(witnesses: &[WitnessLog]) -> [u8; 32] {
let mut hasher = Sha256::new();
for witness in witnesses {
let witness_hash = compute_witness_hash(witness);
hasher.update(&witness_hash);
}
hasher.finalize().into()
}
/// Witness proof for verification
#[derive(Debug, Clone)]
pub struct WitnessProof {
/// Hash of the witness
pub hash: [u8; 32],
/// Timestamp when proof was created
pub timestamp_ns: u64,
/// Optional signature
pub signature: Option<[u8; 64]>,
}
impl WitnessProof {
/// Create a new proof from a witness
pub fn new(witness: &WitnessLog) -> Self {
Self {
hash: compute_witness_hash(witness),
timestamp_ns: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0),
signature: None,
}
}
/// Create a proof with signature
#[cfg(feature = "sign")]
pub fn signed(witness: &WitnessLog, secret_key: &[u8; 32]) -> Self {
use ed25519_dalek::{Signer, SigningKey};
let hash = compute_witness_hash(witness);
let timestamp_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
// Create message to sign
let mut message = [0u8; 40];
message[..32].copy_from_slice(&hash);
message[32..40].copy_from_slice(&timestamp_ns.to_le_bytes());
let signing_key = SigningKey::from_bytes(secret_key);
let signature = signing_key.sign(&message);
Self {
hash,
timestamp_ns,
signature: Some(signature.to_bytes()),
}
}
/// Verify the proof against a witness
pub fn verify(&self, witness: &WitnessLog) -> bool {
verify_witness_hash(witness, &self.hash)
}
/// Verify the signature
#[cfg(feature = "sign")]
pub fn verify_signature(&self, pubkey: &[u8; 32]) -> bool {
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
let Some(sig_bytes) = self.signature else {
return false;
};
let Ok(verifying_key) = VerifyingKey::from_bytes(pubkey) else {
return false;
};
let signature = Signature::from_bytes(&sig_bytes);
let mut message = [0u8; 40];
message[..32].copy_from_slice(&self.hash);
message[32..40].copy_from_slice(&self.timestamp_ns.to_le_bytes());
verifying_key.verify(&message, &signature).is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{BackendKind, GateDecision};
#[test]
fn test_witness_hash_deterministic() {
let witness = WitnessLog::new(
[1u8; 32],
[2u8; 32],
BackendKind::NativeSim,
1000,
50000,
GateDecision::RanFull,
);
let hash1 = compute_witness_hash(&witness);
let hash2 = compute_witness_hash(&witness);
assert_eq!(hash1, hash2);
}
#[test]
fn test_witness_hash_changes() {
let witness1 = WitnessLog::new(
[1u8; 32],
[2u8; 32],
BackendKind::NativeSim,
1000,
50000,
GateDecision::RanFull,
);
let witness2 = WitnessLog::new(
[1u8; 32],
[2u8; 32],
BackendKind::NativeSim,
1001, // Different cycles
50000,
GateDecision::RanFull,
);
let hash1 = compute_witness_hash(&witness1);
let hash2 = compute_witness_hash(&witness2);
assert_ne!(hash1, hash2);
}
#[test]
fn test_verify_witness_hash() {
let witness = WitnessLog::empty();
let hash = compute_witness_hash(&witness);
assert!(verify_witness_hash(&witness, &hash));
assert!(!verify_witness_hash(&witness, &[0u8; 32]));
}
#[test]
fn test_chain_hash() {
let witnesses: Vec<WitnessLog> = (0..5)
.map(|i| {
WitnessLog::new(
[i as u8; 32],
[0u8; 32],
BackendKind::NativeSim,
i * 100,
i * 1000,
GateDecision::RanFull,
)
})
.collect();
let chain_hash1 = compute_chain_hash(&witnesses);
let chain_hash2 = compute_chain_hash(&witnesses);
assert_eq!(chain_hash1, chain_hash2);
}
#[test]
fn test_witness_proof() {
let witness = WitnessLog::empty();
let proof = WitnessProof::new(&witness);
assert!(proof.verify(&witness));
assert!(proof.timestamp_ns > 0);
}
}

View File

@@ -0,0 +1,311 @@
//! Witness log builder and utilities
use crate::types::{BackendKind, GateDecision, WitnessLog};
use std::time::Instant;
/// Builder for creating witness logs
pub struct WitnessBuilder {
model_hash: [u8; 32],
quant_hash: [u8; 32],
backend: BackendKind,
start_time: Instant,
cycles: u32,
gate_decision: GateDecision,
}
impl WitnessBuilder {
/// Start building a new witness
pub fn new(backend: BackendKind) -> Self {
Self {
model_hash: [0u8; 32],
quant_hash: [0u8; 32],
backend,
start_time: Instant::now(),
cycles: 0,
gate_decision: GateDecision::RanFull,
}
}
/// Set model hash
pub fn model_hash(mut self, hash: [u8; 32]) -> Self {
self.model_hash = hash;
self
}
/// Set quantization hash
pub fn quant_hash(mut self, hash: [u8; 32]) -> Self {
self.quant_hash = hash;
self
}
/// Set compute cycles
pub fn cycles(mut self, cycles: u32) -> Self {
self.cycles = cycles;
self
}
/// Set gate decision
pub fn gate_decision(mut self, decision: GateDecision) -> Self {
self.gate_decision = decision;
self
}
/// Build the witness log
pub fn build(self) -> WitnessLog {
let latency_ns = self.start_time.elapsed().as_nanos() as u32;
WitnessLog::new(
self.model_hash,
self.quant_hash,
self.backend,
self.cycles,
latency_ns,
self.gate_decision,
)
}
}
impl WitnessLog {
/// Convert to compact bytes for storage
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(80);
bytes.extend_from_slice(&self.model_hash);
bytes.extend_from_slice(&self.quant_hash);
bytes.push(self.backend as u8);
bytes.extend_from_slice(&self.cycles.to_le_bytes());
bytes.extend_from_slice(&self.latency_ns.to_le_bytes());
// Encode gate decision
match self.gate_decision {
GateDecision::RanFull => {
bytes.push(0);
bytes.push(0);
}
GateDecision::EarlyExit { layer } => {
bytes.push(1);
bytes.push(layer);
}
GateDecision::Skipped { reason } => {
bytes.push(2);
bytes.push(reason as u8);
}
}
bytes
}
/// Parse from bytes
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 75 {
return None;
}
let model_hash: [u8; 32] = bytes[0..32].try_into().ok()?;
let quant_hash: [u8; 32] = bytes[32..64].try_into().ok()?;
let backend = match bytes[64] {
0 => BackendKind::FpgaPcie,
1 => BackendKind::FpgaDaemon,
2 => BackendKind::WasmSim,
3 => BackendKind::NativeSim,
_ => BackendKind::NativeSim,
};
let cycles = u32::from_le_bytes(bytes[65..69].try_into().ok()?);
let latency_ns = u32::from_le_bytes(bytes[69..73].try_into().ok()?);
let gate_decision = match bytes[73] {
0 => GateDecision::RanFull,
1 => GateDecision::EarlyExit { layer: bytes[74] },
2 => GateDecision::Skipped {
reason: match bytes[74] {
0 => crate::types::SkipReason::LowCoherence,
1 => crate::types::SkipReason::PolicyDenied,
_ => crate::types::SkipReason::BudgetExceeded,
},
},
_ => GateDecision::RanFull,
};
Some(Self {
model_hash,
quant_hash,
backend,
cycles,
latency_ns,
gate_decision,
})
}
/// Get latency in microseconds
pub fn latency_us(&self) -> f64 {
self.latency_ns as f64 / 1000.0
}
/// Get latency in milliseconds
pub fn latency_ms(&self) -> f64 {
self.latency_ns as f64 / 1_000_000.0
}
/// Check if this was a successful full inference
pub fn is_full_inference(&self) -> bool {
matches!(self.gate_decision, GateDecision::RanFull)
}
/// Check if this was an early exit
pub fn is_early_exit(&self) -> bool {
matches!(self.gate_decision, GateDecision::EarlyExit { .. })
}
/// Check if inference was skipped
pub fn is_skipped(&self) -> bool {
matches!(self.gate_decision, GateDecision::Skipped { .. })
}
}
/// Witness log aggregator for collecting statistics
#[derive(Debug, Default)]
pub struct WitnessAggregator {
/// Total inferences
pub count: u64,
/// Total cycles
pub total_cycles: u64,
/// Total latency (ns)
pub total_latency_ns: u64,
/// Full inference count
pub full_count: u64,
/// Early exit count
pub early_exit_count: u64,
/// Skipped count
pub skipped_count: u64,
/// Sum of squares for variance calculation
latency_sq_sum: u128,
}
impl WitnessAggregator {
/// Create a new aggregator
pub fn new() -> Self {
Self::default()
}
/// Add a witness to the aggregate
pub fn add(&mut self, witness: &WitnessLog) {
self.count += 1;
self.total_cycles += witness.cycles as u64;
self.total_latency_ns += witness.latency_ns as u64;
self.latency_sq_sum += (witness.latency_ns as u128).pow(2);
match witness.gate_decision {
GateDecision::RanFull => self.full_count += 1,
GateDecision::EarlyExit { .. } => self.early_exit_count += 1,
GateDecision::Skipped { .. } => self.skipped_count += 1,
}
}
/// Get average latency (ns)
pub fn avg_latency_ns(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_latency_ns as f64 / self.count as f64
}
}
/// Get average cycles
pub fn avg_cycles(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_cycles as f64 / self.count as f64
}
}
/// Get latency standard deviation (ns)
pub fn latency_std_ns(&self) -> f64 {
if self.count <= 1 {
return 0.0;
}
let mean = self.avg_latency_ns();
let variance = (self.latency_sq_sum as f64 / self.count as f64) - (mean * mean);
variance.sqrt()
}
/// Get early exit rate
pub fn early_exit_rate(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.early_exit_count as f64 / self.count as f64
}
}
/// Get skip rate
pub fn skip_rate(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.skipped_count as f64 / self.count as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_witness_builder() {
let witness = WitnessBuilder::new(BackendKind::NativeSim)
.model_hash([1u8; 32])
.quant_hash([2u8; 32])
.cycles(1000)
.gate_decision(GateDecision::RanFull)
.build();
assert_eq!(witness.model_hash, [1u8; 32]);
assert_eq!(witness.backend, BackendKind::NativeSim);
assert_eq!(witness.cycles, 1000);
}
#[test]
fn test_witness_bytes_roundtrip() {
let witness = WitnessLog::new(
[0x42u8; 32],
[0x24u8; 32],
BackendKind::FpgaDaemon,
5000,
100_000,
GateDecision::EarlyExit { layer: 4 },
);
let bytes = witness.to_bytes();
let parsed = WitnessLog::from_bytes(&bytes).unwrap();
assert_eq!(witness.model_hash, parsed.model_hash);
assert_eq!(witness.quant_hash, parsed.quant_hash);
assert_eq!(witness.backend, parsed.backend);
assert_eq!(witness.cycles, parsed.cycles);
assert_eq!(witness.latency_ns, parsed.latency_ns);
}
#[test]
fn test_witness_aggregator() {
let mut agg = WitnessAggregator::new();
for i in 0..10 {
let mut witness = WitnessLog::empty();
witness.latency_ns = 1000 * (i + 1);
witness.cycles = 100 * (i + 1);
if i < 3 {
witness.gate_decision = GateDecision::EarlyExit { layer: 2 };
}
agg.add(&witness);
}
assert_eq!(agg.count, 10);
assert_eq!(agg.early_exit_count, 3);
assert!((agg.early_exit_rate() - 0.3).abs() < 0.01);
}
}

View File

@@ -0,0 +1,12 @@
//! Witness logging for audit trails and ReasoningBank integration
//!
//! Every inference produces a small witness bundle that records
//! what happened and enables verification and replay.
pub mod hash;
pub mod log;
// Re-export WitnessLog from types as the canonical location
pub use crate::types::WitnessLog;
pub use hash::{compute_witness_hash, verify_witness_hash};
pub use log::{WitnessAggregator, WitnessBuilder};