Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
999
vendor/ruvector/docs/architecture/bitnet-quantizer-module-design.md
vendored
Normal file
999
vendor/ruvector/docs/architecture/bitnet-quantizer-module-design.md
vendored
Normal file
@@ -0,0 +1,999 @@
|
||||
# PT-BitNet Quantizer Module Architecture Design
|
||||
|
||||
**Version:** 1.0
|
||||
**Date:** 2026-02-03
|
||||
**Status:** Design Specification
|
||||
**Relates to:** ADR-017 (AD-1, AD-5, AD-18, AD-19), DDD Section 3.4/4.2/4.3
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document specifies the architecture for the **PT-BitNet post-training quantizer** module that converts FP16/BF16 GLM-4.7-Flash weights to BitNet b1.58 ternary {-1, 0, +1} format via absmean quantization. This is a **design-only specification** — implementation follows in Phase 0.
|
||||
|
||||
**Design Scope:**
|
||||
- Module layout and file organization
|
||||
- Complete struct definitions with field types
|
||||
- Full function signatures (no implementations)
|
||||
- GGUF integration points and format extensions
|
||||
- Error handling strategy
|
||||
- Testing approach
|
||||
|
||||
**Out of Scope:**
|
||||
- Actual implementation code
|
||||
- Performance benchmarks
|
||||
- Calibration dataset selection
|
||||
|
||||
---
|
||||
|
||||
## A. Module Layout
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
crates/ruvllm/src/
|
||||
├── bitnet/ # NEW module
|
||||
│ ├── mod.rs # Module exports and public API
|
||||
│ ├── quantizer.rs # PtBitnetQuantizer + absmean algorithm
|
||||
│ ├── ternary_tensor.rs # TernaryTensor value object
|
||||
│ ├── dequantize.rs # BITNET_T158 dequantization kernel
|
||||
│ └── config.rs # PtBitnetConfig configuration
|
||||
│
|
||||
├── gguf/
|
||||
│ ├── mod.rs # Add pub mod bitnet export
|
||||
│ ├── quantization.rs # MODIFIED: Add BITNET_T158 enum variant
|
||||
│ ├── parser.rs # Unchanged (reused as-is)
|
||||
│ └── ...
|
||||
│
|
||||
└── kernels/
|
||||
└── matmul.rs # Reference for dispatch patterns
|
||||
```
|
||||
|
||||
### Modified Files
|
||||
|
||||
#### `src/gguf/quantization.rs`
|
||||
|
||||
**Changes:**
|
||||
1. Add `BITNET_T158 = 30` variant to `GgufQuantType` enum (after `Bf16 = 29`)
|
||||
2. Update `try_from()` impl to handle type 30
|
||||
3. Update `block_size()` to return 256 for `BITNET_T158`
|
||||
4. Update `type_size()` to return 66 for `BITNET_T158` (64 bytes packed + 2 bytes FP16 scale)
|
||||
5. Update `is_quantized()` to include `BITNET_T158`
|
||||
6. Update `bits_per_weight()` to return 2.06 for `BITNET_T158`
|
||||
7. Add new match arm in `dequantize_tensor()` → `BITNET_T158 => dequantize_bitnet_t158(data, output)`
|
||||
|
||||
**Exact enum addition:**
|
||||
```rust
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[repr(u32)]
|
||||
pub enum GgufQuantType {
|
||||
// ... existing variants 0-29 ...
|
||||
/// BitNet b1.58 ternary quantization (2-bit packed + FP16 scale per 256-element block)
|
||||
BITNET_T158 = 30,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## B. Struct Definitions
|
||||
|
||||
### 1. `PtBitnetConfig` (in `bitnet/config.rs`)
|
||||
|
||||
**Purpose:** Configuration for PT-BitNet quantization process
|
||||
|
||||
```rust
|
||||
/// Configuration for PT-BitNet post-training quantization
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PtBitnetConfig {
|
||||
/// Block size for absmean scale computation (default: 256)
|
||||
pub block_size: usize,
|
||||
|
||||
/// Epsilon for numerical stability in scale computation (default: 1e-8)
|
||||
pub epsilon: f32,
|
||||
|
||||
/// Whether to run calibration pass to optimize scale factors
|
||||
pub use_calibration: bool,
|
||||
|
||||
/// Number of calibration samples (if use_calibration = true)
|
||||
pub calibration_samples: usize,
|
||||
|
||||
/// Maximum sequence length for calibration (default: 2048)
|
||||
pub calibration_max_seq_len: usize,
|
||||
|
||||
/// Device for calibration pass ("cpu", "metal", "cuda:0")
|
||||
pub calibration_device: String,
|
||||
|
||||
/// Clipping threshold for normalized weights before rounding
|
||||
/// (default: 1.0, range typically 0.95-1.05)
|
||||
pub clip_threshold: f32,
|
||||
|
||||
/// Sparsity target: if > 0.0, bias rounding toward zero to achieve target sparsity
|
||||
pub target_sparsity: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for PtBitnetConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_size: 256,
|
||||
epsilon: 1e-8,
|
||||
use_calibration: false,
|
||||
calibration_samples: 1000,
|
||||
calibration_max_seq_len: 2048,
|
||||
calibration_device: "metal".to_string(),
|
||||
clip_threshold: 1.0,
|
||||
target_sparsity: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. `TernaryTensor` (in `bitnet/ternary_tensor.rs`)
|
||||
|
||||
**Purpose:** Immutable value object for packed ternary weights
|
||||
|
||||
```rust
|
||||
/// Packed ternary tensor with per-block FP16 scales
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TernaryTensor {
|
||||
/// Packed 2-bit ternary values (4 weights per byte)
|
||||
/// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved
|
||||
pub packed_data: Vec<u8>,
|
||||
|
||||
/// Per-block FP16 scale factors (absmean values)
|
||||
pub scales: Vec<f16>,
|
||||
|
||||
/// Tensor shape [out_features, in_features] or [rows, cols]
|
||||
pub shape: [usize; 2],
|
||||
|
||||
/// Block size (always 256 for BitNet b1.58)
|
||||
pub block_size: usize,
|
||||
|
||||
/// Total number of weights
|
||||
pub num_elements: usize,
|
||||
|
||||
/// Number of blocks
|
||||
pub num_blocks: usize,
|
||||
|
||||
/// Measured sparsity (fraction of zero weights)
|
||||
pub sparsity: f32,
|
||||
}
|
||||
|
||||
impl TernaryTensor {
|
||||
/// Calculate total storage size in bytes
|
||||
pub fn storage_size(&self) -> usize;
|
||||
|
||||
/// Get expected packed_data size for validation
|
||||
pub fn expected_packed_size(&self) -> usize;
|
||||
|
||||
/// Validate internal consistency
|
||||
pub fn validate(&self) -> Result<()>;
|
||||
}
|
||||
```
|
||||
|
||||
### 3. `TernaryBlock` (in `bitnet/ternary_tensor.rs`)
|
||||
|
||||
**Purpose:** Single block of 256 ternary weights with scale
|
||||
|
||||
```rust
|
||||
/// A single 256-element block with ternary weights and FP16 scale
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TernaryBlock {
|
||||
/// 64 bytes of packed 2-bit values (256 weights × 2 bits ÷ 8 bits/byte)
|
||||
pub packed: [u8; 64],
|
||||
|
||||
/// FP16 absmean scale factor
|
||||
pub scale: f16,
|
||||
}
|
||||
|
||||
impl TernaryBlock {
|
||||
/// Size in bytes when stored in GGUF (64 + 2 = 66)
|
||||
pub const STORAGE_SIZE: usize = 66;
|
||||
|
||||
/// Number of elements in a block
|
||||
pub const BLOCK_SIZE: usize = 256;
|
||||
}
|
||||
```
|
||||
|
||||
### 4. `AbsmeanResult` (in `bitnet/quantizer.rs`)
|
||||
|
||||
**Purpose:** Result of absmean quantization on a single block
|
||||
|
||||
```rust
|
||||
/// Result of absmean ternary quantization on a block
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AbsmeanResult {
|
||||
/// Ternary values {-1, 0, +1} for each weight in the block
|
||||
pub ternary_weights: Vec<i8>,
|
||||
|
||||
/// Computed absmean scale factor (gamma = mean(|W|))
|
||||
pub scale: f32,
|
||||
|
||||
/// Measured sparsity (fraction of zeros)
|
||||
pub sparsity: f32,
|
||||
|
||||
/// Mean squared error vs original FP16 values (for calibration)
|
||||
pub mse: f32,
|
||||
}
|
||||
```
|
||||
|
||||
### 5. `QuantizationStats` (in `bitnet/quantizer.rs`)
|
||||
|
||||
**Purpose:** Statistics collected during quantization
|
||||
|
||||
```rust
|
||||
/// Statistics from quantizing a single tensor
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizationStats {
|
||||
/// Tensor name
|
||||
pub name: String,
|
||||
|
||||
/// Mean of all block scales
|
||||
pub mean_scale: f32,
|
||||
|
||||
/// Std dev of block scales
|
||||
pub std_scale: f32,
|
||||
|
||||
/// Overall sparsity across all blocks
|
||||
pub sparsity: f32,
|
||||
|
||||
/// Mean MSE across all blocks
|
||||
pub mean_mse: f32,
|
||||
|
||||
/// Number of blocks
|
||||
pub num_blocks: usize,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## C. Function Signatures
|
||||
|
||||
### Core Quantization Functions (in `bitnet/quantizer.rs`)
|
||||
|
||||
#### 1. Primary Quantization Entry Point
|
||||
|
||||
```rust
|
||||
/// Quantize an FP16/F32 tensor to ternary format using absmean quantization
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensor` - Input FP16 or F32 tensor data (flat vector)
|
||||
/// * `shape` - Tensor shape [out_features, in_features]
|
||||
/// * `config` - Quantization configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `TernaryTensor` - Packed ternary representation
|
||||
/// * `QuantizationStats` - Statistics about the quantization process
|
||||
///
|
||||
/// # Errors
|
||||
/// * `RuvLLMError::Quantization` if tensor size is not divisible by block_size
|
||||
/// * `RuvLLMError::Quantization` if shape product doesn't match tensor length
|
||||
pub fn quantize_tensor(
|
||||
tensor: &[f32],
|
||||
shape: [usize; 2],
|
||||
config: &PtBitnetConfig,
|
||||
) -> Result<(TernaryTensor, QuantizationStats)>;
|
||||
```
|
||||
|
||||
#### 2. Per-Block Quantization
|
||||
|
||||
```rust
|
||||
/// Apply absmean quantization to a single block of weights
|
||||
///
|
||||
/// Algorithm:
|
||||
/// 1. gamma = mean(|block|) + epsilon
|
||||
/// 2. normalized = block / gamma
|
||||
/// 3. ternary = round(clamp(normalized, -clip_threshold, +clip_threshold))
|
||||
/// 4. Map to {-1, 0, +1}
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `block` - Block of FP16/F32 values (length = config.block_size)
|
||||
/// * `config` - Configuration with epsilon and clip_threshold
|
||||
///
|
||||
/// # Returns
|
||||
/// * `AbsmeanResult` with ternary values, scale, sparsity, MSE
|
||||
///
|
||||
/// # Panics
|
||||
/// * If block.len() != config.block_size
|
||||
pub fn absmean_ternary(
|
||||
block: &[f32],
|
||||
config: &PtBitnetConfig,
|
||||
) -> AbsmeanResult;
|
||||
```
|
||||
|
||||
#### 3. Packing Functions
|
||||
|
||||
```rust
|
||||
/// Pack ternary {-1, 0, +1} values into 2-bit representation
|
||||
///
|
||||
/// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved (unused)
|
||||
/// 4 values packed per byte: [v3 v2 v1 v0] → byte
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `values` - Ternary values (must be {-1, 0, +1} only)
|
||||
///
|
||||
/// # Returns
|
||||
/// * Packed bytes (length = ceil(values.len() / 4))
|
||||
///
|
||||
/// # Errors
|
||||
/// * If any value is not in {-1, 0, +1}
|
||||
pub fn pack_ternary(values: &[i8]) -> Result<Vec<u8>>;
|
||||
|
||||
/// Unpack 2-bit representation to ternary {-1, 0, +1} values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `packed` - Packed 2-bit data
|
||||
/// * `n` - Number of values to extract
|
||||
///
|
||||
/// # Returns
|
||||
/// * Vector of ternary values (length = n)
|
||||
pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec<i8>;
|
||||
```
|
||||
|
||||
#### 4. Calibration (Optional)
|
||||
|
||||
```rust
|
||||
/// Run calibration pass to optimize scale factors
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensor` - Input FP16 tensor
|
||||
/// * `shape` - Tensor shape
|
||||
/// * `config` - Config with calibration settings
|
||||
/// * `calibration_data` - Sample activations for this layer
|
||||
///
|
||||
/// # Returns
|
||||
/// * Optimized `TernaryTensor` with calibrated scales
|
||||
///
|
||||
/// # Note
|
||||
/// This is optional - if not used, falls back to plain absmean
|
||||
pub fn quantize_with_calibration(
|
||||
tensor: &[f32],
|
||||
shape: [usize; 2],
|
||||
config: &PtBitnetConfig,
|
||||
calibration_data: &[Vec<f32>],
|
||||
) -> Result<(TernaryTensor, QuantizationStats)>;
|
||||
```
|
||||
|
||||
### Dequantization Functions (in `bitnet/dequantize.rs`)
|
||||
|
||||
```rust
|
||||
/// Dequantize BITNET_T158 tensor to FP32
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Raw GGUF tensor bytes (packed ternary + scales)
|
||||
/// * `scales` - Per-block FP16 scales (extracted from data)
|
||||
/// * `n` - Total number of elements to dequantize
|
||||
///
|
||||
/// # Returns
|
||||
/// * Vec<f32> of dequantized values
|
||||
///
|
||||
/// # Format
|
||||
/// Each block: [64 bytes packed ternary][2 bytes FP16 scale]
|
||||
pub fn dequantize_bitnet_t158(
|
||||
data: &[u8],
|
||||
scales: &[f16],
|
||||
n: usize,
|
||||
) -> Vec<f32>;
|
||||
|
||||
/// Dequantize a single BITNET_T158 block
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `block_data` - 64 bytes of packed ternary data
|
||||
/// * `scale` - FP16 scale factor
|
||||
/// * `output` - Output buffer (must have capacity for 256 elements)
|
||||
pub fn dequantize_bitnet_t158_block(
|
||||
block_data: &[u8; 64],
|
||||
scale: f16,
|
||||
output: &mut [f32],
|
||||
);
|
||||
```
|
||||
|
||||
### Tensor Conversion (in `bitnet/ternary_tensor.rs`)
|
||||
|
||||
```rust
|
||||
impl TernaryTensor {
|
||||
/// Convert from packed storage to FP32 (for validation/testing)
|
||||
pub fn to_fp32(&self) -> Vec<f32>;
|
||||
|
||||
/// Create from existing GGUF tensor data
|
||||
pub fn from_gguf_data(
|
||||
data: &[u8],
|
||||
shape: [usize; 2],
|
||||
block_size: usize,
|
||||
) -> Result<Self>;
|
||||
|
||||
/// Serialize to GGUF tensor bytes
|
||||
pub fn to_gguf_data(&self) -> Vec<u8>;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## D. GGUF Integration Points
|
||||
|
||||
### 1. New Quantization Type Variant
|
||||
|
||||
**File:** `crates/ruvllm/src/gguf/quantization.rs`
|
||||
|
||||
**Changes to `GgufQuantType` enum:**
|
||||
|
||||
```rust
|
||||
#[repr(u32)]
|
||||
pub enum GgufQuantType {
|
||||
// ... existing 0-29 ...
|
||||
|
||||
/// BitNet b1.58 ternary quantization
|
||||
/// Block size: 256 elements
|
||||
/// Storage: 64 bytes packed (2-bit) + 2 bytes FP16 scale = 66 bytes/block
|
||||
/// Bits per weight: 2.06 bpw
|
||||
BITNET_T158 = 30,
|
||||
}
|
||||
|
||||
impl GgufQuantType {
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
// ... existing cases ...
|
||||
Self::BITNET_T158 => 256,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn type_size(&self) -> usize {
|
||||
match self {
|
||||
// ... existing cases ...
|
||||
Self::BITNET_T158 => 66, // 64 + 2
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
// ... existing cases ...
|
||||
Self::BITNET_T158 => "BITNET_T158",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for GgufQuantType {
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
match value {
|
||||
// ... existing 0-29 ...
|
||||
30 => Ok(Self::BITNET_T158),
|
||||
_ => Err(/* ... */),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Dequantization Dispatch
|
||||
|
||||
**File:** `crates/ruvllm/src/gguf/quantization.rs`
|
||||
|
||||
**Modification to `dequantize_tensor()` function:**
|
||||
|
||||
```rust
|
||||
pub fn dequantize_tensor(
|
||||
data: &[u8],
|
||||
dtype: GgufQuantType,
|
||||
num_elements: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
let mut output = vec![0.0f32; num_elements];
|
||||
|
||||
match dtype {
|
||||
// ... existing cases ...
|
||||
GgufQuantType::BITNET_T158 => {
|
||||
// Extract scales and packed data
|
||||
let num_blocks = (num_elements + 255) / 256;
|
||||
let mut scales = Vec::with_capacity(num_blocks);
|
||||
|
||||
for i in 0..num_blocks {
|
||||
let block_offset = i * 66;
|
||||
let scale_offset = block_offset + 64;
|
||||
let scale_bytes = [data[scale_offset], data[scale_offset + 1]];
|
||||
scales.push(f16::from_le_bytes(scale_bytes));
|
||||
}
|
||||
|
||||
crate::bitnet::dequantize::dequantize_bitnet_t158(
|
||||
data,
|
||||
&scales,
|
||||
num_elements,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
return Err(RuvLLMError::Model(format!(
|
||||
"Dequantization not implemented for {:?}",
|
||||
dtype
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
```
|
||||
|
||||
### 3. GGUF Metadata Keys
|
||||
|
||||
**New metadata keys for BitNet models** (written during quantization, read during load):
|
||||
|
||||
```rust
|
||||
// In quantizer when exporting GGUF
|
||||
pub const BITNET_METADATA_KEYS: &[(&str, &str)] = &[
|
||||
("craftsman.bitnet.version", "1"),
|
||||
("craftsman.bitnet.weight_encoding", "absmean_ternary"),
|
||||
("craftsman.bitnet.activation_bits", "8"),
|
||||
("craftsman.bitnet.block_size", "256"),
|
||||
("craftsman.bitnet.kernel_hint", "tl1"), // or "tl2", "i2s"
|
||||
];
|
||||
```
|
||||
|
||||
**Metadata reading in model loader:**
|
||||
|
||||
```rust
|
||||
// In backend when loading model
|
||||
fn detect_bitnet_model(metadata: &HashMap<String, GgufValue>) -> bool {
|
||||
metadata.get("craftsman.bitnet.version")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|v| v == "1")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Tensor Info Extension
|
||||
|
||||
**No changes needed** - existing `TensorInfo` struct in `parser.rs` already supports:
|
||||
- `name: String`
|
||||
- `shape: Vec<usize>`
|
||||
- `dtype: GgufQuantType` ← Will now include `BITNET_T158`
|
||||
- `offset: u64`
|
||||
|
||||
---
|
||||
|
||||
## E. Error Handling Strategy
|
||||
|
||||
### Error Types
|
||||
|
||||
All errors use existing `RuvLLMError` enum from `crates/ruvllm/src/error.rs`:
|
||||
|
||||
```rust
|
||||
pub enum RuvLLMError {
|
||||
// Existing variants...
|
||||
|
||||
// Quantization-specific errors
|
||||
Quantization(String), // Use this variant for all quantization errors
|
||||
Model(String), // For GGUF format issues
|
||||
Config(String), // For invalid configuration
|
||||
}
|
||||
```
|
||||
|
||||
### Error Scenarios and Handling
|
||||
|
||||
| Scenario | Error Type | Recovery Strategy |
|
||||
|----------|-----------|-------------------|
|
||||
| Tensor size not divisible by block_size | `Quantization` | Pad last block with zeros |
|
||||
| Invalid ternary value during packing | `Quantization` | Fail-fast - indicates bug |
|
||||
| GGUF file has wrong BITNET_T158 block size | `Model` | Fail-fast - corrupted file |
|
||||
| Calibration device unavailable | `Config` | Fall back to non-calibrated quantization |
|
||||
| Out of memory during quantization | System panic | Let Rust OOM handler catch |
|
||||
| Shape mismatch in tensor | `Quantization` | Fail-fast - validate before processing |
|
||||
| FP16 scale is NaN/Inf | `Quantization` | Clamp to epsilon value |
|
||||
| Empty tensor / zero elements | `Quantization` | Skip with warning |
|
||||
|
||||
### Validation Functions
|
||||
|
||||
```rust
|
||||
/// Validate quantization config
|
||||
pub fn validate_config(config: &PtBitnetConfig) -> Result<()> {
|
||||
if config.block_size == 0 || config.block_size % 4 != 0 {
|
||||
return Err(RuvLLMError::Config(
|
||||
"block_size must be non-zero and divisible by 4".into()
|
||||
));
|
||||
}
|
||||
|
||||
if config.epsilon <= 0.0 {
|
||||
return Err(RuvLLMError::Config(
|
||||
"epsilon must be positive".into()
|
||||
));
|
||||
}
|
||||
|
||||
if config.clip_threshold <= 0.0 || config.clip_threshold > 2.0 {
|
||||
return Err(RuvLLMError::Config(
|
||||
"clip_threshold must be in range (0.0, 2.0]".into()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate tensor shape and size
|
||||
pub fn validate_tensor(
|
||||
tensor: &[f32],
|
||||
shape: [usize; 2],
|
||||
block_size: usize,
|
||||
) -> Result<()> {
|
||||
let expected_size = shape[0] * shape[1];
|
||||
|
||||
if tensor.len() != expected_size {
|
||||
return Err(RuvLLMError::Quantization(format!(
|
||||
"Tensor length {} doesn't match shape {:?} (expected {})",
|
||||
tensor.len(), shape, expected_size
|
||||
)));
|
||||
}
|
||||
|
||||
if expected_size % block_size != 0 {
|
||||
// Could pad, but for simplicity require exact multiple
|
||||
return Err(RuvLLMError::Quantization(format!(
|
||||
"Tensor size {} is not divisible by block_size {}",
|
||||
expected_size, block_size
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## F. Testing Strategy
|
||||
|
||||
### Unit Tests
|
||||
|
||||
#### 1. Absmean Quantization Correctness
|
||||
|
||||
**File:** `crates/ruvllm/src/bitnet/tests/quantizer_tests.rs`
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_absmean_ternary_basic() {
|
||||
// Test that absmean correctly quantizes known values
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
// Block with known mean(|x|) = 1.0
|
||||
let block: Vec<f32> = vec![
|
||||
2.0, -2.0, 1.0, -1.0, // gamma = mean(2,2,1,1,...) ≈ 1.0
|
||||
0.5, -0.5, 0.0, 0.0,
|
||||
// ... (pad to 256 elements)
|
||||
];
|
||||
|
||||
let result = absmean_ternary(&block, &config);
|
||||
|
||||
// After normalization: 2.0/1.0 = 2.0 → clamp to 1.0 → round to +1
|
||||
assert_eq!(result.ternary_weights[0], 1); // 2.0 → +1
|
||||
assert_eq!(result.ternary_weights[1], -1); // -2.0 → -1
|
||||
assert_eq!(result.ternary_weights[2], 1); // 1.0 → +1
|
||||
assert_eq!(result.ternary_weights[6], 0); // 0.0 → 0
|
||||
|
||||
assert!(result.scale > 0.9 && result.scale < 1.1); // gamma ≈ 1.0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absmean_all_zeros() {
|
||||
let config = PtBitnetConfig::default();
|
||||
let block = vec![0.0; 256];
|
||||
|
||||
let result = absmean_ternary(&block, &config);
|
||||
|
||||
// All zeros → scale = epsilon, all ternary = 0
|
||||
assert_eq!(result.scale, config.epsilon);
|
||||
assert!(result.ternary_weights.iter().all(|&x| x == 0));
|
||||
assert_eq!(result.sparsity, 1.0);
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Pack/Unpack Round-Trip
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_pack_unpack_roundtrip() {
|
||||
let original = vec![1i8, -1, 0, 1, 0, -1, 1, 0];
|
||||
|
||||
let packed = pack_ternary(&original).unwrap();
|
||||
assert_eq!(packed.len(), 2); // 8 values → 2 bytes
|
||||
|
||||
let unpacked = unpack_ternary(&packed, 8);
|
||||
assert_eq!(unpacked, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_invalid_value() {
|
||||
let invalid = vec![1i8, 2, 0]; // 2 is not ternary
|
||||
|
||||
let result = pack_ternary(&invalid);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
```
|
||||
|
||||
#### 3. Tensor Validation
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_validate_tensor_shape_mismatch() {
|
||||
let tensor = vec![1.0; 100];
|
||||
let shape = [10, 11]; // 10*11 = 110 ≠ 100
|
||||
|
||||
let result = validate_tensor(&tensor, shape, 256);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tensor_block_alignment() {
|
||||
let tensor = vec![1.0; 257]; // Not divisible by 256
|
||||
let shape = [1, 257];
|
||||
|
||||
let result = validate_tensor(&tensor, shape, 256);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
|
||||
#### 4. Full Quantization Pipeline
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_quantize_tensor_full_pipeline() {
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
// Create a 512-element tensor (2 blocks)
|
||||
let tensor: Vec<f32> = (0..512).map(|i| (i as f32) / 512.0).collect();
|
||||
let shape = [2, 256];
|
||||
|
||||
let (ternary, stats) = quantize_tensor(&tensor, shape, &config).unwrap();
|
||||
|
||||
assert_eq!(ternary.num_blocks, 2);
|
||||
assert_eq!(ternary.packed_data.len(), 2 * 64); // 2 blocks × 64 bytes
|
||||
assert_eq!(ternary.scales.len(), 2);
|
||||
assert_eq!(stats.num_blocks, 2);
|
||||
|
||||
// Verify reconstruction quality
|
||||
let reconstructed = ternary.to_fp32();
|
||||
assert_eq!(reconstructed.len(), 512);
|
||||
}
|
||||
```
|
||||
|
||||
#### 5. GGUF Round-Trip
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_gguf_serialization_roundtrip() {
|
||||
let config = PtBitnetConfig::default();
|
||||
let tensor = vec![1.0; 256];
|
||||
let shape = [1, 256];
|
||||
|
||||
let (ternary, _) = quantize_tensor(&tensor, shape, &config).unwrap();
|
||||
|
||||
// Serialize to GGUF format
|
||||
let gguf_data = ternary.to_gguf_data();
|
||||
assert_eq!(gguf_data.len(), 66); // 1 block = 66 bytes
|
||||
|
||||
// Deserialize
|
||||
let recovered = TernaryTensor::from_gguf_data(&gguf_data, shape, 256).unwrap();
|
||||
|
||||
assert_eq!(recovered.packed_data, ternary.packed_data);
|
||||
assert_eq!(recovered.scales, ternary.scales);
|
||||
}
|
||||
```
|
||||
|
||||
### Benchmark Tests
|
||||
|
||||
#### 6. Performance Regression
|
||||
|
||||
```rust
|
||||
#[bench]
|
||||
fn bench_absmean_ternary_256(b: &mut Bencher) {
|
||||
let config = PtBitnetConfig::default();
|
||||
let block: Vec<f32> = (0..256).map(|i| (i as f32) / 256.0).collect();
|
||||
|
||||
b.iter(|| {
|
||||
let _ = absmean_ternary(&block, &config);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_pack_ternary_1024(b: &mut Bencher) {
|
||||
let values = vec![1i8; 1024];
|
||||
|
||||
b.iter(|| {
|
||||
let _ = pack_ternary(&values);
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### Correctness Validation Tests
|
||||
|
||||
#### 7. Bit-Exact Validation Against Reference
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_dequantize_matches_reference() {
|
||||
// Reference implementation (naive)
|
||||
fn reference_dequant(ternary: &[i8], scale: f32) -> Vec<f32> {
|
||||
ternary.iter().map(|&t| (t as f32) * scale).collect()
|
||||
}
|
||||
|
||||
let config = PtBitnetConfig::default();
|
||||
let tensor = vec![1.5, -2.3, 0.1, -0.4]; // Extend to 256
|
||||
let tensor_256 = /* pad to 256 */;
|
||||
let shape = [1, 256];
|
||||
|
||||
let (ternary, _) = quantize_tensor(&tensor_256, shape, &config).unwrap();
|
||||
|
||||
// Unpack and dequantize
|
||||
let unpacked = unpack_ternary(&ternary.packed_data, 256);
|
||||
let reference = reference_dequant(&unpacked, ternary.scales[0].to_f32());
|
||||
let optimized = ternary.to_fp32();
|
||||
|
||||
// Allow small floating-point error
|
||||
for (r, o) in reference.iter().zip(optimized.iter()) {
|
||||
assert!((r - o).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Test Organization
|
||||
|
||||
```
|
||||
crates/ruvllm/src/bitnet/tests/
|
||||
├── quantizer_tests.rs # absmean, pack/unpack
|
||||
├── tensor_tests.rs # TernaryTensor validation
|
||||
├── dequantize_tests.rs # BITNET_T158 dequant
|
||||
├── integration_tests.rs # Full pipeline, GGUF round-trip
|
||||
└── benches.rs # Performance benchmarks
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## G. Implementation Phases
|
||||
|
||||
### Phase 0.1: Core Data Structures (~2-3 days)
|
||||
1. `bitnet/mod.rs` - module structure
|
||||
2. `bitnet/config.rs` - `PtBitnetConfig`
|
||||
3. `bitnet/ternary_tensor.rs` - `TernaryTensor`, `TernaryBlock`
|
||||
4. Unit tests for validation
|
||||
|
||||
### Phase 0.2: Quantization Algorithm (~3-4 days)
|
||||
1. `bitnet/quantizer.rs` - `absmean_ternary()`
|
||||
2. Pack/unpack functions
|
||||
3. `quantize_tensor()` main entry point
|
||||
4. Unit tests for correctness
|
||||
|
||||
### Phase 0.3: Dequantization (~2 days)
|
||||
1. `bitnet/dequantize.rs` - block and tensor dequant
|
||||
2. Integration with existing `quantization.rs`
|
||||
3. Round-trip tests
|
||||
|
||||
### Phase 0.4: GGUF Integration (~2-3 days)
|
||||
1. Modify `gguf/quantization.rs` - add `BITNET_T158` enum variant
|
||||
2. Add metadata keys
|
||||
3. GGUF serialization/deserialization
|
||||
4. Integration tests
|
||||
|
||||
### Phase 0.5: Validation & Benchmarks (~2 days)
|
||||
1. Full pipeline integration tests
|
||||
2. Performance benchmarks
|
||||
3. Bit-exact validation
|
||||
4. Documentation
|
||||
|
||||
**Total Estimated Effort:** ~13-16 days for clean, well-tested implementation
|
||||
|
||||
---
|
||||
|
||||
## H. Open Design Questions
|
||||
|
||||
| # | Question | Impact | Recommendation |
|
||||
|---|----------|--------|----------------|
|
||||
| 1 | Use `IQ1_S` (type 19) or new `BITNET_T158` (type 30)? | Compatibility | **New type 30** - cleaner separation, avoids confusion with IQ1_S's codebook format |
|
||||
| 2 | Padding strategy for last block if not aligned? | Correctness | **Zero-pad** - simplest, matches BitNet spec |
|
||||
| 3 | Should calibration be mandatory or optional? | Quality vs Speed | **Optional** - Phase 0 can work without it, add later if needed |
|
||||
| 4 | F16 or F32 for internal scale computation? | Precision | **F32 internally, store as F16** - extra precision during compute |
|
||||
| 5 | Handle NaN/Inf in input tensors? | Robustness | **Fail-fast** - corrupted weights should not be silently ignored |
|
||||
| 6 | Support block sizes other than 256? | Flexibility | **No** - BitNet spec is 256, simplifies code |
|
||||
| 7 | Multi-threading for per-block quantization? | Performance | **Not in Phase 0** - can add via rayon later |
|
||||
| 8 | Store sparsity per-block in GGUF? | Kernel optimization | **No** - compute on-the-fly during dequant, saves space |
|
||||
|
||||
---
|
||||
|
||||
## I. Dependencies and Prerequisites
|
||||
|
||||
### Existing RuvLLM Components (Reused)
|
||||
- `crates/ruvllm/src/error.rs` - `RuvLLMError` enum
|
||||
- `crates/ruvllm/src/gguf/parser.rs` - GGUF parsing (unchanged)
|
||||
- `crates/ruvllm/src/gguf/quantization.rs` - Enum + dispatch (modified)
|
||||
- `half` crate - FP16 support (already in Cargo.toml)
|
||||
|
||||
### New External Dependencies
|
||||
None - uses only existing dependencies
|
||||
|
||||
### Minimum Rust Version
|
||||
Same as RuvLLM (likely 1.70+)
|
||||
|
||||
---
|
||||
|
||||
## J. Non-Goals (Out of Scope)
|
||||
|
||||
1. **Calibration implementation** - Deferred to future phase
|
||||
2. **TL1/TL2 kernel implementation** - Separate ADR/DDD
|
||||
3. **Model loader integration** - Separate backend implementation
|
||||
4. **Performance optimization** - Phase 0 is correctness-first
|
||||
5. **WASM support** - Desktop/server only for Phase 0
|
||||
6. **Dynamic quantization** - Only post-training static
|
||||
7. **Mixed-precision strategies** - All-or-nothing ternary for Phase 0
|
||||
|
||||
---
|
||||
|
||||
## K. Success Criteria
|
||||
|
||||
**This design is complete when:**
|
||||
|
||||
1. All struct definitions have complete field specifications
|
||||
2. All function signatures are documented with arguments, returns, errors
|
||||
3. Module organization is clear and follows Rust conventions
|
||||
4. GGUF integration points are precisely specified
|
||||
5. Error handling covers all failure modes
|
||||
6. Test plan covers correctness, integration, and performance
|
||||
7. Implementation phases are realistic and sequenced
|
||||
8. Open questions are documented with recommendations
|
||||
|
||||
**Implementation is successful when:**
|
||||
|
||||
1. All unit tests pass
|
||||
2. Round-trip GGUF serialization is bit-exact
|
||||
3. Dequantization produces correct FP32 output
|
||||
4. Integration with existing GGUF pipeline works
|
||||
5. Quantization of GLM-4.7-Flash completes without errors
|
||||
6. Exported GGUF file is loadable by model loader
|
||||
|
||||
---
|
||||
|
||||
## Appendix A: Code Size Estimates
|
||||
|
||||
| File | Estimated Lines | Complexity |
|
||||
|------|----------------|------------|
|
||||
| `bitnet/mod.rs` | ~50 | Low |
|
||||
| `bitnet/config.rs` | ~80 | Low |
|
||||
| `bitnet/ternary_tensor.rs` | ~200 | Medium |
|
||||
| `bitnet/quantizer.rs` | ~350 | High |
|
||||
| `bitnet/dequantize.rs` | ~150 | Medium |
|
||||
| `gguf/quantization.rs` (changes) | ~100 | Low |
|
||||
| Tests | ~800 | Medium |
|
||||
| **Total** | **~1,730 lines** | |
|
||||
|
||||
**Comparison to ADR-018 estimate:** ~200-300 lines core quantizer → Actual ~350 lines (reasonable given struct overhead)
|
||||
|
||||
---
|
||||
|
||||
## Appendix B: Memory Layout Examples
|
||||
|
||||
### TernaryBlock Storage (66 bytes)
|
||||
|
||||
```
|
||||
Byte Offset | Content
|
||||
------------|--------
|
||||
0-63 | Packed 2-bit ternary (256 values)
|
||||
64-65 | FP16 scale (little-endian)
|
||||
```
|
||||
|
||||
### 2-Bit Packing Example
|
||||
|
||||
```
|
||||
Values: [+1, -1, 0, +1]
|
||||
Encoding: [10, 00, 01, 10]
|
||||
Packed byte: 10_00_01_10 = 0x86
|
||||
```
|
||||
|
||||
### GGUF Tensor Data Layout
|
||||
|
||||
```
|
||||
[TensorInfo] (in header)
|
||||
name: "model.layers.0.mlp.gate_proj.weight"
|
||||
shape: [4096, 11008]
|
||||
dtype: BITNET_T158 (30)
|
||||
offset: 0x1000
|
||||
|
||||
[Tensor Data] (at offset 0x1000)
|
||||
Block 0: [64 bytes packed][2 bytes scale]
|
||||
Block 1: [64 bytes packed][2 bytes scale]
|
||||
...
|
||||
Block N: [64 bytes packed][2 bytes scale]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**End of Design Document**
|
||||
|
||||
Reference in New Issue
Block a user