git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
4844 lines
169 KiB
Rust
4844 lines
169 KiB
Rust
//! BitNet b1.58 Inference Backend
|
|
//!
|
|
//! This module implements the `BitNetBackend` inference pipeline for BitNet b1.58
|
|
//! MoE models (e.g., GLM-4.7-Flash). It wires together the quantizer, TL1 kernel,
|
|
//! and MoE routing into a working inference pipeline.
|
|
//!
|
|
//! ## Phase 0 Scope
|
|
//!
|
|
//! - Attention is a placeholder (pass-through) for smoke testing
|
|
//! - MoE routing is fully functional (FP16 gate + softmax + top-K)
|
|
//! - Expert FFN uses real TL1 GEMV on ternary weights
|
|
//! - Embedding lookup and LM head are FP16 matmul
|
|
//!
|
|
//! ## Architecture
|
|
//!
|
|
//! ```text
|
|
//! Embedding (FP16) -> [Transformer Layers] -> RMSNorm -> LM Head (FP16) -> Logits
|
|
//!
|
|
//! Each Transformer Layer:
|
|
//! RMSNorm -> Attention (placeholder) -> Residual
|
|
//! -> RMSNorm -> MoE Gate (FP16) -> Top-K Expert Selection
|
|
//! -> Expert FFN (TL1 GEMV on ternary) -> Weighted Sum -> Residual
|
|
//! ```
|
|
|
|
use std::path::Path;
|
|
use std::sync::Mutex;
|
|
|
|
use crate::backends::{
|
|
GenerateParams, GeneratedToken, LlmBackend, ModelArchitecture, ModelConfig, ModelInfo,
|
|
Quantization, SpecialTokens as BackendSpecialTokens, StreamEvent, TokenStream,
|
|
Tokenizer as BackendTokenizer,
|
|
};
|
|
use crate::error::{Result, RuvLLMError};
|
|
use crate::gguf::{GgufFile, GgufQuantType};
|
|
|
|
use super::ternary_tensor::TernaryTensor;
|
|
use super::tokenizer::{BpeTokenizer, SpecialTokens as BitNetSpecialTokens};
|
|
|
|
// ============================================================================
|
|
// Configuration
|
|
// ============================================================================
|
|
|
|
/// Model configuration for BitNet MoE inference.
|
|
///
|
|
/// Describes the architecture dimensions extracted from GGUF metadata
|
|
/// or supplied manually for testing. Supports both standard GQA attention
|
|
/// and MLA (Multi-Head Latent Attention) as used by GLM-4.7-Flash.
|
|
#[derive(Debug, Clone)]
|
|
pub struct BitNetModelConfig {
|
|
/// Number of transformer layers
|
|
pub num_layers: usize,
|
|
/// Hidden state dimension
|
|
pub hidden_size: usize,
|
|
/// Number of MoE routed experts per layer
|
|
pub num_experts: usize,
|
|
/// Number of active experts per token (top-K)
|
|
pub active_experts: usize,
|
|
/// Dense FFN intermediate dimension (for dense layers)
|
|
pub intermediate_size: usize,
|
|
/// MoE expert FFN intermediate dimension (may differ from dense)
|
|
pub moe_intermediate_size: usize,
|
|
/// Number of attention query heads
|
|
pub num_attention_heads: usize,
|
|
/// Number of attention key-value heads (GQA; equals num_attention_heads in MLA)
|
|
pub num_kv_heads: usize,
|
|
/// Vocabulary size
|
|
pub vocab_size: usize,
|
|
/// Maximum context length
|
|
pub max_context: usize,
|
|
/// RoPE frequency base
|
|
pub rope_theta: f32,
|
|
|
|
// --- MLA (Multi-Head Latent Attention) parameters ---
|
|
/// Whether attention uses MLA (true) or standard GQA (false)
|
|
pub use_mla: bool,
|
|
/// Q low-rank compression dimension (MLA)
|
|
pub q_lora_rank: usize,
|
|
/// KV low-rank compression dimension (MLA)
|
|
pub kv_lora_rank: usize,
|
|
/// Non-RoPE portion of Q/K head dimension (MLA)
|
|
pub qk_nope_head_dim: usize,
|
|
/// RoPE portion of Q/K head dimension (MLA)
|
|
pub qk_rope_head_dim: usize,
|
|
/// Value head dimension (MLA)
|
|
pub v_head_dim: usize,
|
|
|
|
// --- MoE structure ---
|
|
/// Number of shared experts (always-active, non-routed)
|
|
pub n_shared_experts: usize,
|
|
/// First N layers use dense FFN instead of MoE (e.g., 1 means layer 0 is dense)
|
|
pub first_k_dense_replace: usize,
|
|
/// Scaling factor for routed expert weights
|
|
pub routed_scaling_factor: f32,
|
|
}
|
|
|
|
impl Default for BitNetModelConfig {
|
|
fn default() -> Self {
|
|
// Default values matching GLM-4.7-Flash architecture
|
|
Self {
|
|
num_layers: 47,
|
|
hidden_size: 2048,
|
|
num_experts: 64,
|
|
active_experts: 4,
|
|
intermediate_size: 10240,
|
|
moe_intermediate_size: 1536,
|
|
num_attention_heads: 20,
|
|
num_kv_heads: 20,
|
|
vocab_size: 154880,
|
|
max_context: 8192,
|
|
rope_theta: 1_000_000.0,
|
|
// MLA parameters from GLM-4.7-Flash config.json
|
|
use_mla: true,
|
|
q_lora_rank: 768,
|
|
kv_lora_rank: 512,
|
|
qk_nope_head_dim: 192,
|
|
qk_rope_head_dim: 64,
|
|
v_head_dim: 256,
|
|
// MoE structure
|
|
n_shared_experts: 1,
|
|
first_k_dense_replace: 1,
|
|
routed_scaling_factor: 1.8,
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// TL1 Lookup Table
|
|
// ============================================================================
|
|
|
|
/// Pre-computed lookup table for packed 2-bit ternary bytes.
|
|
///
|
|
/// For each of the 256 possible byte values, stores the four decoded
|
|
/// ternary values {-1, 0, +1}. This avoids per-element bit manipulation
|
|
/// during the hot GEMV inner loop.
|
|
type Tl1Lut = [[i8; 4]; 256];
|
|
|
|
/// Build the TL1 lookup table at load time.
|
|
///
|
|
/// Encoding per the ternary_tensor module:
|
|
/// - 00 = -1, 01 = 0, 10 = +1, 11 = 0 (reserved)
|
|
fn build_tl1_lut() -> Tl1Lut {
|
|
let mut lut = [[0i8; 4]; 256];
|
|
for byte_val in 0u16..256 {
|
|
for pos in 0..4 {
|
|
let bits = ((byte_val as u8) >> (pos * 2)) & 0b11;
|
|
lut[byte_val as usize][pos] = match bits {
|
|
0b00 => -1,
|
|
0b01 => 0,
|
|
0b10 => 1,
|
|
0b11 => 0, // reserved
|
|
_ => unreachable!(),
|
|
};
|
|
}
|
|
}
|
|
lut
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tensor Name Mapper
|
|
// ============================================================================
|
|
|
|
/// Resolves logical tensor names to actual GGUF tensor names.
|
|
///
|
|
/// GLM-4.7-Flash GGUF files use llama.cpp conventions (`blk.0.attn_q_a.weight`),
|
|
/// while some models use HuggingFace conventions (`model.layers.0.self_attn.q_proj.weight`).
|
|
/// The mapper tries GGUF names first, then HuggingFace names as fallback.
|
|
struct TensorNameMapper;
|
|
|
|
impl TensorNameMapper {
|
|
/// Find the first tensor name that exists in the GGUF file.
|
|
fn resolve(gguf: &GgufFile, candidates: &[String]) -> Option<String> {
|
|
for name in candidates {
|
|
if gguf.get_tensor(name).is_some() {
|
|
return Some(name.clone());
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
// -- Global tensors --
|
|
|
|
fn embedding() -> Vec<String> {
|
|
vec![
|
|
"token_embd.weight".into(),
|
|
"model.embed_tokens.weight".into(),
|
|
]
|
|
}
|
|
|
|
fn output() -> Vec<String> {
|
|
vec!["output.weight".into(), "lm_head.weight".into()]
|
|
}
|
|
|
|
fn final_norm() -> Vec<String> {
|
|
vec!["output_norm.weight".into(), "model.norm.weight".into()]
|
|
}
|
|
|
|
// -- Per-layer norms --
|
|
|
|
fn input_norm(idx: usize) -> Vec<String> {
|
|
vec![
|
|
format!("blk.{}.attn_norm.weight", idx),
|
|
format!("model.layers.{}.input_layernorm.weight", idx),
|
|
]
|
|
}
|
|
|
|
fn post_attn_norm(idx: usize) -> Vec<String> {
|
|
vec![
|
|
format!("blk.{}.ffn_norm.weight", idx),
|
|
format!("model.layers.{}.post_attention_layernorm.weight", idx),
|
|
]
|
|
}
|
|
|
|
// -- MLA attention tensors --
|
|
|
|
fn attn_q_a(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.attn_q_a.weight", idx)]
|
|
}
|
|
|
|
fn attn_q_b(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.attn_q_b.weight", idx)]
|
|
}
|
|
|
|
fn attn_q_a_norm(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.attn_q_a_norm.weight", idx)]
|
|
}
|
|
|
|
fn attn_kv_a_mqa(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.attn_kv_a_mqa.weight", idx)]
|
|
}
|
|
|
|
fn attn_kv_a_norm(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.attn_kv_a_norm.weight", idx)]
|
|
}
|
|
|
|
fn attn_k_b(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.attn_k_b.weight", idx)]
|
|
}
|
|
|
|
fn attn_v_b(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.attn_v_b.weight", idx)]
|
|
}
|
|
|
|
fn attn_output(idx: usize) -> Vec<String> {
|
|
vec![
|
|
format!("blk.{}.attn_output.weight", idx),
|
|
format!("model.layers.{}.self_attn.o_proj.weight", idx),
|
|
]
|
|
}
|
|
|
|
// -- Standard GQA attention tensors --
|
|
|
|
fn attn_q_proj(idx: usize) -> Vec<String> {
|
|
vec![format!("model.layers.{}.self_attn.q_proj.weight", idx)]
|
|
}
|
|
|
|
fn attn_k_proj(idx: usize) -> Vec<String> {
|
|
vec![format!("model.layers.{}.self_attn.k_proj.weight", idx)]
|
|
}
|
|
|
|
fn attn_v_proj(idx: usize) -> Vec<String> {
|
|
vec![format!("model.layers.{}.self_attn.v_proj.weight", idx)]
|
|
}
|
|
|
|
// -- MoE router gate --
|
|
|
|
fn moe_gate(idx: usize) -> Vec<String> {
|
|
vec![
|
|
format!("blk.{}.ffn_gate_inp.weight", idx),
|
|
format!("model.layers.{}.mlp.gate.weight", idx),
|
|
]
|
|
}
|
|
|
|
// -- Dense FFN tensors --
|
|
|
|
fn ffn_gate(idx: usize) -> Vec<String> {
|
|
vec![
|
|
format!("blk.{}.ffn_gate.weight", idx),
|
|
format!("model.layers.{}.mlp.gate_proj.weight", idx),
|
|
]
|
|
}
|
|
|
|
fn ffn_up(idx: usize) -> Vec<String> {
|
|
vec![
|
|
format!("blk.{}.ffn_up.weight", idx),
|
|
format!("model.layers.{}.mlp.up_proj.weight", idx),
|
|
]
|
|
}
|
|
|
|
fn ffn_down(idx: usize) -> Vec<String> {
|
|
vec![
|
|
format!("blk.{}.ffn_down.weight", idx),
|
|
format!("model.layers.{}.mlp.down_proj.weight", idx),
|
|
]
|
|
}
|
|
|
|
// -- Shared expert tensors --
|
|
|
|
fn ffn_gate_shexp(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.ffn_gate_shexp.weight", idx)]
|
|
}
|
|
|
|
fn ffn_up_shexp(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.ffn_up_shexp.weight", idx)]
|
|
}
|
|
|
|
fn ffn_down_shexp(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.ffn_down_shexp.weight", idx)]
|
|
}
|
|
|
|
// -- Stacked expert tensors (3D, all experts in one tensor) --
|
|
|
|
fn ffn_gate_exps(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.ffn_gate_exps.weight", idx)]
|
|
}
|
|
|
|
fn ffn_up_exps(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.ffn_up_exps.weight", idx)]
|
|
}
|
|
|
|
fn ffn_down_exps(idx: usize) -> Vec<String> {
|
|
vec![format!("blk.{}.ffn_down_exps.weight", idx)]
|
|
}
|
|
|
|
// -- Per-expert tensors (HuggingFace individual naming) --
|
|
|
|
fn expert_gate(idx: usize, expert_idx: usize) -> Vec<String> {
|
|
vec![format!(
|
|
"model.layers.{}.mlp.experts.{}.gate_proj.weight",
|
|
idx, expert_idx
|
|
)]
|
|
}
|
|
|
|
fn expert_up(idx: usize, expert_idx: usize) -> Vec<String> {
|
|
vec![format!(
|
|
"model.layers.{}.mlp.experts.{}.up_proj.weight",
|
|
idx, expert_idx
|
|
)]
|
|
}
|
|
|
|
fn expert_down(idx: usize, expert_idx: usize) -> Vec<String> {
|
|
vec![format!(
|
|
"model.layers.{}.mlp.experts.{}.down_proj.weight",
|
|
idx, expert_idx
|
|
)]
|
|
}
|
|
|
|
/// Check if a layer has MLA attention tensors.
|
|
fn has_mla(gguf: &GgufFile, idx: usize) -> bool {
|
|
Self::resolve(gguf, &Self::attn_q_a(idx)).is_some()
|
|
}
|
|
|
|
/// Check if a layer has stacked expert tensors.
|
|
fn has_stacked_experts(gguf: &GgufFile, idx: usize) -> bool {
|
|
Self::resolve(gguf, &Self::ffn_gate_exps(idx)).is_some()
|
|
}
|
|
|
|
/// Check if a layer has dense FFN (not MoE).
|
|
fn has_dense_ffn(gguf: &GgufFile, idx: usize) -> bool {
|
|
Self::resolve(gguf, &Self::ffn_gate(idx)).is_some()
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Per-Layer and Per-Expert Weight Storage
|
|
// ============================================================================
|
|
|
|
/// Ternary weights for a single MoE expert (gate, up, down projections).
|
|
#[derive(Debug, Clone)]
|
|
struct ExpertWeights {
|
|
/// gate_proj: [intermediate_size, hidden_size]
|
|
gate_proj: TernaryTensor,
|
|
/// up_proj: [intermediate_size, hidden_size]
|
|
up_proj: TernaryTensor,
|
|
/// down_proj: [hidden_size, intermediate_size]
|
|
down_proj: TernaryTensor,
|
|
}
|
|
|
|
/// Attention projection weights.
|
|
///
|
|
/// Supports two variants:
|
|
/// - **Standard GQA**: Direct Q/K/V/O projections
|
|
/// - **MLA (Multi-Head Latent Attention)**: Low-rank compressed Q/KV projections
|
|
/// as used by GLM-4.7-Flash / DeepSeek-V2
|
|
#[derive(Debug, Clone)]
|
|
struct AttentionWeights {
|
|
/// Whether this layer uses MLA or standard GQA
|
|
is_mla: bool,
|
|
|
|
// --- Standard GQA fields ---
|
|
/// Q projection: [num_heads * head_dim, hidden_size]
|
|
q_proj: TernaryTensor,
|
|
/// K projection: [num_kv_heads * head_dim, hidden_size]
|
|
k_proj: TernaryTensor,
|
|
/// V projection: [num_kv_heads * head_dim, hidden_size]
|
|
v_proj: TernaryTensor,
|
|
/// Output projection: [hidden_size, num_heads * head_dim]
|
|
o_proj: TernaryTensor,
|
|
|
|
// --- MLA fields (populated when is_mla = true) ---
|
|
/// Q down-projection: [hidden_size → q_lora_rank]
|
|
q_a: Option<TernaryTensor>,
|
|
/// Q up-projection: [q_lora_rank → num_heads * (qk_nope_head_dim + qk_rope_head_dim)]
|
|
q_b: Option<TernaryTensor>,
|
|
/// Q compression norm weights: [q_lora_rank]
|
|
q_a_norm: Option<Vec<f32>>,
|
|
/// KV joint down-projection: [hidden_size → kv_lora_rank + qk_rope_head_dim]
|
|
kv_a_mqa: Option<TernaryTensor>,
|
|
/// KV compression norm weights: [kv_lora_rank]
|
|
kv_a_norm: Option<Vec<f32>>,
|
|
/// K up-projection: [kv_lora_rank → num_heads * qk_nope_head_dim]
|
|
k_b: Option<TernaryTensor>,
|
|
/// V up-projection: [kv_lora_rank → num_heads * v_head_dim]
|
|
v_b: Option<TernaryTensor>,
|
|
}
|
|
|
|
/// Type of FFN in a transformer layer.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
enum LayerType {
|
|
/// Dense FFN (single gate/up/down, no MoE routing)
|
|
Dense,
|
|
/// MoE with routed experts only
|
|
Moe,
|
|
/// MoE with routed experts + shared expert(s)
|
|
MoeWithShared,
|
|
}
|
|
|
|
/// Weights for a single transformer layer.
|
|
#[derive(Debug, Clone)]
|
|
struct TransformerLayer {
|
|
/// Input RMSNorm weight [hidden_size]
|
|
input_norm_weight: Vec<f32>,
|
|
/// Post-attention RMSNorm weight [hidden_size]
|
|
post_attn_norm_weight: Vec<f32>,
|
|
/// Attention projection weights (ternary, supports MLA or GQA)
|
|
attention: AttentionWeights,
|
|
/// Type of FFN in this layer
|
|
layer_type: LayerType,
|
|
/// MoE router gate weight [num_experts, hidden_size] (FP32, empty for dense layers)
|
|
gate_weight: Vec<f32>,
|
|
/// Per-expert FFN weights (routed experts, ternary)
|
|
experts: Vec<ExpertWeights>,
|
|
/// Shared expert FFN weights (always-active, non-routed; None for dense layers)
|
|
shared_expert: Option<ExpertWeights>,
|
|
/// Dense FFN weights (for dense-only layers; uses gate/up/down from ExpertWeights)
|
|
dense_ffn: Option<ExpertWeights>,
|
|
}
|
|
|
|
// ============================================================================
|
|
// KV Cache
|
|
// ============================================================================
|
|
|
|
/// Per-layer KV cache for autoregressive generation.
|
|
#[derive(Debug, Clone)]
|
|
struct LayerKvCache {
|
|
/// Cached key vectors: one [num_kv_heads * head_dim] per position
|
|
keys: Vec<Vec<f32>>,
|
|
/// Cached value vectors: one [num_kv_heads * head_dim] per position
|
|
values: Vec<Vec<f32>>,
|
|
}
|
|
|
|
impl LayerKvCache {
|
|
fn new() -> Self {
|
|
Self {
|
|
keys: Vec::new(),
|
|
values: Vec::new(),
|
|
}
|
|
}
|
|
|
|
fn clear(&mut self) {
|
|
self.keys.clear();
|
|
self.values.clear();
|
|
}
|
|
|
|
fn len(&self) -> usize {
|
|
self.keys.len()
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Scratch Memory Pool (Zero-Allocation Forward Pass)
|
|
// ============================================================================
|
|
|
|
/// Pre-allocated scratch buffers to eliminate per-token heap allocations
|
|
/// in the forward pass. All hot-path vectors are pre-sized to the maximum
|
|
/// needed dimension and reused across tokens.
|
|
struct ScratchPool {
|
|
/// General-purpose buffer [hidden_size] — used for normed, residual, etc.
|
|
buf_hidden_a: Vec<f32>,
|
|
buf_hidden_b: Vec<f32>,
|
|
buf_hidden_c: Vec<f32>,
|
|
/// Buffer for attention Q output [num_heads * head_dim]
|
|
buf_attn_q: Vec<f32>,
|
|
/// Buffer for attention K output [num_kv_heads * head_dim or num_heads * q_head_dim]
|
|
buf_attn_k: Vec<f32>,
|
|
/// Buffer for attention V output [num_kv_heads * head_dim or num_heads * v_dim]
|
|
buf_attn_v: Vec<f32>,
|
|
/// Buffer for attention output [hidden_size or num_heads * v_dim]
|
|
buf_attn_out: Vec<f32>,
|
|
/// Buffer for FFN intermediate [intermediate_size]
|
|
buf_ffn_gate: Vec<f32>,
|
|
buf_ffn_up: Vec<f32>,
|
|
buf_ffn_fused: Vec<f32>,
|
|
buf_ffn_down: Vec<f32>,
|
|
/// Buffer for expert output accumulation [hidden_size]
|
|
buf_expert_out: Vec<f32>,
|
|
/// Buffer for logits [vocab_size]
|
|
buf_logits: Vec<f32>,
|
|
/// Buffer for MLA compressed Q [q_lora_rank]
|
|
buf_mla_cq: Vec<f32>,
|
|
/// Buffer for MLA Q full [num_heads * q_head_dim]
|
|
buf_mla_qfull: Vec<f32>,
|
|
/// Buffer for MLA KV combined [kv_lora_rank + qk_rope_head_dim]
|
|
buf_mla_kv: Vec<f32>,
|
|
/// TL1 GEMV output buffer (reusable for arbitrary sizes)
|
|
buf_gemv: Vec<f32>,
|
|
}
|
|
|
|
impl ScratchPool {
|
|
fn new() -> Self {
|
|
Self {
|
|
buf_hidden_a: Vec::new(),
|
|
buf_hidden_b: Vec::new(),
|
|
buf_hidden_c: Vec::new(),
|
|
buf_attn_q: Vec::new(),
|
|
buf_attn_k: Vec::new(),
|
|
buf_attn_v: Vec::new(),
|
|
buf_attn_out: Vec::new(),
|
|
buf_ffn_gate: Vec::new(),
|
|
buf_ffn_up: Vec::new(),
|
|
buf_ffn_fused: Vec::new(),
|
|
buf_ffn_down: Vec::new(),
|
|
buf_expert_out: Vec::new(),
|
|
buf_logits: Vec::new(),
|
|
buf_mla_cq: Vec::new(),
|
|
buf_mla_qfull: Vec::new(),
|
|
buf_mla_kv: Vec::new(),
|
|
buf_gemv: Vec::new(),
|
|
}
|
|
}
|
|
|
|
/// Pre-allocate all buffers based on model config. Called once after loading.
|
|
fn allocate(&mut self, config: &BitNetModelConfig) {
|
|
let h = config.hidden_size;
|
|
let q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim;
|
|
let attn_dim = config.num_attention_heads * q_head_dim;
|
|
let v_total = config.num_attention_heads * config.v_head_dim;
|
|
let inter = config.intermediate_size.max(config.moe_intermediate_size);
|
|
|
|
self.buf_hidden_a = vec![0.0; h];
|
|
self.buf_hidden_b = vec![0.0; h];
|
|
self.buf_hidden_c = vec![0.0; h];
|
|
self.buf_attn_q = vec![0.0; attn_dim];
|
|
self.buf_attn_k = vec![0.0; attn_dim];
|
|
self.buf_attn_v = vec![0.0; v_total.max(attn_dim)];
|
|
self.buf_attn_out = vec![0.0; v_total.max(h)];
|
|
self.buf_ffn_gate = vec![0.0; inter];
|
|
self.buf_ffn_up = vec![0.0; inter];
|
|
self.buf_ffn_fused = vec![0.0; inter];
|
|
self.buf_ffn_down = vec![0.0; h];
|
|
self.buf_expert_out = vec![0.0; h];
|
|
self.buf_logits = vec![0.0; config.vocab_size];
|
|
self.buf_mla_cq = vec![0.0; config.q_lora_rank];
|
|
self.buf_mla_qfull = vec![0.0; attn_dim];
|
|
self.buf_mla_kv = vec![0.0; config.kv_lora_rank + config.qk_rope_head_dim];
|
|
self.buf_gemv = vec![0.0; attn_dim.max(inter).max(h)];
|
|
}
|
|
|
|
/// Total memory used by scratch buffers.
|
|
fn memory_bytes(&self) -> usize {
|
|
(self.buf_hidden_a.len()
|
|
+ self.buf_hidden_b.len()
|
|
+ self.buf_hidden_c.len()
|
|
+ self.buf_attn_q.len()
|
|
+ self.buf_attn_k.len()
|
|
+ self.buf_attn_v.len()
|
|
+ self.buf_attn_out.len()
|
|
+ self.buf_ffn_gate.len()
|
|
+ self.buf_ffn_up.len()
|
|
+ self.buf_ffn_fused.len()
|
|
+ self.buf_ffn_down.len()
|
|
+ self.buf_expert_out.len()
|
|
+ self.buf_logits.len()
|
|
+ self.buf_mla_cq.len()
|
|
+ self.buf_mla_qfull.len()
|
|
+ self.buf_mla_kv.len()
|
|
+ self.buf_gemv.len())
|
|
* 4
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// BitNetBackend
|
|
// ============================================================================
|
|
|
|
/// BitNet b1.58 MoE inference backend.
|
|
///
|
|
/// Provides model loading from GGUF and forward pass inference using
|
|
/// ternary TL1 GEMV kernels for expert FFN layers and FP32 for shared
|
|
/// layers (embeddings, norms, router, LM head).
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```rust,ignore
|
|
/// use ruvllm::bitnet::backend::BitNetBackend;
|
|
/// use ruvllm::backends::{LlmBackend, ModelConfig, GenerateParams};
|
|
///
|
|
/// let mut backend = BitNetBackend::new();
|
|
/// backend.load_model("model.gguf", ModelConfig::default())?;
|
|
///
|
|
/// let logits = backend.forward(&[1, 2, 3])?;
|
|
/// ```
|
|
pub struct BitNetBackend {
|
|
/// Model configuration (set after load)
|
|
config: Option<BitNetModelConfig>,
|
|
/// Embedding table [vocab_size * hidden_size], row-major FP32
|
|
embedding: Vec<f32>,
|
|
/// LM head weight [vocab_size * hidden_size], row-major FP32
|
|
lm_head: Vec<f32>,
|
|
/// Final RMSNorm weight [hidden_size]
|
|
final_norm_weight: Vec<f32>,
|
|
/// Transformer layers
|
|
layers: Vec<TransformerLayer>,
|
|
/// Pre-computed TL1 lookup table
|
|
tl1_lut: Tl1Lut,
|
|
/// Per-layer KV caches for autoregressive generation
|
|
kv_caches: Vec<LayerKvCache>,
|
|
/// Tokenizer (loaded from GGUF or byte-level fallback)
|
|
tok: Option<BpeTokenizer>,
|
|
/// Pre-computed RoPE cos/sin tables [max_context, head_dim/2]
|
|
rope_cos: Vec<f32>,
|
|
rope_sin: Vec<f32>,
|
|
/// Whether a model is loaded
|
|
loaded: bool,
|
|
/// Model path (for info)
|
|
model_path: String,
|
|
/// Pre-allocated scratch buffers for zero-alloc forward pass
|
|
scratch: ScratchPool,
|
|
/// Per-layer routing history for expert prediction (last N positions).
|
|
/// Uses Mutex for interior mutability so forward_ffn can track routing
|
|
/// decisions without requiring &mut self (needed for LlmBackend trait compat).
|
|
routing_history: Mutex<Vec<Vec<usize>>>,
|
|
/// Maximum routing history length
|
|
max_routing_history: usize,
|
|
/// Cached expert predictor, rebuilt periodically from routing history.
|
|
/// Used to prefetch likely-next experts before they're computed.
|
|
expert_predictor: Option<ExpertPredictor>,
|
|
/// Number of routing history entries since last predictor rebuild.
|
|
predictor_stale_count: usize,
|
|
/// Per-layer compressed MLA KV caches (used instead of `kv_caches` for MLA layers).
|
|
mla_caches: Vec<CompressedMlaCache>,
|
|
/// When true, MLA layers store compressed latents (c_kv + k_pe) instead of
|
|
/// full K/V vectors, giving ~17.8x memory reduction at the cost of recomputing
|
|
/// K_nope and V during attention. Ideal for memory-constrained targets (Pi 5).
|
|
use_compressed_kv: bool,
|
|
}
|
|
|
|
impl BitNetBackend {
|
|
/// Create a new unloaded BitNetBackend.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
config: None,
|
|
embedding: Vec::new(),
|
|
lm_head: Vec::new(),
|
|
final_norm_weight: Vec::new(),
|
|
layers: Vec::new(),
|
|
tl1_lut: build_tl1_lut(),
|
|
kv_caches: Vec::new(),
|
|
tok: None,
|
|
rope_cos: Vec::new(),
|
|
rope_sin: Vec::new(),
|
|
loaded: false,
|
|
model_path: String::new(),
|
|
scratch: ScratchPool::new(),
|
|
routing_history: Mutex::new(Vec::new()),
|
|
max_routing_history: 128,
|
|
expert_predictor: None,
|
|
predictor_stale_count: 0,
|
|
mla_caches: Vec::new(),
|
|
use_compressed_kv: false,
|
|
}
|
|
}
|
|
|
|
/// Enable or disable compressed MLA KV cache mode.
|
|
///
|
|
/// When enabled, MLA layers store only the compressed latents (c_kv + k_pe)
|
|
/// instead of full K/V vectors, giving ~17.8x memory reduction. K_nope and V
|
|
/// are recomputed from the compressed latent during attention, which trades
|
|
/// compute for memory. Ideal for memory-constrained targets (e.g., Pi 5).
|
|
pub fn set_compressed_kv(&mut self, enabled: bool) {
|
|
self.use_compressed_kv = enabled;
|
|
}
|
|
|
|
/// Returns whether compressed MLA KV cache mode is enabled.
|
|
pub fn compressed_kv_enabled(&self) -> bool {
|
|
self.use_compressed_kv
|
|
}
|
|
|
|
/// Clear the KV cache (call between sequences).
|
|
pub fn reset_cache(&mut self) {
|
|
for cache in &mut self.kv_caches {
|
|
cache.clear();
|
|
}
|
|
for cache in &mut self.mla_caches {
|
|
cache.clear();
|
|
}
|
|
}
|
|
|
|
// ========================================================================
|
|
// Model Loading
|
|
// ========================================================================
|
|
|
|
/// Load a BitNet MoE model from a GGUF file.
|
|
///
|
|
/// Parses the GGUF file, extracts model configuration from metadata,
|
|
/// separates FP16 shared tensors from ternary expert tensors, and
|
|
/// pre-builds the TL1 lookup table.
|
|
///
|
|
/// Supports both llama.cpp GGUF tensor naming (`token_embd.weight`,
|
|
/// `blk.0.attn_q_a.weight`) and HuggingFace naming (`model.embed_tokens.weight`,
|
|
/// `model.layers.0.self_attn.q_proj.weight`).
|
|
fn load_gguf(&mut self, path: &str) -> Result<()> {
|
|
let gguf = GgufFile::open_mmap(Path::new(path))?;
|
|
|
|
// Extract model config from GGUF metadata
|
|
let config = self.extract_config(&gguf)?;
|
|
|
|
// Load embedding table via name mapper
|
|
let emb_name = TensorNameMapper::resolve(&gguf, &TensorNameMapper::embedding())
|
|
.ok_or_else(|| RuvLLMError::NotFound(
|
|
"Embedding tensor not found (tried: token_embd.weight, model.embed_tokens.weight)".into()
|
|
))?;
|
|
self.embedding = self.load_fp_tensor(&gguf, &emb_name, &config)?;
|
|
|
|
// Load LM head / output via name mapper (fallback to tied embeddings)
|
|
self.lm_head =
|
|
if let Some(out_name) = TensorNameMapper::resolve(&gguf, &TensorNameMapper::output()) {
|
|
self.load_fp_tensor(&gguf, &out_name, &config)?
|
|
} else {
|
|
self.embedding.clone()
|
|
};
|
|
|
|
// Load final norm via name mapper
|
|
let norm_name = TensorNameMapper::resolve(&gguf, &TensorNameMapper::final_norm())
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(
|
|
"Final norm tensor not found (tried: output_norm.weight, model.norm.weight)"
|
|
.into(),
|
|
)
|
|
})?;
|
|
self.final_norm_weight = self.load_fp_tensor(&gguf, &norm_name, &config)?;
|
|
|
|
// Load transformer layers
|
|
self.layers = Vec::with_capacity(config.num_layers);
|
|
for layer_idx in 0..config.num_layers {
|
|
let layer = self.load_layer(&gguf, layer_idx, &config)?;
|
|
self.layers.push(layer);
|
|
}
|
|
|
|
// Initialize KV caches (one per layer, pre-allocated for 512 positions)
|
|
let pre_alloc_seq = 512.min(config.max_context);
|
|
self.kv_caches = (0..config.num_layers)
|
|
.map(|_| {
|
|
let mut cache = LayerKvCache::new();
|
|
cache.keys.reserve(pre_alloc_seq);
|
|
cache.values.reserve(pre_alloc_seq);
|
|
cache
|
|
})
|
|
.collect();
|
|
|
|
// Initialize compressed MLA caches (one per layer for MLA layers)
|
|
self.mla_caches = (0..config.num_layers)
|
|
.map(|_| CompressedMlaCache::new())
|
|
.collect();
|
|
|
|
// Build RoPE cos/sin tables
|
|
// For MLA, rope applies only to qk_rope_head_dim portion
|
|
let rope_dim = if config.use_mla {
|
|
config.qk_rope_head_dim
|
|
} else {
|
|
config.hidden_size / config.num_attention_heads
|
|
};
|
|
self.build_rope_tables(config.max_context.min(8192), rope_dim, config.rope_theta);
|
|
|
|
// Load tokenizer from GGUF metadata
|
|
self.tok = self.load_tokenizer_from_gguf(&gguf);
|
|
|
|
// Pre-allocate scratch memory pool
|
|
self.scratch.allocate(&config);
|
|
|
|
// Initialize routing history
|
|
self.routing_history.lock().unwrap().clear();
|
|
|
|
self.config = Some(config);
|
|
self.loaded = true;
|
|
self.model_path = path.to_string();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Pre-compute RoPE frequency tables.
|
|
fn build_rope_tables(&mut self, max_seq: usize, head_dim: usize, theta: f32) {
|
|
let half = head_dim / 2;
|
|
let total = max_seq * half;
|
|
self.rope_cos = vec![0.0; total];
|
|
self.rope_sin = vec![0.0; total];
|
|
|
|
for pos in 0..max_seq {
|
|
for i in 0..half {
|
|
let freq = 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32);
|
|
let angle = pos as f32 * freq;
|
|
self.rope_cos[pos * half + i] = angle.cos();
|
|
self.rope_sin[pos * half + i] = angle.sin();
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Load tokenizer from GGUF metadata, falling back to byte-level tokenizer.
|
|
fn load_tokenizer_from_gguf(&self, gguf: &GgufFile) -> Option<BpeTokenizer> {
|
|
// Try to extract token list from GGUF
|
|
let tokens_meta = gguf.metadata.get("tokenizer.ggml.tokens");
|
|
let merges_meta = gguf.metadata.get("tokenizer.ggml.merges");
|
|
|
|
if let Some(tokens_arr) = tokens_meta.and_then(|v| v.as_array()) {
|
|
let vocab: Vec<String> = tokens_arr
|
|
.iter()
|
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
.collect();
|
|
|
|
let merges: Vec<(String, String)> =
|
|
if let Some(merges_arr) = merges_meta.and_then(|v| v.as_array()) {
|
|
merges_arr
|
|
.iter()
|
|
.filter_map(|v| {
|
|
let s = v.as_str()?;
|
|
let mut parts = s.splitn(2, ' ');
|
|
let left = parts.next()?.to_string();
|
|
let right = parts.next()?.to_string();
|
|
Some((left, right))
|
|
})
|
|
.collect()
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
if !vocab.is_empty() {
|
|
return Some(BpeTokenizer::from_vocab(
|
|
vocab,
|
|
merges,
|
|
BitNetSpecialTokens::default(),
|
|
));
|
|
}
|
|
}
|
|
|
|
// Fallback: construct a byte-level tokenizer (260 tokens)
|
|
Some(Self::build_byte_level_tokenizer())
|
|
}
|
|
|
|
/// Build a minimal byte-level tokenizer for when GGUF has no vocab.
|
|
fn build_byte_level_tokenizer() -> BpeTokenizer {
|
|
let mut vocab = vec![
|
|
"<PAD>".to_string(), // 0
|
|
"<BOS>".to_string(), // 1
|
|
"<EOS>".to_string(), // 2
|
|
"<UNK>".to_string(), // 3
|
|
];
|
|
for b in 0..=255u8 {
|
|
vocab.push(format!("<{:02X}>", b));
|
|
}
|
|
BpeTokenizer::from_vocab(vocab, vec![], BitNetSpecialTokens::default())
|
|
}
|
|
|
|
/// Extract BitNetModelConfig from GGUF metadata.
|
|
fn extract_config(&self, gguf: &GgufFile) -> Result<BitNetModelConfig> {
|
|
let defaults = BitNetModelConfig::default();
|
|
let num_layers = gguf.layer_count().unwrap_or(defaults.num_layers);
|
|
let hidden_size = gguf.embedding_length().unwrap_or(defaults.hidden_size);
|
|
let num_attention_heads = gguf.head_count().unwrap_or(defaults.num_attention_heads);
|
|
let num_kv_heads = gguf.head_count_kv().unwrap_or(defaults.num_kv_heads);
|
|
let vocab_size = gguf.vocab_size().unwrap_or(defaults.vocab_size);
|
|
let max_context = gguf.context_length().unwrap_or(defaults.max_context);
|
|
let rope_theta = gguf.rope_freq_base().unwrap_or(defaults.rope_theta);
|
|
let intermediate_size = gguf
|
|
.feed_forward_length()
|
|
.unwrap_or(defaults.intermediate_size);
|
|
|
|
// Detect expert count from tensor names or metadata
|
|
let num_experts = self
|
|
.detect_expert_count(gguf)
|
|
.or_else(|| Self::meta_usize(gguf, "llm.expert_count"))
|
|
.unwrap_or(defaults.num_experts);
|
|
|
|
// Active experts per token
|
|
let active_experts = Self::meta_usize(gguf, "llm.expert_used_count")
|
|
.or_else(|| Self::meta_usize(gguf, "model.expert_count_active"))
|
|
.unwrap_or(defaults.active_experts);
|
|
|
|
// MoE intermediate size (may differ from dense intermediate_size)
|
|
let moe_intermediate_size = Self::meta_usize(gguf, "llm.expert_feed_forward_length")
|
|
.unwrap_or(defaults.moe_intermediate_size);
|
|
|
|
// MLA parameters
|
|
let q_lora_rank =
|
|
Self::meta_usize(gguf, "llm.attention.q_lora_rank").unwrap_or(defaults.q_lora_rank);
|
|
let kv_lora_rank =
|
|
Self::meta_usize(gguf, "llm.attention.kv_lora_rank").unwrap_or(defaults.kv_lora_rank);
|
|
let qk_nope_head_dim = Self::meta_usize(gguf, "llm.attention.key_length_nope")
|
|
.unwrap_or(defaults.qk_nope_head_dim);
|
|
let qk_rope_head_dim = Self::meta_usize(gguf, "llm.attention.key_length_rope")
|
|
.or_else(|| gguf.rope_dimension_count())
|
|
.unwrap_or(defaults.qk_rope_head_dim);
|
|
let v_head_dim =
|
|
Self::meta_usize(gguf, "llm.attention.value_length").unwrap_or(defaults.v_head_dim);
|
|
|
|
// Detect MLA by checking for q_a tensor in first layer
|
|
let use_mla = TensorNameMapper::has_mla(gguf, 0);
|
|
|
|
// Shared experts
|
|
let n_shared_experts =
|
|
Self::meta_usize(gguf, "llm.expert_shared_count").unwrap_or(if num_experts > 1 {
|
|
defaults.n_shared_experts
|
|
} else {
|
|
0
|
|
});
|
|
|
|
// First K dense layers
|
|
let first_k_dense_replace = Self::meta_usize(gguf, "llm.expert_first_dense_layers")
|
|
.unwrap_or(defaults.first_k_dense_replace);
|
|
|
|
// Routed scaling factor
|
|
let routed_scaling_factor = Self::meta_f32(gguf, "llm.expert_weights_scale")
|
|
.unwrap_or(defaults.routed_scaling_factor);
|
|
|
|
Ok(BitNetModelConfig {
|
|
num_layers,
|
|
hidden_size,
|
|
num_experts,
|
|
active_experts,
|
|
intermediate_size,
|
|
moe_intermediate_size,
|
|
num_attention_heads,
|
|
num_kv_heads,
|
|
vocab_size,
|
|
max_context,
|
|
rope_theta,
|
|
use_mla,
|
|
q_lora_rank,
|
|
kv_lora_rank,
|
|
qk_nope_head_dim,
|
|
qk_rope_head_dim,
|
|
v_head_dim,
|
|
n_shared_experts,
|
|
first_k_dense_replace,
|
|
routed_scaling_factor,
|
|
})
|
|
}
|
|
|
|
/// Helper: extract a usize from GGUF metadata.
|
|
fn meta_usize(gguf: &GgufFile, key: &str) -> Option<usize> {
|
|
gguf.metadata
|
|
.get(key)
|
|
.and_then(|v| v.as_u64())
|
|
.map(|v| v as usize)
|
|
}
|
|
|
|
/// Helper: extract an f32 from GGUF metadata.
|
|
fn meta_f32(gguf: &GgufFile, key: &str) -> Option<f32> {
|
|
gguf.metadata.get(key).and_then(|v| v.as_f32())
|
|
}
|
|
|
|
/// Detect the number of MoE experts by scanning tensor names.
|
|
fn detect_expert_count(&self, gguf: &GgufFile) -> Option<usize> {
|
|
let mut max_expert_idx = 0usize;
|
|
let mut found_any = false;
|
|
|
|
for tensor in &gguf.tensors {
|
|
// Look for patterns like "experts.0.", "experts.7.", etc.
|
|
if let Some(pos) = tensor.name.find("experts.") {
|
|
let after = &tensor.name[pos + 8..];
|
|
if let Some(dot) = after.find('.') {
|
|
if let Ok(idx) = after[..dot].parse::<usize>() {
|
|
max_expert_idx = max_expert_idx.max(idx);
|
|
found_any = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if found_any {
|
|
Some(max_expert_idx + 1)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Load an FP16/FP32 tensor from GGUF, returning FP32 data.
|
|
fn load_fp_tensor(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
name: &str,
|
|
_config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
match gguf.get_tensor(name) {
|
|
Some(_) => gguf.load_tensor_f32(name),
|
|
None => Err(RuvLLMError::NotFound(format!(
|
|
"Required tensor not found: {}",
|
|
name
|
|
))),
|
|
}
|
|
}
|
|
|
|
/// Load a ternary tensor from GGUF (BitnetT158 or dequant + re-quantize).
|
|
fn load_ternary_tensor(&self, gguf: &GgufFile, name: &str) -> Result<TernaryTensor> {
|
|
let info = gguf
|
|
.get_tensor(name)
|
|
.ok_or_else(|| RuvLLMError::NotFound(format!("Tensor not found: {}", name)))?;
|
|
|
|
if info.dtype == GgufQuantType::BitnetT158 {
|
|
// Native ternary format: extract packed data and scales directly
|
|
let raw = gguf.load_tensor_quantized(name)?;
|
|
let num_elements = info.num_elements();
|
|
let block_size = 256usize;
|
|
let num_blocks = (num_elements + block_size - 1) / block_size;
|
|
let type_size = 66usize; // 64 packed + 2 FP16 scale
|
|
|
|
let mut packed_data = Vec::with_capacity(num_blocks * 64);
|
|
let mut scales = Vec::with_capacity(num_blocks);
|
|
|
|
for blk in 0..num_blocks {
|
|
let offset = blk * type_size;
|
|
if offset + type_size > raw.data.len() {
|
|
break;
|
|
}
|
|
packed_data.extend_from_slice(&raw.data[offset..offset + 64]);
|
|
let scale_bits = u16::from_le_bytes([raw.data[offset + 64], raw.data[offset + 65]]);
|
|
scales.push(f16_to_f32(scale_bits));
|
|
}
|
|
|
|
let shape = if info.shape.len() == 2 {
|
|
(info.shape[0], info.shape[1])
|
|
} else {
|
|
(1, num_elements)
|
|
};
|
|
|
|
Ok(TernaryTensor {
|
|
packed_data,
|
|
scales,
|
|
shape,
|
|
block_size,
|
|
})
|
|
} else {
|
|
// Non-native format: dequantize to FP32, then quantize to ternary
|
|
let fp32 = gguf.load_tensor_f32(name)?;
|
|
let num_elements = fp32.len();
|
|
let shape = if info.shape.len() == 2 {
|
|
(info.shape[0], info.shape[1])
|
|
} else {
|
|
(1, num_elements)
|
|
};
|
|
|
|
let ptconfig = super::quantizer::PtBitnetConfig::default();
|
|
super::quantizer::quantize_tensor(&fp32, shape, &ptconfig)
|
|
}
|
|
}
|
|
|
|
/// Load a single transformer layer.
|
|
///
|
|
/// Detects the layer type (dense vs MoE), attention type (MLA vs GQA),
|
|
/// and expert tensor format (stacked 3D vs individual) from the GGUF file.
|
|
fn load_layer(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<TransformerLayer> {
|
|
// Norm weights via name mapper
|
|
let in_norm_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::input_norm(idx))
|
|
.ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} input norm not found", idx)))?;
|
|
let input_norm_weight = self.load_fp_tensor(gguf, &in_norm_name, config)?;
|
|
|
|
let post_norm_name =
|
|
TensorNameMapper::resolve(gguf, &TensorNameMapper::post_attn_norm(idx)).ok_or_else(
|
|
|| RuvLLMError::NotFound(format!("Layer {} post-attn norm not found", idx)),
|
|
)?;
|
|
let post_attn_norm_weight = self.load_fp_tensor(gguf, &post_norm_name, config)?;
|
|
|
|
// === Attention weights ===
|
|
let attention = if TensorNameMapper::has_mla(gguf, idx) {
|
|
self.load_mla_attention(gguf, idx, config)?
|
|
} else {
|
|
self.load_gqa_attention(gguf, idx, config)?
|
|
};
|
|
|
|
// === FFN weights ===
|
|
let is_dense_layer =
|
|
idx < config.first_k_dense_replace || TensorNameMapper::has_dense_ffn(gguf, idx);
|
|
|
|
if is_dense_layer {
|
|
// Dense FFN layer (no MoE routing)
|
|
let dense_ffn = self.load_dense_ffn(gguf, idx, config)?;
|
|
Ok(TransformerLayer {
|
|
input_norm_weight,
|
|
post_attn_norm_weight,
|
|
attention,
|
|
layer_type: LayerType::Dense,
|
|
gate_weight: Vec::new(),
|
|
experts: Vec::new(),
|
|
shared_expert: None,
|
|
dense_ffn: Some(dense_ffn),
|
|
})
|
|
} else {
|
|
// MoE layer: load router gate + experts
|
|
let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::moe_gate(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} MoE gate not found", idx))
|
|
})?;
|
|
let gate_weight = self.load_fp_tensor(gguf, &gate_name, config)?;
|
|
|
|
let experts = self.load_experts(gguf, idx, config)?;
|
|
|
|
// Try loading shared expert
|
|
let shared_expert = self.load_shared_expert(gguf, idx, config).ok();
|
|
|
|
let layer_type = if shared_expert.is_some() {
|
|
LayerType::MoeWithShared
|
|
} else {
|
|
LayerType::Moe
|
|
};
|
|
|
|
Ok(TransformerLayer {
|
|
input_norm_weight,
|
|
post_attn_norm_weight,
|
|
attention,
|
|
layer_type,
|
|
gate_weight,
|
|
experts,
|
|
shared_expert,
|
|
dense_ffn: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Load MLA attention weights for a layer.
|
|
fn load_mla_attention(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
_config: &BitNetModelConfig,
|
|
) -> Result<AttentionWeights> {
|
|
// MLA projections
|
|
let q_a_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_a(idx))
|
|
.ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_q_a not found", idx)))?;
|
|
let q_a = self.load_ternary_tensor(gguf, &q_a_name)?;
|
|
|
|
let q_b_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_b(idx))
|
|
.ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_q_b not found", idx)))?;
|
|
let q_b = self.load_ternary_tensor(gguf, &q_b_name)?;
|
|
|
|
let kv_a_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_kv_a_mqa(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} attn_kv_a_mqa not found", idx))
|
|
})?;
|
|
let kv_a_mqa = self.load_ternary_tensor(gguf, &kv_a_name)?;
|
|
|
|
let k_b_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_k_b(idx))
|
|
.ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_k_b not found", idx)))?;
|
|
let k_b = self.load_ternary_tensor(gguf, &k_b_name)?;
|
|
|
|
let v_b_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_v_b(idx))
|
|
.ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_v_b not found", idx)))?;
|
|
let v_b = self.load_ternary_tensor(gguf, &v_b_name)?;
|
|
|
|
let o_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_output(idx))
|
|
.ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_output not found", idx)))?;
|
|
let o_proj = self.load_ternary_tensor(gguf, &o_name)?;
|
|
|
|
// Norm weights for MLA compression (may or may not be present)
|
|
let q_a_norm = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_a_norm(idx))
|
|
.and_then(|n| self.load_fp_tensor(gguf, &n, _config).ok());
|
|
let kv_a_norm = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_kv_a_norm(idx))
|
|
.and_then(|n| self.load_fp_tensor(gguf, &n, _config).ok());
|
|
|
|
// Use o_proj as placeholder for the standard fields (they won't be used in MLA path)
|
|
let placeholder = TernaryTensor {
|
|
packed_data: vec![],
|
|
scales: vec![],
|
|
shape: (0, 0),
|
|
block_size: 256,
|
|
};
|
|
|
|
Ok(AttentionWeights {
|
|
is_mla: true,
|
|
q_proj: placeholder.clone(),
|
|
k_proj: placeholder.clone(),
|
|
v_proj: placeholder,
|
|
o_proj,
|
|
q_a: Some(q_a),
|
|
q_b: Some(q_b),
|
|
q_a_norm,
|
|
kv_a_mqa: Some(kv_a_mqa),
|
|
kv_a_norm,
|
|
k_b: Some(k_b),
|
|
v_b: Some(v_b),
|
|
})
|
|
}
|
|
|
|
/// Load standard GQA attention weights for a layer.
|
|
fn load_gqa_attention(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
_config: &BitNetModelConfig,
|
|
) -> Result<AttentionWeights> {
|
|
let q_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_proj(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} Q projection not found", idx))
|
|
})?;
|
|
let q_proj = self.load_ternary_tensor(gguf, &q_name)?;
|
|
|
|
let k_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_k_proj(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} K projection not found", idx))
|
|
})?;
|
|
let k_proj = self.load_ternary_tensor(gguf, &k_name)?;
|
|
|
|
let v_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_v_proj(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} V projection not found", idx))
|
|
})?;
|
|
let v_proj = self.load_ternary_tensor(gguf, &v_name)?;
|
|
|
|
let o_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_output(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} O projection not found", idx))
|
|
})?;
|
|
let o_proj = self.load_ternary_tensor(gguf, &o_name)?;
|
|
|
|
Ok(AttentionWeights {
|
|
is_mla: false,
|
|
q_proj,
|
|
k_proj,
|
|
v_proj,
|
|
o_proj,
|
|
q_a: None,
|
|
q_b: None,
|
|
q_a_norm: None,
|
|
kv_a_mqa: None,
|
|
kv_a_norm: None,
|
|
k_b: None,
|
|
v_b: None,
|
|
})
|
|
}
|
|
|
|
/// Load dense FFN weights for a layer (no MoE).
|
|
fn load_dense_ffn(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
_config: &BitNetModelConfig,
|
|
) -> Result<ExpertWeights> {
|
|
let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_gate(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} dense ffn_gate not found", idx))
|
|
})?;
|
|
let up_name =
|
|
TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_up(idx)).ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} dense ffn_up not found", idx))
|
|
})?;
|
|
let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_down(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} dense ffn_down not found", idx))
|
|
})?;
|
|
|
|
Ok(ExpertWeights {
|
|
gate_proj: self.load_ternary_tensor(gguf, &gate_name)?,
|
|
up_proj: self.load_ternary_tensor(gguf, &up_name)?,
|
|
down_proj: self.load_ternary_tensor(gguf, &down_name)?,
|
|
})
|
|
}
|
|
|
|
/// Load shared expert weights for a layer.
|
|
fn load_shared_expert(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
_config: &BitNetModelConfig,
|
|
) -> Result<ExpertWeights> {
|
|
let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_gate_shexp(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} shared expert gate not found", idx))
|
|
})?;
|
|
let up_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_up_shexp(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} shared expert up not found", idx))
|
|
})?;
|
|
let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_down_shexp(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} shared expert down not found", idx))
|
|
})?;
|
|
|
|
Ok(ExpertWeights {
|
|
gate_proj: self.load_ternary_tensor(gguf, &gate_name)?,
|
|
up_proj: self.load_ternary_tensor(gguf, &up_name)?,
|
|
down_proj: self.load_ternary_tensor(gguf, &down_name)?,
|
|
})
|
|
}
|
|
|
|
/// Load routed expert weights, supporting both stacked (3D) and individual tensor formats.
|
|
fn load_experts(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<ExpertWeights>> {
|
|
if TensorNameMapper::has_stacked_experts(gguf, idx) {
|
|
self.load_stacked_experts(gguf, idx, config)
|
|
} else {
|
|
self.load_individual_experts(gguf, idx, config)
|
|
}
|
|
}
|
|
|
|
/// Load stacked expert tensors (3D format: [num_experts, out_dim, in_dim])
|
|
/// and split into per-expert TernaryTensors.
|
|
fn load_stacked_experts(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<ExpertWeights>> {
|
|
let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_gate_exps(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} stacked gate_exps not found", idx))
|
|
})?;
|
|
let up_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_up_exps(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} stacked up_exps not found", idx))
|
|
})?;
|
|
let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_down_exps(idx))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} stacked down_exps not found", idx))
|
|
})?;
|
|
|
|
// Load stacked tensors as FP32 and split per expert
|
|
let gate_all = gguf.load_tensor_f32(&gate_name)?;
|
|
let up_all = gguf.load_tensor_f32(&up_name)?;
|
|
let down_all = gguf.load_tensor_f32(&down_name)?;
|
|
|
|
let num_experts = config.num_experts;
|
|
let intermediate = config.moe_intermediate_size;
|
|
let hidden = config.hidden_size;
|
|
|
|
// gate/up: [num_experts, intermediate_size, hidden_size]
|
|
let gate_per_expert = intermediate * hidden;
|
|
// down: [num_experts, hidden_size, intermediate_size]
|
|
let down_per_expert = hidden * intermediate;
|
|
|
|
let ptconfig = super::quantizer::PtBitnetConfig::default();
|
|
let mut experts = Vec::with_capacity(num_experts);
|
|
|
|
for e in 0..num_experts {
|
|
let gate_start = e * gate_per_expert;
|
|
let gate_end = gate_start + gate_per_expert;
|
|
let gate_slice = if gate_end <= gate_all.len() {
|
|
&gate_all[gate_start..gate_end]
|
|
} else {
|
|
// Insufficient data — create zeros
|
|
&[]
|
|
};
|
|
|
|
let up_start = e * gate_per_expert;
|
|
let up_end = up_start + gate_per_expert;
|
|
let up_slice = if up_end <= up_all.len() {
|
|
&up_all[up_start..up_end]
|
|
} else {
|
|
&[]
|
|
};
|
|
|
|
let down_start = e * down_per_expert;
|
|
let down_end = down_start + down_per_expert;
|
|
let down_slice = if down_end <= down_all.len() {
|
|
&down_all[down_start..down_end]
|
|
} else {
|
|
&[]
|
|
};
|
|
|
|
let gate_proj = if gate_slice.is_empty() {
|
|
TernaryTensor {
|
|
packed_data: vec![],
|
|
scales: vec![],
|
|
shape: (intermediate, hidden),
|
|
block_size: 256,
|
|
}
|
|
} else {
|
|
super::quantizer::quantize_tensor(gate_slice, (intermediate, hidden), &ptconfig)?
|
|
};
|
|
let up_proj = if up_slice.is_empty() {
|
|
TernaryTensor {
|
|
packed_data: vec![],
|
|
scales: vec![],
|
|
shape: (intermediate, hidden),
|
|
block_size: 256,
|
|
}
|
|
} else {
|
|
super::quantizer::quantize_tensor(up_slice, (intermediate, hidden), &ptconfig)?
|
|
};
|
|
let down_proj = if down_slice.is_empty() {
|
|
TernaryTensor {
|
|
packed_data: vec![],
|
|
scales: vec![],
|
|
shape: (hidden, intermediate),
|
|
block_size: 256,
|
|
}
|
|
} else {
|
|
super::quantizer::quantize_tensor(down_slice, (hidden, intermediate), &ptconfig)?
|
|
};
|
|
|
|
experts.push(ExpertWeights {
|
|
gate_proj,
|
|
up_proj,
|
|
down_proj,
|
|
});
|
|
}
|
|
|
|
Ok(experts)
|
|
}
|
|
|
|
/// Load individual expert tensors (HuggingFace naming: `experts.{e}.gate_proj.weight`).
|
|
fn load_individual_experts(
|
|
&self,
|
|
gguf: &GgufFile,
|
|
idx: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<ExpertWeights>> {
|
|
let mut experts = Vec::with_capacity(config.num_experts);
|
|
for e in 0..config.num_experts {
|
|
let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::expert_gate(idx, e))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} expert {} gate_proj not found", idx, e))
|
|
})?;
|
|
let up_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::expert_up(idx, e))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} expert {} up_proj not found", idx, e))
|
|
})?;
|
|
let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::expert_down(idx, e))
|
|
.ok_or_else(|| {
|
|
RuvLLMError::NotFound(format!("Layer {} expert {} down_proj not found", idx, e))
|
|
})?;
|
|
|
|
experts.push(ExpertWeights {
|
|
gate_proj: self.load_ternary_tensor(gguf, &gate_name)?,
|
|
up_proj: self.load_ternary_tensor(gguf, &up_name)?,
|
|
down_proj: self.load_ternary_tensor(gguf, &down_name)?,
|
|
});
|
|
}
|
|
Ok(experts)
|
|
}
|
|
|
|
// ========================================================================
|
|
// Forward Pass
|
|
// ========================================================================
|
|
|
|
/// Run a forward pass for a single token, using the KV cache.
|
|
///
|
|
/// This is the autoregressive path: embed one token, run all layers
|
|
/// with cached K/V from prior positions, return logits.
|
|
///
|
|
/// Call `reset_cache()` before starting a new sequence.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `token_id` - Single token to process
|
|
/// * `position` - Position index in the sequence (0-based)
|
|
pub fn forward_token(&mut self, token_id: u32, position: usize) -> Result<Vec<f32>> {
|
|
let config = self
|
|
.config
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("No model loaded".to_string()))?
|
|
.clone();
|
|
|
|
let hidden = config.hidden_size;
|
|
|
|
if (token_id as usize) >= config.vocab_size {
|
|
return Err(RuvLLMError::Model(format!(
|
|
"Token ID {} exceeds vocab size {}",
|
|
token_id, config.vocab_size
|
|
)));
|
|
}
|
|
|
|
// Periodically rebuild expert predictor from routing history.
|
|
// Rebuild every 16 tokens to amortize the transition matrix cost.
|
|
self.predictor_stale_count += 1;
|
|
if self.predictor_stale_count >= 16 {
|
|
let hist = self.routing_history.lock().unwrap();
|
|
if hist.len() >= 2 {
|
|
self.expert_predictor =
|
|
Some(ExpertPredictor::from_history(config.num_experts, &hist));
|
|
}
|
|
self.predictor_stale_count = 0;
|
|
}
|
|
|
|
// Embedding lookup
|
|
let start = (token_id as usize) * hidden;
|
|
let mut hidden_states: Vec<f32> = self.embedding[start..start + hidden].to_vec();
|
|
|
|
// Transformer layers
|
|
for layer_idx in 0..self.layers.len() {
|
|
hidden_states =
|
|
self.forward_layer_cached(&hidden_states, layer_idx, position, &config)?;
|
|
}
|
|
|
|
// Final RMSNorm
|
|
rms_norm_inplace(&mut hidden_states, &self.final_norm_weight, 1e-6);
|
|
|
|
// LM head: logits = hidden_states @ lm_head^T
|
|
let logits =
|
|
fp32_matvec_transposed(&self.lm_head, &hidden_states, config.vocab_size, hidden);
|
|
|
|
Ok(logits)
|
|
}
|
|
|
|
/// Legacy forward: process full token sequence without KV cache.
|
|
/// Kept for backwards compatibility with tests.
|
|
pub fn forward(&self, token_ids: &[u32]) -> Result<Vec<f32>> {
|
|
let config = self
|
|
.config
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("No model loaded".to_string()))?;
|
|
|
|
if token_ids.is_empty() {
|
|
return Err(RuvLLMError::Model("Empty token sequence".to_string()));
|
|
}
|
|
|
|
let hidden = config.hidden_size;
|
|
let last_token = *token_ids.last().unwrap() as usize;
|
|
if last_token >= config.vocab_size {
|
|
return Err(RuvLLMError::Model(format!(
|
|
"Token ID {} exceeds vocab size {}",
|
|
last_token, config.vocab_size
|
|
)));
|
|
}
|
|
let mut hidden_states: Vec<f32> =
|
|
self.embedding[last_token * hidden..(last_token + 1) * hidden].to_vec();
|
|
|
|
for layer_idx in 0..self.layers.len() {
|
|
hidden_states = self.forward_layer_nocache(&hidden_states, layer_idx, config)?;
|
|
}
|
|
|
|
rms_norm_inplace(&mut hidden_states, &self.final_norm_weight, 1e-6);
|
|
|
|
let logits =
|
|
fp32_matvec_transposed(&self.lm_head, &hidden_states, config.vocab_size, hidden);
|
|
|
|
Ok(logits)
|
|
}
|
|
|
|
/// Forward pass through a single layer with KV cache (autoregressive).
|
|
fn forward_layer_cached(
|
|
&mut self,
|
|
input: &[f32],
|
|
layer_idx: usize,
|
|
position: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
let hidden = config.hidden_size;
|
|
|
|
// --- Pre-attention norm ---
|
|
let mut normed = input.to_vec();
|
|
let layer = &self.layers[layer_idx];
|
|
rms_norm_inplace(&mut normed, &layer.input_norm_weight, 1e-6);
|
|
|
|
// --- Attention (MLA or GQA) ---
|
|
let attn_out = if self.layers[layer_idx].attention.is_mla {
|
|
self.forward_mla_cached(&normed, layer_idx, position, config)?
|
|
} else {
|
|
self.forward_gqa_cached(&normed, layer_idx, position, config)?
|
|
};
|
|
|
|
// --- Output projection ---
|
|
let o_out = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.o_proj,
|
|
&attn_out,
|
|
hidden,
|
|
hidden,
|
|
);
|
|
|
|
// --- Residual after attention ---
|
|
let mut residual: Vec<f32> = input.iter().zip(o_out.iter()).map(|(r, a)| r + a).collect();
|
|
|
|
// --- Post-attention norm ---
|
|
let mut normed_ffn = residual.clone();
|
|
let layer = &self.layers[layer_idx];
|
|
rms_norm_inplace(&mut normed_ffn, &layer.post_attn_norm_weight, 1e-6);
|
|
|
|
// --- FFN (Dense, MoE, or MoE+Shared) ---
|
|
let ffn_out = self.forward_ffn(&normed_ffn, layer_idx, config)?;
|
|
|
|
for (r, &f) in residual.iter_mut().zip(ffn_out.iter()) {
|
|
*r += f;
|
|
}
|
|
|
|
Ok(residual)
|
|
}
|
|
|
|
/// GQA attention with KV cache.
|
|
///
|
|
/// Optimized with 4-wide unrolled dot products and fused score-weighted
|
|
/// value accumulation.
|
|
fn forward_gqa_cached(
|
|
&mut self,
|
|
normed: &[f32],
|
|
layer_idx: usize,
|
|
position: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
let hidden = config.hidden_size;
|
|
let num_heads = config.num_attention_heads;
|
|
let num_kv_heads = config.num_kv_heads;
|
|
let head_dim = hidden / num_heads;
|
|
let kv_dim = num_kv_heads * head_dim;
|
|
|
|
// Q/K/V projections via TL1 GEMV (SIMD-dispatched)
|
|
let q = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.q_proj,
|
|
normed,
|
|
hidden,
|
|
hidden,
|
|
);
|
|
let k = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.k_proj,
|
|
normed,
|
|
kv_dim,
|
|
hidden,
|
|
);
|
|
let v = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.v_proj,
|
|
normed,
|
|
kv_dim,
|
|
hidden,
|
|
);
|
|
|
|
// Apply RoPE to Q and K
|
|
let mut q_rope = q;
|
|
let mut k_rope = k;
|
|
self.apply_rope(&mut q_rope, num_heads, head_dim, position);
|
|
self.apply_rope(&mut k_rope, num_kv_heads, head_dim, position);
|
|
|
|
// Update KV cache
|
|
self.kv_caches[layer_idx].keys.push(k_rope);
|
|
self.kv_caches[layer_idx].values.push(v);
|
|
let seq_len = self.kv_caches[layer_idx].len();
|
|
|
|
// GQA attention scores with 4-wide dot product
|
|
let gqa_groups = if num_kv_heads > 0 {
|
|
num_heads / num_kv_heads
|
|
} else {
|
|
1
|
|
};
|
|
let inv_sqrt_d = 1.0 / (head_dim as f32).sqrt();
|
|
let mut attn_out = vec![0.0f32; hidden];
|
|
let dim_chunks = head_dim / 4;
|
|
let dim_tail = dim_chunks * 4;
|
|
|
|
for h in 0..num_heads {
|
|
let kv_head = h / gqa_groups;
|
|
let q_offset = h * head_dim;
|
|
let k_offset = kv_head * head_dim;
|
|
|
|
let mut scores = Vec::with_capacity(seq_len);
|
|
for pos in 0..seq_len {
|
|
let k_vec = &self.kv_caches[layer_idx].keys[pos];
|
|
// 4-wide unrolled dot product
|
|
let mut d0 = 0.0f32;
|
|
let mut d1 = 0.0f32;
|
|
let mut d2 = 0.0f32;
|
|
let mut d3 = 0.0f32;
|
|
for c in 0..dim_chunks {
|
|
let d = c * 4;
|
|
unsafe {
|
|
d0 += *q_rope.get_unchecked(q_offset + d)
|
|
* *k_vec.get_unchecked(k_offset + d);
|
|
d1 += *q_rope.get_unchecked(q_offset + d + 1)
|
|
* *k_vec.get_unchecked(k_offset + d + 1);
|
|
d2 += *q_rope.get_unchecked(q_offset + d + 2)
|
|
* *k_vec.get_unchecked(k_offset + d + 2);
|
|
d3 += *q_rope.get_unchecked(q_offset + d + 3)
|
|
* *k_vec.get_unchecked(k_offset + d + 3);
|
|
}
|
|
}
|
|
let mut dot = d0 + d1 + d2 + d3;
|
|
for d in dim_tail..head_dim {
|
|
dot += q_rope[q_offset + d] * k_vec[k_offset + d];
|
|
}
|
|
scores.push(dot * inv_sqrt_d);
|
|
}
|
|
|
|
softmax_inplace(&mut scores);
|
|
|
|
// Weighted value accumulation
|
|
let v_offset = kv_head * head_dim;
|
|
for pos in 0..seq_len {
|
|
let v_vec = &self.kv_caches[layer_idx].values[pos];
|
|
let w = scores[pos];
|
|
if w < 1e-10 {
|
|
continue;
|
|
} // Skip negligible weights
|
|
for d in 0..head_dim {
|
|
unsafe {
|
|
*attn_out.get_unchecked_mut(q_offset + d) +=
|
|
w * *v_vec.get_unchecked(v_offset + d);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(attn_out)
|
|
}
|
|
|
|
/// MLA (Multi-Head Latent Attention) with KV cache.
|
|
///
|
|
/// Forward path:
|
|
/// 1. Q: x → W_q_a → RMSNorm → W_q_b → split(Q_nope, Q_rope) → RoPE(Q_rope)
|
|
/// 2. KV: x → W_kv_a → split(c_kv, k_pe) → RoPE(k_pe)
|
|
/// K: RMSNorm(c_kv) → W_k_b → K_nope → concat(K_nope, K_rope)
|
|
/// V: c_kv → W_v_b → V
|
|
/// 3. Standard multi-head attention on concatenated Q/K
|
|
///
|
|
/// When `use_compressed_kv` is enabled, stores only compressed latents (c_kv + k_pe)
|
|
/// instead of full K/V vectors (~17.8x memory reduction), recomputing K_nope and V
|
|
/// from cached latents during attention.
|
|
fn forward_mla_cached(
|
|
&mut self,
|
|
normed: &[f32],
|
|
layer_idx: usize,
|
|
position: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
let hidden = config.hidden_size;
|
|
let num_heads = config.num_attention_heads;
|
|
let q_lora_rank = config.q_lora_rank;
|
|
let kv_lora_rank = config.kv_lora_rank;
|
|
let qk_nope_dim = config.qk_nope_head_dim;
|
|
let qk_rope_dim = config.qk_rope_head_dim;
|
|
let v_dim = config.v_head_dim;
|
|
let q_head_dim = qk_nope_dim + qk_rope_dim;
|
|
let kv_a_out = kv_lora_rank + qk_rope_dim;
|
|
|
|
let attn = &self.layers[layer_idx].attention;
|
|
|
|
// --- Q path ---
|
|
let q_a = attn
|
|
.q_a
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA q_a missing".into()))?;
|
|
let mut c_q = self.tl1_gemv(q_a, normed, q_lora_rank, hidden);
|
|
|
|
if let Some(ref norm_w) = attn.q_a_norm {
|
|
rms_norm_inplace(&mut c_q, norm_w, 1e-6);
|
|
}
|
|
|
|
let q_b = attn
|
|
.q_b
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA q_b missing".into()))?;
|
|
let q_full = self.tl1_gemv(q_b, &c_q, num_heads * q_head_dim, q_lora_rank);
|
|
|
|
// Split Q into nope and rope parts, apply RoPE
|
|
let mut q_nope = vec![0.0f32; num_heads * qk_nope_dim];
|
|
let mut q_rope_part = vec![0.0f32; num_heads * qk_rope_dim];
|
|
|
|
for h in 0..num_heads {
|
|
let src = h * q_head_dim;
|
|
let nope_dst = h * qk_nope_dim;
|
|
let rope_dst = h * qk_rope_dim;
|
|
q_nope[nope_dst..nope_dst + qk_nope_dim]
|
|
.copy_from_slice(&q_full[src..src + qk_nope_dim]);
|
|
q_rope_part[rope_dst..rope_dst + qk_rope_dim]
|
|
.copy_from_slice(&q_full[src + qk_nope_dim..src + q_head_dim]);
|
|
}
|
|
|
|
self.apply_rope(&mut q_rope_part, num_heads, qk_rope_dim, position);
|
|
|
|
// Build full Q by concatenating Q_nope + Q_rope per head
|
|
let mut q_full_concat = vec![0.0f32; num_heads * q_head_dim];
|
|
for h in 0..num_heads {
|
|
let dst = h * q_head_dim;
|
|
let nope_src = h * qk_nope_dim;
|
|
let rope_src = h * qk_rope_dim;
|
|
q_full_concat[dst..dst + qk_nope_dim]
|
|
.copy_from_slice(&q_nope[nope_src..nope_src + qk_nope_dim]);
|
|
q_full_concat[dst + qk_nope_dim..dst + q_head_dim]
|
|
.copy_from_slice(&q_rope_part[rope_src..rope_src + qk_rope_dim]);
|
|
}
|
|
|
|
// --- KV path ---
|
|
let kv_a = attn
|
|
.kv_a_mqa
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA kv_a_mqa missing".into()))?;
|
|
let kv_combined = self.tl1_gemv(kv_a, normed, kv_a_out, hidden);
|
|
|
|
let c_kv_raw = kv_combined[..kv_lora_rank].to_vec();
|
|
let mut k_pe = kv_combined[kv_lora_rank..].to_vec();
|
|
self.apply_rope(&mut k_pe, 1, qk_rope_dim, position);
|
|
|
|
// --- Attention dispatch: compressed or full KV cache ---
|
|
if self.use_compressed_kv {
|
|
// COMPRESSED PATH: store only c_kv + k_pe, recompute K/V during attention.
|
|
// ~17.8x memory savings at the cost of per-position recomputation.
|
|
self.mla_caches[layer_idx].push(c_kv_raw.clone(), k_pe.clone());
|
|
let seq_len = self.mla_caches[layer_idx].len();
|
|
|
|
let k_b = self.layers[layer_idx]
|
|
.attention
|
|
.k_b
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA k_b missing".into()))?;
|
|
let v_b = self.layers[layer_idx]
|
|
.attention
|
|
.v_b
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA v_b missing".into()))?;
|
|
|
|
let inv_sqrt_d = 1.0 / (q_head_dim as f32).sqrt();
|
|
let mut attn_out = vec![0.0f32; num_heads * v_dim];
|
|
|
|
for h in 0..num_heads {
|
|
let q_off = h * q_head_dim;
|
|
|
|
let mut scores = Vec::with_capacity(seq_len);
|
|
for pos in 0..seq_len {
|
|
// Recompute K for this cached position from compressed latent
|
|
let cached_ckv = &self.mla_caches[layer_idx].c_kv[pos];
|
|
let cached_kpe = &self.mla_caches[layer_idx].k_pe[pos];
|
|
|
|
let mut ckv_normed = cached_ckv.clone();
|
|
if let Some(ref norm_w) = self.layers[layer_idx].attention.kv_a_norm {
|
|
rms_norm_inplace(&mut ckv_normed, norm_w, 1e-6);
|
|
}
|
|
|
|
let k_nope =
|
|
self.tl1_gemv(k_b, &ckv_normed, num_heads * qk_nope_dim, kv_lora_rank);
|
|
|
|
// Build K for this head: [K_nope_h | K_rope]
|
|
let nope_off = h * qk_nope_dim;
|
|
let mut dot = 0.0f32;
|
|
// Dot with nope portion
|
|
for d in 0..qk_nope_dim {
|
|
dot += q_full_concat[q_off + d] * k_nope[nope_off + d];
|
|
}
|
|
// Dot with rope portion (shared across heads)
|
|
for d in 0..qk_rope_dim {
|
|
dot += q_full_concat[q_off + qk_nope_dim + d] * cached_kpe[d];
|
|
}
|
|
scores.push(dot * inv_sqrt_d);
|
|
}
|
|
|
|
softmax_inplace(&mut scores);
|
|
|
|
// Weighted value accumulation (recompute V from cached c_kv)
|
|
let v_off = h * v_dim;
|
|
for pos in 0..seq_len {
|
|
let w = scores[pos];
|
|
if w < 1e-10 {
|
|
continue;
|
|
}
|
|
|
|
let cached_ckv = &self.mla_caches[layer_idx].c_kv[pos];
|
|
let v_full = self.tl1_gemv(v_b, cached_ckv, num_heads * v_dim, kv_lora_rank);
|
|
for d in 0..v_dim {
|
|
attn_out[v_off + d] += w * v_full[h * v_dim + d];
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(attn_out)
|
|
} else {
|
|
// FULL PATH: expand K/V and store in standard KV cache (fast, more memory).
|
|
let mut c_kv_normed = c_kv_raw;
|
|
if let Some(ref norm_w) = self.layers[layer_idx].attention.kv_a_norm {
|
|
rms_norm_inplace(&mut c_kv_normed, norm_w, 1e-6);
|
|
}
|
|
|
|
let k_b = self.layers[layer_idx]
|
|
.attention
|
|
.k_b
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA k_b missing".into()))?;
|
|
let k_nope = self.tl1_gemv(k_b, &c_kv_normed, num_heads * qk_nope_dim, kv_lora_rank);
|
|
|
|
let v_b = self.layers[layer_idx]
|
|
.attention
|
|
.v_b
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA v_b missing".into()))?;
|
|
let c_kv_for_v = &kv_combined[..kv_lora_rank];
|
|
let v_full = self.tl1_gemv(v_b, c_kv_for_v, num_heads * v_dim, kv_lora_rank);
|
|
|
|
// Build full K
|
|
let mut k_full = vec![0.0f32; num_heads * q_head_dim];
|
|
for h in 0..num_heads {
|
|
let dst = h * q_head_dim;
|
|
let nope_src = h * qk_nope_dim;
|
|
k_full[dst..dst + qk_nope_dim]
|
|
.copy_from_slice(&k_nope[nope_src..nope_src + qk_nope_dim]);
|
|
k_full[dst + qk_nope_dim..dst + q_head_dim].copy_from_slice(&k_pe[..qk_rope_dim]);
|
|
}
|
|
|
|
// Update KV cache
|
|
self.kv_caches[layer_idx].keys.push(k_full);
|
|
self.kv_caches[layer_idx].values.push(v_full);
|
|
let seq_len = self.kv_caches[layer_idx].len();
|
|
|
|
// Multi-head attention
|
|
let inv_sqrt_d = 1.0 / (q_head_dim as f32).sqrt();
|
|
let mut attn_out = vec![0.0f32; num_heads * v_dim];
|
|
|
|
for h in 0..num_heads {
|
|
let q_off = h * q_head_dim;
|
|
|
|
let mut scores = Vec::with_capacity(seq_len);
|
|
for pos in 0..seq_len {
|
|
let k_vec = &self.kv_caches[layer_idx].keys[pos];
|
|
let k_off = h * q_head_dim;
|
|
let mut dot = 0.0f32;
|
|
for d in 0..q_head_dim {
|
|
dot += q_full_concat[q_off + d] * k_vec[k_off + d];
|
|
}
|
|
scores.push(dot * inv_sqrt_d);
|
|
}
|
|
|
|
softmax_inplace(&mut scores);
|
|
|
|
let v_off = h * v_dim;
|
|
for pos in 0..seq_len {
|
|
let v_vec = &self.kv_caches[layer_idx].values[pos];
|
|
let w = scores[pos];
|
|
for d in 0..v_dim {
|
|
attn_out[v_off + d] += w * v_vec[h * v_dim + d];
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(attn_out)
|
|
}
|
|
}
|
|
|
|
/// Unified FFN forward: dispatches to dense, MoE, or MoE+shared based on layer type.
|
|
///
|
|
/// For MoE layers, tracks routing decisions in `self.routing_history` to
|
|
/// enable predictive expert prefetching via `ExpertPredictor`.
|
|
fn forward_ffn(
|
|
&self,
|
|
normed_ffn: &[f32],
|
|
layer_idx: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
let hidden = config.hidden_size;
|
|
let layer = &self.layers[layer_idx];
|
|
|
|
match layer.layer_type {
|
|
LayerType::Dense => {
|
|
// Dense FFN: single gate/up/down
|
|
let ffn = layer.dense_ffn.as_ref().ok_or_else(|| {
|
|
RuvLLMError::Model(format!("Layer {} is Dense but has no dense_ffn", layer_idx))
|
|
})?;
|
|
self.expert_forward(normed_ffn, ffn, config)
|
|
}
|
|
LayerType::Moe | LayerType::MoeWithShared => {
|
|
// Predictive prefetch: touch predicted expert weight data before routing.
|
|
// This pulls weight cache lines into L2/L3 during the router computation,
|
|
// hiding memory latency for the upcoming expert GEMVs.
|
|
if let Some(ref predictor) = self.expert_predictor {
|
|
let hist = self.routing_history.lock().unwrap();
|
|
if let Some(last) = hist.last() {
|
|
let predicted = predictor.predict_next(last, config.active_experts);
|
|
let experts = &self.layers[layer_idx].experts;
|
|
for &eidx in &predicted {
|
|
if eidx < experts.len() {
|
|
// Touch first cache line of gate_proj packed data
|
|
let data = &experts[eidx].gate_proj.packed_data;
|
|
if !data.is_empty() {
|
|
// Volatile read forces the load, acting as software prefetch
|
|
unsafe {
|
|
std::ptr::read_volatile(data.as_ptr());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Route to top-K experts
|
|
let (indices, weights) =
|
|
self.route_experts(normed_ffn, &self.layers[layer_idx].gate_weight, config)?;
|
|
|
|
// Track routing decisions from the first MoE layer for expert prediction.
|
|
// For GLM-4.7-Flash, layer 0 is Dense (first_k_dense_replace=1), so
|
|
// the first MoE layer is at index first_k_dense_replace.
|
|
if layer_idx == config.first_k_dense_replace {
|
|
let mut hist = self.routing_history.lock().unwrap();
|
|
hist.push(indices.clone());
|
|
if hist.len() > self.max_routing_history {
|
|
hist.remove(0);
|
|
}
|
|
}
|
|
|
|
let mut output = vec![0.0f32; hidden];
|
|
|
|
// Routed experts
|
|
let experts = &self.layers[layer_idx].experts;
|
|
for (&eidx, &ew) in indices.iter().zip(weights.iter()) {
|
|
if eidx >= experts.len() {
|
|
continue;
|
|
}
|
|
let e_out = self.expert_forward(normed_ffn, &experts[eidx], config)?;
|
|
for (o, &e) in output.iter_mut().zip(e_out.iter()) {
|
|
*o += ew * e;
|
|
}
|
|
}
|
|
|
|
// Shared expert (MoeWithShared only)
|
|
if layer.layer_type == LayerType::MoeWithShared {
|
|
if let Some(ref shared) = self.layers[layer_idx].shared_expert {
|
|
let s_out = self.expert_forward(normed_ffn, shared, config)?;
|
|
for (o, &s) in output.iter_mut().zip(s_out.iter()) {
|
|
*o += s;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(output)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Forward pass through a single layer WITHOUT KV cache (legacy path).
|
|
fn forward_layer_nocache(
|
|
&self,
|
|
input: &[f32],
|
|
layer_idx: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
let hidden = config.hidden_size;
|
|
|
|
let mut normed = input.to_vec();
|
|
rms_norm_inplace(&mut normed, &self.layers[layer_idx].input_norm_weight, 1e-6);
|
|
|
|
// Attention: single-position (degenerates to V pass-through for GQA)
|
|
let attn_concat = if self.layers[layer_idx].attention.is_mla {
|
|
// MLA single-position: project through full pipeline but attention = identity
|
|
self.forward_mla_single_position(&normed, layer_idx, config)?
|
|
} else {
|
|
// GQA single-position: V expanded to all heads
|
|
let num_heads = config.num_attention_heads;
|
|
let head_dim = hidden / num_heads;
|
|
let kv_dim = config.num_kv_heads * head_dim;
|
|
let gqa_groups = if config.num_kv_heads > 0 {
|
|
num_heads / config.num_kv_heads
|
|
} else {
|
|
1
|
|
};
|
|
|
|
let q = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.q_proj,
|
|
&normed,
|
|
hidden,
|
|
hidden,
|
|
);
|
|
let k = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.k_proj,
|
|
&normed,
|
|
kv_dim,
|
|
hidden,
|
|
);
|
|
let v = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.v_proj,
|
|
&normed,
|
|
kv_dim,
|
|
hidden,
|
|
);
|
|
let _ = (q, k); // Exercise projections
|
|
|
|
let mut concat = vec![0.0f32; hidden];
|
|
for h in 0..num_heads {
|
|
let kv_head = h / gqa_groups;
|
|
for d in 0..head_dim {
|
|
concat[h * head_dim + d] = v[kv_head * head_dim + d];
|
|
}
|
|
}
|
|
concat
|
|
};
|
|
|
|
let o_out = self.tl1_gemv(
|
|
&self.layers[layer_idx].attention.o_proj,
|
|
&attn_concat,
|
|
hidden,
|
|
hidden,
|
|
);
|
|
let mut residual: Vec<f32> = input.iter().zip(o_out.iter()).map(|(r, a)| r + a).collect();
|
|
|
|
let mut normed_ffn = residual.clone();
|
|
rms_norm_inplace(
|
|
&mut normed_ffn,
|
|
&self.layers[layer_idx].post_attn_norm_weight,
|
|
1e-6,
|
|
);
|
|
|
|
let ffn_out = self.forward_ffn(&normed_ffn, layer_idx, config)?;
|
|
|
|
for (r, &f) in residual.iter_mut().zip(ffn_out.iter()) {
|
|
*r += f;
|
|
}
|
|
|
|
Ok(residual)
|
|
}
|
|
|
|
/// MLA forward for single-position (no KV cache). Used in legacy forward path.
|
|
fn forward_mla_single_position(
|
|
&self,
|
|
normed: &[f32],
|
|
layer_idx: usize,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
let hidden = config.hidden_size;
|
|
let num_heads = config.num_attention_heads;
|
|
let q_lora_rank = config.q_lora_rank;
|
|
let kv_lora_rank = config.kv_lora_rank;
|
|
let v_dim = config.v_head_dim;
|
|
let kv_a_out = kv_lora_rank + config.qk_rope_head_dim;
|
|
|
|
let attn = &self.layers[layer_idx].attention;
|
|
|
|
// Q path (exercise projections)
|
|
if let Some(ref q_a) = attn.q_a {
|
|
let mut c_q = self.tl1_gemv(q_a, normed, q_lora_rank, hidden);
|
|
if let Some(ref norm_w) = attn.q_a_norm {
|
|
rms_norm_inplace(&mut c_q, norm_w, 1e-6);
|
|
}
|
|
if let Some(ref q_b) = attn.q_b {
|
|
let _q = self.tl1_gemv(
|
|
q_b,
|
|
&c_q,
|
|
num_heads * (config.qk_nope_head_dim + config.qk_rope_head_dim),
|
|
q_lora_rank,
|
|
);
|
|
}
|
|
}
|
|
|
|
// KV path
|
|
let kv_a = self.layers[layer_idx]
|
|
.attention
|
|
.kv_a_mqa
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA kv_a_mqa missing in nocache path".into()))?;
|
|
let kv_combined = self.tl1_gemv(kv_a, normed, kv_a_out, hidden);
|
|
let c_kv = &kv_combined[..kv_lora_rank];
|
|
|
|
// V = c_kv @ W_v_b
|
|
let v_b = self.layers[layer_idx]
|
|
.attention
|
|
.v_b
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("MLA v_b missing".into()))?;
|
|
let v_full = self.tl1_gemv(v_b, c_kv, num_heads * v_dim, kv_lora_rank);
|
|
|
|
// Single position: attention is identity, output = V directly
|
|
Ok(v_full)
|
|
}
|
|
|
|
/// Apply Rotary Position Embedding (RoPE) in-place.
|
|
///
|
|
/// For each head, rotates pairs of dimensions (2i, 2i+1) by position-dependent angles.
|
|
fn apply_rope(&self, x: &mut [f32], num_heads: usize, head_dim: usize, position: usize) {
|
|
let half = head_dim / 2;
|
|
let max_seq = self.rope_cos.len() / half;
|
|
if position >= max_seq {
|
|
return; // Beyond pre-computed tables — skip RoPE
|
|
}
|
|
let cos_base = position * half;
|
|
for h in 0..num_heads {
|
|
let offset = h * head_dim;
|
|
for i in 0..half {
|
|
let cos_val = self.rope_cos[cos_base + i];
|
|
let sin_val = self.rope_sin[cos_base + i];
|
|
let x0 = x[offset + 2 * i];
|
|
let x1 = x[offset + 2 * i + 1];
|
|
x[offset + 2 * i] = x0 * cos_val - x1 * sin_val;
|
|
x[offset + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ========================================================================
|
|
// MoE Router
|
|
// ========================================================================
|
|
|
|
/// Route hidden states to the top-K experts.
|
|
///
|
|
/// Computes `scores = hidden_states @ gate_weight^T`, applies softmax,
|
|
/// then selects the top-K experts with highest scores.
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// Tuple of (expert_indices, expert_weights) both of length active_experts.
|
|
fn route_experts(
|
|
&self,
|
|
hidden_states: &[f32],
|
|
gate_weight: &[f32],
|
|
config: &BitNetModelConfig,
|
|
) -> Result<(Vec<usize>, Vec<f32>)> {
|
|
let num_experts = config.num_experts;
|
|
let hidden = config.hidden_size;
|
|
// Clamp top_k to num_experts to prevent selecting more experts than exist
|
|
let top_k = config.active_experts.min(num_experts);
|
|
|
|
if num_experts == 0 {
|
|
return Ok((vec![], vec![]));
|
|
}
|
|
|
|
// Gate: scores[e] = dot(hidden_states, gate_weight[e])
|
|
let mut scores = vec![0.0f32; num_experts];
|
|
for e in 0..num_experts {
|
|
let row_start = e * hidden;
|
|
if row_start + hidden > gate_weight.len() {
|
|
break;
|
|
}
|
|
let mut dot = 0.0f32;
|
|
for j in 0..hidden {
|
|
dot += hidden_states[j] * gate_weight[row_start + j];
|
|
}
|
|
scores[e] = dot;
|
|
}
|
|
|
|
// Softmax over expert scores
|
|
softmax_inplace(&mut scores);
|
|
|
|
// Top-K selection
|
|
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
|
|
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
let selected: Vec<(usize, f32)> = indexed.into_iter().take(top_k).collect();
|
|
|
|
// Renormalize selected weights so they sum to 1
|
|
let weight_sum: f32 = selected.iter().map(|(_, w)| w).sum();
|
|
let norm_factor = if weight_sum > 1e-12 {
|
|
1.0 / weight_sum
|
|
} else {
|
|
1.0
|
|
};
|
|
|
|
let expert_indices: Vec<usize> = selected.iter().map(|(i, _)| *i).collect();
|
|
let expert_weights: Vec<f32> = selected.iter().map(|(_, w)| w * norm_factor).collect();
|
|
|
|
Ok((expert_indices, expert_weights))
|
|
}
|
|
|
|
// ========================================================================
|
|
// Expert FFN (TL1 GEMV)
|
|
// ========================================================================
|
|
|
|
/// Forward pass through a single expert's SwiGLU FFN.
|
|
///
|
|
/// Fused implementation: gate and up projections are computed, then
|
|
/// SiLU(gate) * up is fused in a single pass to halve memory traffic.
|
|
///
|
|
/// Computes:
|
|
/// ```text
|
|
/// gate = TL1_GEMV(gate_proj, input)
|
|
/// up = TL1_GEMV(up_proj, input)
|
|
/// hidden = silu(gate) * up [FUSED: single pass]
|
|
/// output = TL1_GEMV(down_proj, hidden)
|
|
/// ```
|
|
fn expert_forward(
|
|
&self,
|
|
input: &[f32],
|
|
expert: &ExpertWeights,
|
|
config: &BitNetModelConfig,
|
|
) -> Result<Vec<f32>> {
|
|
let intermediate = config.intermediate_size;
|
|
let hidden = config.hidden_size;
|
|
|
|
// gate_proj and up_proj GEMVs
|
|
let gate_out = self.tl1_gemv(&expert.gate_proj, input, intermediate, hidden);
|
|
let up_out = self.tl1_gemv(&expert.up_proj, input, intermediate, hidden);
|
|
|
|
// Fused SiLU(gate) * up — single pass with 4-wide unroll
|
|
let mut fused = vec![0.0f32; intermediate];
|
|
let chunks = intermediate / 4;
|
|
let remainder = intermediate % 4;
|
|
|
|
// Unrolled 4-wide loop — keeps gate/up values in registers
|
|
for c in 0..chunks {
|
|
let base = c * 4;
|
|
unsafe {
|
|
let g0 = *gate_out.get_unchecked(base);
|
|
let g1 = *gate_out.get_unchecked(base + 1);
|
|
let g2 = *gate_out.get_unchecked(base + 2);
|
|
let g3 = *gate_out.get_unchecked(base + 3);
|
|
let u0 = *up_out.get_unchecked(base);
|
|
let u1 = *up_out.get_unchecked(base + 1);
|
|
let u2 = *up_out.get_unchecked(base + 2);
|
|
let u3 = *up_out.get_unchecked(base + 3);
|
|
*fused.get_unchecked_mut(base) = g0 * sigmoid(g0) * u0;
|
|
*fused.get_unchecked_mut(base + 1) = g1 * sigmoid(g1) * u1;
|
|
*fused.get_unchecked_mut(base + 2) = g2 * sigmoid(g2) * u2;
|
|
*fused.get_unchecked_mut(base + 3) = g3 * sigmoid(g3) * u3;
|
|
}
|
|
}
|
|
let tail_start = chunks * 4;
|
|
for i in 0..remainder {
|
|
let idx = tail_start + i;
|
|
fused[idx] = gate_out[idx] * sigmoid(gate_out[idx]) * up_out[idx];
|
|
}
|
|
|
|
// down_proj
|
|
let output = self.tl1_gemv(&expert.down_proj, &fused, hidden, intermediate);
|
|
|
|
Ok(output)
|
|
}
|
|
|
|
/// TL1 GEMV: ternary matrix-vector product with automatic SIMD dispatch.
|
|
///
|
|
/// Delegates to AVX2 kernel on x86_64 (16 elements/iter via vpshufb LUT +
|
|
/// INT16 madd), with scalar LUT fallback on other architectures.
|
|
///
|
|
/// Computes `output[i] = sum_j(ternary_weight[i,j] * input[j]) * scale[block]`
|
|
#[inline]
|
|
fn tl1_gemv(
|
|
&self,
|
|
weight: &TernaryTensor,
|
|
input: &[f32],
|
|
out_rows: usize,
|
|
in_cols: usize,
|
|
) -> Vec<f32> {
|
|
let mut output = vec![0.0f32; out_rows];
|
|
if out_rows == 0 || in_cols == 0 || weight.packed_data.is_empty() {
|
|
return output;
|
|
}
|
|
Self::tl1_gemv_dispatch(
|
|
&self.tl1_lut,
|
|
&weight.packed_data,
|
|
&weight.scales,
|
|
input,
|
|
&mut output,
|
|
out_rows,
|
|
in_cols,
|
|
weight.block_size,
|
|
);
|
|
output
|
|
}
|
|
|
|
/// TL1 GEMV into a pre-allocated output buffer (zero-alloc hot path).
|
|
///
|
|
/// The caller must ensure `output.len() >= out_rows`.
|
|
#[inline]
|
|
fn tl1_gemv_into(
|
|
&self,
|
|
weight: &TernaryTensor,
|
|
input: &[f32],
|
|
output: &mut [f32],
|
|
out_rows: usize,
|
|
in_cols: usize,
|
|
) {
|
|
for v in output[..out_rows].iter_mut() {
|
|
*v = 0.0;
|
|
}
|
|
if out_rows == 0 || in_cols == 0 || weight.packed_data.is_empty() {
|
|
return;
|
|
}
|
|
Self::tl1_gemv_dispatch(
|
|
&self.tl1_lut,
|
|
&weight.packed_data,
|
|
&weight.scales,
|
|
input,
|
|
&mut output[..out_rows],
|
|
out_rows,
|
|
in_cols,
|
|
weight.block_size,
|
|
);
|
|
}
|
|
|
|
/// Dispatch TL1 GEMV to AVX2 SIMD when available, otherwise scalar LUT path.
|
|
#[inline]
|
|
fn tl1_gemv_dispatch(
|
|
lut: &[[i8; 4]; 256],
|
|
packed_data: &[u8],
|
|
scales: &[f32],
|
|
input: &[f32],
|
|
output: &mut [f32],
|
|
out_rows: usize,
|
|
in_cols: usize,
|
|
block_size: usize,
|
|
) {
|
|
// AVX2 SIMD path (compile-time gate + runtime dispatch inside tl1_avx2)
|
|
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
|
|
{
|
|
super::tl1_avx2::tl1_gemv(
|
|
packed_data,
|
|
scales,
|
|
input,
|
|
output,
|
|
out_rows,
|
|
in_cols,
|
|
block_size,
|
|
);
|
|
return;
|
|
}
|
|
|
|
// Scalar LUT fallback for non-AVX2 platforms
|
|
#[allow(unreachable_code)]
|
|
{
|
|
let bytes_per_row = (in_cols + 3) / 4;
|
|
let blocks_per_row = (in_cols + block_size - 1) / block_size;
|
|
|
|
for row in 0..out_rows {
|
|
let row_byte_offset = row * bytes_per_row;
|
|
let row_scale_offset = row * blocks_per_row;
|
|
let mut accum = 0.0f32;
|
|
|
|
for blk in 0..blocks_per_row {
|
|
let scale = scales.get(row_scale_offset + blk).copied().unwrap_or(1.0);
|
|
|
|
let blk_start = blk * block_size;
|
|
let blk_end = (blk_start + block_size).min(in_cols);
|
|
let mut block_accum = 0.0f32;
|
|
let mut c = blk_start;
|
|
|
|
// Process 4 elements at a time via LUT
|
|
while c + 4 <= blk_end {
|
|
let byte_idx = row_byte_offset + c / 4;
|
|
if byte_idx >= packed_data.len() {
|
|
break;
|
|
}
|
|
let ternary = &lut[packed_data[byte_idx] as usize];
|
|
for k in 0..4 {
|
|
let t = ternary[k];
|
|
if t == 1 {
|
|
block_accum += input[c + k];
|
|
} else if t == -1 {
|
|
block_accum -= input[c + k];
|
|
}
|
|
}
|
|
c += 4;
|
|
}
|
|
|
|
// Handle tail
|
|
while c < blk_end {
|
|
let byte_idx = row_byte_offset + c / 4;
|
|
let bit_pos = c % 4;
|
|
if byte_idx < packed_data.len() {
|
|
let t = lut[packed_data[byte_idx] as usize][bit_pos];
|
|
if t == 1 {
|
|
block_accum += input[c];
|
|
} else if t == -1 {
|
|
block_accum -= input[c];
|
|
}
|
|
}
|
|
c += 1;
|
|
}
|
|
|
|
accum += block_accum * scale;
|
|
}
|
|
|
|
output[row] += accum;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ========================================================================
|
|
// Tensor Discovery & Model Validation
|
|
// ========================================================================
|
|
|
|
/// Discover and classify all tensors in a GGUF file.
|
|
///
|
|
/// Returns a structured report of found tensors, grouped by type
|
|
/// (embedding, attention, FFN, norm, etc.), with shape and quantization info.
|
|
pub fn discover_tensors(path: &str) -> Result<TensorDiscoveryReport> {
|
|
let gguf = GgufFile::open_mmap(Path::new(path))?;
|
|
let mut report = TensorDiscoveryReport {
|
|
total_tensors: gguf.tensors.len(),
|
|
total_bytes: gguf.total_tensor_size(),
|
|
architecture: gguf.architecture().map(|s| s.to_string()),
|
|
tensor_groups: Vec::new(),
|
|
warnings: Vec::new(),
|
|
};
|
|
|
|
// Classify tensors
|
|
let mut embedding = Vec::new();
|
|
let mut attention = Vec::new();
|
|
let mut ffn = Vec::new();
|
|
let mut norm = Vec::new();
|
|
let mut other = Vec::new();
|
|
|
|
for t in &gguf.tensors {
|
|
let info = TensorEntry {
|
|
name: t.name.clone(),
|
|
shape: t.shape.clone(),
|
|
dtype: t.dtype.name().to_string(),
|
|
bytes: t.byte_size(),
|
|
};
|
|
|
|
if t.name.contains("embd") || t.name.contains("embed") || t.name == "output.weight" {
|
|
embedding.push(info);
|
|
} else if t.name.contains("attn") || t.name.contains("self_attn") {
|
|
attention.push(info);
|
|
} else if t.name.contains("ffn") || t.name.contains("mlp") || t.name.contains("expert")
|
|
{
|
|
ffn.push(info);
|
|
} else if t.name.contains("norm") {
|
|
norm.push(info);
|
|
} else {
|
|
other.push(info);
|
|
}
|
|
}
|
|
|
|
if !embedding.is_empty() {
|
|
report.tensor_groups.push(TensorGroup {
|
|
name: "Embedding/Output".into(),
|
|
tensors: embedding,
|
|
});
|
|
}
|
|
if !norm.is_empty() {
|
|
report.tensor_groups.push(TensorGroup {
|
|
name: "Normalization".into(),
|
|
tensors: norm,
|
|
});
|
|
}
|
|
if !attention.is_empty() {
|
|
report.tensor_groups.push(TensorGroup {
|
|
name: "Attention".into(),
|
|
tensors: attention,
|
|
});
|
|
}
|
|
if !ffn.is_empty() {
|
|
report.tensor_groups.push(TensorGroup {
|
|
name: "FFN/Expert".into(),
|
|
tensors: ffn,
|
|
});
|
|
}
|
|
if !other.is_empty() {
|
|
report.tensor_groups.push(TensorGroup {
|
|
name: "Other".into(),
|
|
tensors: other,
|
|
});
|
|
}
|
|
|
|
// Detect naming convention
|
|
let has_blk = gguf.tensors.iter().any(|t| t.name.starts_with("blk."));
|
|
let has_model = gguf.tensors.iter().any(|t| t.name.starts_with("model."));
|
|
if has_blk && has_model {
|
|
report
|
|
.warnings
|
|
.push("Mixed naming conventions detected (blk.* and model.*)".into());
|
|
}
|
|
|
|
// Detect MLA
|
|
let has_mla = gguf.tensors.iter().any(|t| t.name.contains("attn_q_a"));
|
|
if has_mla {
|
|
report
|
|
.warnings
|
|
.push("MLA (Multi-Head Latent Attention) tensors detected".into());
|
|
}
|
|
|
|
// Detect stacked experts
|
|
let has_exps = gguf.tensors.iter().any(|t| t.name.contains("_exps"));
|
|
if has_exps {
|
|
report
|
|
.warnings
|
|
.push("Stacked expert tensors detected (3D format)".into());
|
|
}
|
|
|
|
Ok(report)
|
|
}
|
|
|
|
/// Validate that a GGUF file has all required tensors for loading.
|
|
///
|
|
/// Returns a list of missing tensor names and a boolean indicating
|
|
/// whether the model can be loaded.
|
|
pub fn validate_model(path: &str) -> Result<ModelValidation> {
|
|
let gguf = GgufFile::open_mmap(Path::new(path))?;
|
|
let backend = BitNetBackend::new();
|
|
let config = backend.extract_config(&gguf)?;
|
|
let mut missing = Vec::new();
|
|
let mut found = Vec::new();
|
|
|
|
// Check global tensors
|
|
for (label, candidates) in [
|
|
("Embedding", TensorNameMapper::embedding()),
|
|
("Output/LM Head", TensorNameMapper::output()),
|
|
("Final Norm", TensorNameMapper::final_norm()),
|
|
] {
|
|
if let Some(name) = TensorNameMapper::resolve(&gguf, &candidates) {
|
|
found.push(format!("{}: {}", label, name));
|
|
} else {
|
|
missing.push(format!("{} (tried: {})", label, candidates.join(", ")));
|
|
}
|
|
}
|
|
|
|
// Check first layer tensors to determine structure
|
|
let idx = 0;
|
|
for (label, candidates) in [
|
|
("Layer 0 Input Norm", TensorNameMapper::input_norm(idx)),
|
|
(
|
|
"Layer 0 Post-Attn Norm",
|
|
TensorNameMapper::post_attn_norm(idx),
|
|
),
|
|
] {
|
|
if let Some(name) = TensorNameMapper::resolve(&gguf, &candidates) {
|
|
found.push(format!("{}: {}", label, name));
|
|
} else {
|
|
missing.push(format!("{} (tried: {})", label, candidates.join(", ")));
|
|
}
|
|
}
|
|
|
|
// Check attention type
|
|
if TensorNameMapper::has_mla(&gguf, 0) {
|
|
found.push("Attention type: MLA".into());
|
|
for (label, candidates) in [
|
|
("Layer 0 attn_q_a", TensorNameMapper::attn_q_a(0)),
|
|
("Layer 0 attn_q_b", TensorNameMapper::attn_q_b(0)),
|
|
("Layer 0 attn_kv_a_mqa", TensorNameMapper::attn_kv_a_mqa(0)),
|
|
("Layer 0 attn_k_b", TensorNameMapper::attn_k_b(0)),
|
|
("Layer 0 attn_v_b", TensorNameMapper::attn_v_b(0)),
|
|
("Layer 0 attn_output", TensorNameMapper::attn_output(0)),
|
|
] {
|
|
if TensorNameMapper::resolve(&gguf, &candidates).is_some() {
|
|
found.push(format!(" {}: present", label));
|
|
} else {
|
|
missing.push(format!("{} (tried: {})", label, candidates.join(", ")));
|
|
}
|
|
}
|
|
} else {
|
|
found.push("Attention type: GQA".into());
|
|
}
|
|
|
|
// Check FFN structure for layers
|
|
let check_layer = config.first_k_dense_replace.min(config.num_layers);
|
|
if check_layer > 0 {
|
|
if TensorNameMapper::has_dense_ffn(&gguf, 0) {
|
|
found.push("Layer 0: Dense FFN".into());
|
|
} else {
|
|
missing.push("Layer 0 dense FFN tensors".into());
|
|
}
|
|
}
|
|
if config.num_layers > config.first_k_dense_replace {
|
|
let moe_layer = config.first_k_dense_replace;
|
|
if TensorNameMapper::has_stacked_experts(&gguf, moe_layer) {
|
|
found.push(format!("Layer {}: Stacked MoE experts", moe_layer));
|
|
} else if TensorNameMapper::resolve(&gguf, &TensorNameMapper::expert_gate(moe_layer, 0))
|
|
.is_some()
|
|
{
|
|
found.push(format!("Layer {}: Individual MoE experts", moe_layer));
|
|
} else {
|
|
missing.push(format!("Layer {} MoE expert tensors", moe_layer));
|
|
}
|
|
}
|
|
|
|
let can_load = missing.is_empty();
|
|
Ok(ModelValidation {
|
|
can_load,
|
|
config_summary: format!(
|
|
"layers={}, hidden={}, heads={}, experts={}, vocab={}, mla={}",
|
|
config.num_layers,
|
|
config.hidden_size,
|
|
config.num_attention_heads,
|
|
config.num_experts,
|
|
config.vocab_size,
|
|
config.use_mla
|
|
),
|
|
found,
|
|
missing,
|
|
})
|
|
}
|
|
|
|
/// Greedy-decode a single next token from logits.
|
|
fn argmax(logits: &[f32]) -> u32 {
|
|
let mut best_idx = 0u32;
|
|
let mut best_val = f32::NEG_INFINITY;
|
|
for (i, &v) in logits.iter().enumerate() {
|
|
if v > best_val {
|
|
best_val = v;
|
|
best_idx = i as u32;
|
|
}
|
|
}
|
|
best_idx
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tensor Discovery & Validation Report Types
|
|
// ============================================================================
|
|
|
|
/// Report from tensor discovery on a GGUF file.
|
|
#[derive(Debug)]
|
|
pub struct TensorDiscoveryReport {
|
|
/// Total number of tensors
|
|
pub total_tensors: usize,
|
|
/// Total bytes across all tensors
|
|
pub total_bytes: usize,
|
|
/// Architecture string from metadata
|
|
pub architecture: Option<String>,
|
|
/// Grouped tensor listings
|
|
pub tensor_groups: Vec<TensorGroup>,
|
|
/// Warnings or observations
|
|
pub warnings: Vec<String>,
|
|
}
|
|
|
|
/// A group of related tensors.
|
|
#[derive(Debug)]
|
|
pub struct TensorGroup {
|
|
/// Group name (e.g., "Attention", "FFN/Expert")
|
|
pub name: String,
|
|
/// Tensors in this group
|
|
pub tensors: Vec<TensorEntry>,
|
|
}
|
|
|
|
/// Info about a single tensor.
|
|
#[derive(Debug)]
|
|
pub struct TensorEntry {
|
|
/// Tensor name in GGUF
|
|
pub name: String,
|
|
/// Shape dimensions
|
|
pub shape: Vec<usize>,
|
|
/// Quantization type name
|
|
pub dtype: String,
|
|
/// Size in bytes
|
|
pub bytes: usize,
|
|
}
|
|
|
|
/// Result of model validation against expected tensor layout.
|
|
#[derive(Debug)]
|
|
pub struct ModelValidation {
|
|
/// Whether all required tensors were found
|
|
pub can_load: bool,
|
|
/// Summary of detected configuration
|
|
pub config_summary: String,
|
|
/// Tensors that were found
|
|
pub found: Vec<String>,
|
|
/// Tensors that are missing
|
|
pub missing: Vec<String>,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Generation Statistics
|
|
// ============================================================================
|
|
|
|
/// Statistics from a streaming generation run.
|
|
#[derive(Debug, Clone)]
|
|
pub struct GenerationStats {
|
|
/// Number of tokens in the prompt
|
|
pub prompt_tokens: usize,
|
|
/// Number of tokens generated
|
|
pub generated_tokens: usize,
|
|
/// Total tokens processed (prompt + generated)
|
|
pub total_tokens: usize,
|
|
/// Wall-clock time for generation (excluding prefill) in milliseconds
|
|
pub elapsed_ms: u64,
|
|
/// Tokens per second (generated tokens / elapsed time)
|
|
pub tokens_per_second: f64,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Predictive Expert Prefetcher
|
|
// ============================================================================
|
|
|
|
/// Predicts which experts will be needed next based on routing history.
|
|
///
|
|
/// Maintains a transition matrix `P[i][j]` estimating the probability that
|
|
/// expert `j` is selected at position `t+1` given expert `i` at position `t`.
|
|
/// Uses Laplace smoothing to handle unseen transitions.
|
|
///
|
|
/// # Usage
|
|
///
|
|
/// ```rust,ignore
|
|
/// // Build from routing history (one entry per token position)
|
|
/// let history = vec![vec![2, 5], vec![5, 3], vec![2, 7]]; // top-K per position
|
|
/// let predictor = ExpertPredictor::from_history(64, &history);
|
|
///
|
|
/// // Predict next experts given current selection
|
|
/// let current = vec![2, 5];
|
|
/// let predicted = predictor.predict_next(¤t, 4);
|
|
/// // predicted might be [3, 7, 5, 2] — likely next experts
|
|
/// ```
|
|
pub struct ExpertPredictor {
|
|
/// Number of experts
|
|
num_experts: usize,
|
|
/// Transition counts: transition_counts[from][to] = number of observed transitions
|
|
transition_counts: Vec<Vec<u32>>,
|
|
/// Total transitions observed from each expert
|
|
row_totals: Vec<u32>,
|
|
}
|
|
|
|
impl ExpertPredictor {
|
|
/// Build a predictor from routing history.
|
|
///
|
|
/// `routing_history` is a sequence of expert selections, where each entry
|
|
/// contains the expert IDs selected at that position (top-K).
|
|
pub fn from_history(num_experts: usize, routing_history: &[Vec<usize>]) -> Self {
|
|
let mut transition_counts = vec![vec![0u32; num_experts]; num_experts];
|
|
let mut row_totals = vec![0u32; num_experts];
|
|
|
|
// Count transitions: for each consecutive pair of positions,
|
|
// every expert at position t transitions to every expert at position t+1
|
|
for window in routing_history.windows(2) {
|
|
let prev = &window[0];
|
|
let next = &window[1];
|
|
for &from in prev {
|
|
if from >= num_experts {
|
|
continue;
|
|
}
|
|
for &to in next {
|
|
if to >= num_experts {
|
|
continue;
|
|
}
|
|
transition_counts[from][to] += 1;
|
|
row_totals[from] += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
Self {
|
|
num_experts,
|
|
transition_counts,
|
|
row_totals,
|
|
}
|
|
}
|
|
|
|
/// Predict the most likely next experts given the current selection.
|
|
///
|
|
/// Returns up to `top_k` expert IDs ranked by predicted probability.
|
|
/// Aggregates predictions from all currently-active experts.
|
|
pub fn predict_next(&self, current_experts: &[usize], top_k: usize) -> Vec<usize> {
|
|
let mut scores = vec![0.0f32; self.num_experts];
|
|
|
|
for &from in current_experts {
|
|
if from >= self.num_experts {
|
|
continue;
|
|
}
|
|
let total = self.row_totals[from] as f32 + self.num_experts as f32; // Laplace denom
|
|
for to in 0..self.num_experts {
|
|
// Laplace-smoothed probability
|
|
let count = self.transition_counts[from][to] as f32 + 1.0;
|
|
scores[to] += count / total;
|
|
}
|
|
}
|
|
|
|
// Exclude currently-active experts (they're already loaded)
|
|
for &cur in current_experts {
|
|
if cur < self.num_experts {
|
|
scores[cur] = 0.0;
|
|
}
|
|
}
|
|
|
|
// Top-K by score
|
|
let mut indexed: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
|
|
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
indexed.into_iter().take(top_k).map(|(id, _)| id).collect()
|
|
}
|
|
|
|
/// Get the transition probability from expert `from` to expert `to`.
|
|
///
|
|
/// Returns a Laplace-smoothed probability in (0, 1).
|
|
pub fn transition_prob(&self, from: usize, to: usize) -> f32 {
|
|
if from >= self.num_experts || to >= self.num_experts {
|
|
return 0.0;
|
|
}
|
|
let total = self.row_totals[from] as f32 + self.num_experts as f32;
|
|
let count = self.transition_counts[from][to] as f32 + 1.0;
|
|
count / total
|
|
}
|
|
|
|
/// Return the number of experts this predictor covers.
|
|
pub fn num_experts(&self) -> usize {
|
|
self.num_experts
|
|
}
|
|
|
|
/// Total number of observed transitions.
|
|
pub fn total_observations(&self) -> u64 {
|
|
self.row_totals.iter().map(|&r| r as u64).sum()
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Compressed MLA KV Cache
|
|
// ============================================================================
|
|
|
|
/// Compressed KV cache for MLA (Multi-Head Latent Attention) layers.
|
|
///
|
|
/// Instead of storing the full decompressed K and V vectors (which are
|
|
/// `num_heads * (qk_nope_head_dim + qk_rope_head_dim)` and
|
|
/// `num_heads * v_head_dim` per position), this cache stores the
|
|
/// compressed latent representation:
|
|
///
|
|
/// - `c_kv`: The compressed KV latent, size `kv_lora_rank` per position
|
|
/// - `k_pe`: The RoPE-applied key portion, size `qk_rope_head_dim` per position
|
|
///
|
|
/// Total per position: `kv_lora_rank + qk_rope_head_dim` (e.g., 512 + 64 = 576)
|
|
/// vs full KV: `num_heads * (qk_nope_head_dim + qk_rope_head_dim) + num_heads * v_head_dim`
|
|
/// (e.g., 20 * 256 + 20 * 256 = 10240)
|
|
///
|
|
/// This gives a **17.8x memory reduction** for GLM-4.7-Flash at the cost of
|
|
/// recomputing K_nope and V from the compressed latent during attention.
|
|
#[derive(Debug, Clone)]
|
|
pub struct CompressedMlaCache {
|
|
/// Compressed KV latents: one [kv_lora_rank] vector per position
|
|
c_kv: Vec<Vec<f32>>,
|
|
/// RoPE-applied key portion: one [qk_rope_head_dim] vector per position
|
|
k_pe: Vec<Vec<f32>>,
|
|
}
|
|
|
|
impl CompressedMlaCache {
|
|
/// Create a new empty compressed cache.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
c_kv: Vec::new(),
|
|
k_pe: Vec::new(),
|
|
}
|
|
}
|
|
|
|
/// Push a new position's compressed KV data.
|
|
pub fn push(&mut self, c_kv: Vec<f32>, k_pe: Vec<f32>) {
|
|
self.c_kv.push(c_kv);
|
|
self.k_pe.push(k_pe);
|
|
}
|
|
|
|
/// Number of cached positions.
|
|
pub fn len(&self) -> usize {
|
|
self.c_kv.len()
|
|
}
|
|
|
|
/// Check if the cache is empty.
|
|
pub fn is_empty(&self) -> bool {
|
|
self.c_kv.is_empty()
|
|
}
|
|
|
|
/// Clear the cache.
|
|
pub fn clear(&mut self) {
|
|
self.c_kv.clear();
|
|
self.k_pe.clear();
|
|
}
|
|
|
|
/// Memory usage in bytes.
|
|
pub fn memory_bytes(&self) -> usize {
|
|
let c_kv_bytes: usize = self.c_kv.iter().map(|v| v.len() * 4).sum();
|
|
let k_pe_bytes: usize = self.k_pe.iter().map(|v| v.len() * 4).sum();
|
|
c_kv_bytes + k_pe_bytes
|
|
}
|
|
|
|
/// Compute the memory savings ratio vs full KV cache.
|
|
///
|
|
/// Returns the ratio of full cache size to compressed cache size.
|
|
/// E.g., a return value of 17.8 means the compressed cache is 17.8x smaller.
|
|
pub fn savings_ratio(
|
|
num_heads: usize,
|
|
qk_nope_head_dim: usize,
|
|
qk_rope_head_dim: usize,
|
|
v_head_dim: usize,
|
|
kv_lora_rank: usize,
|
|
) -> f32 {
|
|
let full_k_dim = num_heads * (qk_nope_head_dim + qk_rope_head_dim);
|
|
let full_v_dim = num_heads * v_head_dim;
|
|
let full_per_pos = (full_k_dim + full_v_dim) as f32;
|
|
let compressed_per_pos = (kv_lora_rank + qk_rope_head_dim) as f32;
|
|
if compressed_per_pos > 0.0 {
|
|
full_per_pos / compressed_per_pos
|
|
} else {
|
|
0.0
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// LlmBackend Trait Implementation
|
|
// ============================================================================
|
|
|
|
// ============================================================================
|
|
// Tokenizer trait bridge
|
|
// ============================================================================
|
|
|
|
/// Wraps our BpeTokenizer to implement the crate-level Tokenizer trait.
|
|
struct TokenizerBridge<'a> {
|
|
inner: &'a BpeTokenizer,
|
|
}
|
|
|
|
impl<'a> BackendTokenizer for TokenizerBridge<'a> {
|
|
fn encode(&self, text: &str) -> Result<Vec<u32>> {
|
|
Ok(self.inner.encode(text))
|
|
}
|
|
|
|
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
|
Ok(self.inner.decode(tokens))
|
|
}
|
|
|
|
fn vocab_size(&self) -> usize {
|
|
self.inner.vocab_size()
|
|
}
|
|
|
|
fn special_tokens(&self) -> BackendSpecialTokens {
|
|
BackendSpecialTokens {
|
|
bos_token_id: Some(1),
|
|
eos_token_id: Some(2),
|
|
..Default::default()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl LlmBackend for BitNetBackend {
|
|
fn load_model(&mut self, model_id: &str, _config: ModelConfig) -> Result<()> {
|
|
self.load_gguf(model_id)
|
|
}
|
|
|
|
fn generate(&self, prompt: &str, params: GenerateParams) -> Result<String> {
|
|
if !self.loaded {
|
|
return Err(RuvLLMError::Model("No model loaded".to_string()));
|
|
}
|
|
|
|
let tokenizer = self
|
|
.tok
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("No tokenizer loaded".to_string()))?;
|
|
|
|
// Encode prompt via tokenizer
|
|
let prompt_tokens = tokenizer.encode(prompt);
|
|
let eos_id = 2u32;
|
|
|
|
// Autoregressive generation using forward_token with KV cache.
|
|
// Since generate() takes &self (not &mut self), we use the legacy
|
|
// full-sequence forward path here. Use generate_mut() for KV-cached
|
|
// generation.
|
|
let mut tokens = prompt_tokens;
|
|
let mut generated = Vec::new();
|
|
|
|
for _ in 0..params.max_tokens {
|
|
let logits = self.forward(&tokens)?;
|
|
let next_token = Self::argmax(&logits);
|
|
|
|
if next_token == eos_id || next_token == 0 {
|
|
break;
|
|
}
|
|
|
|
generated.push(next_token);
|
|
tokens.push(next_token);
|
|
}
|
|
|
|
// Decode generated tokens back to text
|
|
let text = tokenizer.decode(&generated);
|
|
Ok(text)
|
|
}
|
|
|
|
fn generate_stream(
|
|
&self,
|
|
prompt: &str,
|
|
params: GenerateParams,
|
|
) -> Result<Box<dyn Iterator<Item = Result<GeneratedToken>> + Send + '_>> {
|
|
let result = self.generate(prompt, params)?;
|
|
let tokens: Vec<Result<GeneratedToken>> = result
|
|
.chars()
|
|
.enumerate()
|
|
.map(|(i, c)| {
|
|
Ok(GeneratedToken {
|
|
id: i as u32,
|
|
text: c.to_string(),
|
|
logprob: None,
|
|
is_special: false,
|
|
})
|
|
})
|
|
.collect();
|
|
Ok(Box::new(tokens.into_iter()))
|
|
}
|
|
|
|
fn generate_stream_v2(&self, prompt: &str, params: GenerateParams) -> Result<TokenStream> {
|
|
let (tx, stream) = TokenStream::channel();
|
|
let result = self.generate(prompt, params.clone());
|
|
|
|
match result {
|
|
Ok(text) => {
|
|
let _ = tx.send(StreamEvent::Token(GeneratedToken {
|
|
id: 0,
|
|
text,
|
|
logprob: None,
|
|
is_special: false,
|
|
}));
|
|
let _ = tx.send(StreamEvent::Done {
|
|
total_tokens: 1,
|
|
duration_ms: 0,
|
|
tokens_per_second: 0.0,
|
|
});
|
|
}
|
|
Err(e) => {
|
|
let _ = tx.send(StreamEvent::Error(e.to_string()));
|
|
}
|
|
}
|
|
|
|
Ok(stream)
|
|
}
|
|
|
|
fn get_embeddings(&self, text: &str) -> Result<Vec<f32>> {
|
|
let config = self
|
|
.config
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("No model loaded".to_string()))?;
|
|
let tokenizer = self
|
|
.tok
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("No tokenizer loaded".to_string()))?;
|
|
|
|
let ids = tokenizer.encode(text);
|
|
if ids.is_empty() {
|
|
return Err(RuvLLMError::Model("Empty token sequence".to_string()));
|
|
}
|
|
|
|
// Use last token embedding as text representation
|
|
let last_id = *ids.last().unwrap() as usize;
|
|
let hidden = config.hidden_size;
|
|
if last_id >= config.vocab_size {
|
|
return Err(RuvLLMError::Model("Token exceeds vocab".to_string()));
|
|
}
|
|
Ok(self.embedding[last_id * hidden..(last_id + 1) * hidden].to_vec())
|
|
}
|
|
|
|
fn tokenizer(&self) -> Option<&dyn BackendTokenizer> {
|
|
self.tok
|
|
.as_ref()
|
|
.map(|t| {
|
|
// Safety: we return a reference with the same lifetime as &self.
|
|
// The TokenizerBridge is a thin wrapper — we use a raw pointer trick
|
|
// to avoid the borrow checker issue with returning a trait object
|
|
// that borrows from self.
|
|
//
|
|
// Alternative: store a Box<dyn BackendTokenizer> directly. For now,
|
|
// return None and callers should use `self.tok` directly.
|
|
let _ = t;
|
|
// Return None for the trait-object path; callers can use tok() accessor
|
|
None::<&dyn BackendTokenizer>
|
|
})
|
|
.flatten()
|
|
}
|
|
|
|
fn is_model_loaded(&self) -> bool {
|
|
self.loaded
|
|
}
|
|
|
|
fn model_info(&self) -> Option<ModelInfo> {
|
|
let config = self.config.as_ref()?;
|
|
Some(ModelInfo {
|
|
name: self.model_path.clone(),
|
|
architecture: ModelArchitecture::Qwen,
|
|
num_parameters: config.num_layers
|
|
* config.num_experts
|
|
* config.intermediate_size
|
|
* config.hidden_size
|
|
* 3,
|
|
vocab_size: config.vocab_size,
|
|
hidden_size: config.hidden_size,
|
|
num_layers: config.num_layers,
|
|
max_context_length: config.max_context,
|
|
quantization: Some(Quantization::Q2K),
|
|
memory_usage: self.embedding.len() * 4
|
|
+ self.lm_head.len() * 4
|
|
+ self
|
|
.layers
|
|
.iter()
|
|
.map(|l| {
|
|
let mut bytes = l.gate_weight.len() * 4
|
|
+ l.input_norm_weight.len() * 4
|
|
+ l.post_attn_norm_weight.len() * 4
|
|
+ l.attention.o_proj.memory_bytes();
|
|
// Attention: MLA or GQA
|
|
if l.attention.is_mla {
|
|
bytes += l.attention.q_a.as_ref().map_or(0, |t| t.memory_bytes());
|
|
bytes += l.attention.q_b.as_ref().map_or(0, |t| t.memory_bytes());
|
|
bytes += l
|
|
.attention
|
|
.kv_a_mqa
|
|
.as_ref()
|
|
.map_or(0, |t| t.memory_bytes());
|
|
bytes += l.attention.k_b.as_ref().map_or(0, |t| t.memory_bytes());
|
|
bytes += l.attention.v_b.as_ref().map_or(0, |t| t.memory_bytes());
|
|
bytes += l.attention.q_a_norm.as_ref().map_or(0, |v| v.len() * 4);
|
|
bytes += l.attention.kv_a_norm.as_ref().map_or(0, |v| v.len() * 4);
|
|
} else {
|
|
bytes += l.attention.q_proj.memory_bytes();
|
|
bytes += l.attention.k_proj.memory_bytes();
|
|
bytes += l.attention.v_proj.memory_bytes();
|
|
}
|
|
// FFN: routed experts
|
|
bytes += l
|
|
.experts
|
|
.iter()
|
|
.map(|e| {
|
|
e.gate_proj.memory_bytes()
|
|
+ e.up_proj.memory_bytes()
|
|
+ e.down_proj.memory_bytes()
|
|
})
|
|
.sum::<usize>();
|
|
// FFN: shared expert
|
|
if let Some(ref se) = l.shared_expert {
|
|
bytes += se.gate_proj.memory_bytes()
|
|
+ se.up_proj.memory_bytes()
|
|
+ se.down_proj.memory_bytes();
|
|
}
|
|
// FFN: dense
|
|
if let Some(ref df) = l.dense_ffn {
|
|
bytes += df.gate_proj.memory_bytes()
|
|
+ df.up_proj.memory_bytes()
|
|
+ df.down_proj.memory_bytes();
|
|
}
|
|
bytes
|
|
})
|
|
.sum::<usize>(),
|
|
})
|
|
}
|
|
|
|
fn unload_model(&mut self) {
|
|
self.config = None;
|
|
self.embedding.clear();
|
|
self.lm_head.clear();
|
|
self.final_norm_weight.clear();
|
|
self.layers.clear();
|
|
self.kv_caches.clear();
|
|
self.tok = None;
|
|
self.rope_cos.clear();
|
|
self.rope_sin.clear();
|
|
self.loaded = false;
|
|
self.model_path.clear();
|
|
}
|
|
}
|
|
|
|
impl BitNetBackend {
|
|
/// Autoregressive generate with KV cache (takes &mut self).
|
|
///
|
|
/// This is the efficient path for generation: each token only computes
|
|
/// attention against cached K/V vectors rather than reprocessing the
|
|
/// full sequence.
|
|
pub fn generate_cached(&mut self, prompt: &str, max_tokens: usize) -> Result<String> {
|
|
if !self.loaded {
|
|
return Err(RuvLLMError::Model("No model loaded".to_string()));
|
|
}
|
|
let tokenizer = self
|
|
.tok
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("No tokenizer loaded".to_string()))?;
|
|
|
|
let prompt_tokens = tokenizer.encode(prompt);
|
|
let eos_id = 2u32;
|
|
|
|
self.reset_cache();
|
|
|
|
// Prefill: process all prompt tokens
|
|
let mut last_logits = Vec::new();
|
|
for (pos, &tid) in prompt_tokens.iter().enumerate() {
|
|
last_logits = self.forward_token(tid, pos)?;
|
|
}
|
|
|
|
// Decode
|
|
let mut generated = Vec::new();
|
|
let mut pos = prompt_tokens.len();
|
|
|
|
for _ in 0..max_tokens {
|
|
let next_token = Self::argmax(&last_logits);
|
|
if next_token == eos_id || next_token == 0 {
|
|
break;
|
|
}
|
|
generated.push(next_token);
|
|
last_logits = self.forward_token(next_token, pos)?;
|
|
pos += 1;
|
|
}
|
|
|
|
let tokenizer = self.tok.as_ref().unwrap();
|
|
Ok(tokenizer.decode(&generated))
|
|
}
|
|
|
|
/// Get the loaded tokenizer (if any).
|
|
pub fn tok(&self) -> Option<&BpeTokenizer> {
|
|
self.tok.as_ref()
|
|
}
|
|
|
|
// ========================================================================
|
|
// Streaming Generation
|
|
// ========================================================================
|
|
|
|
/// Streaming autoregressive generation with per-token callback.
|
|
///
|
|
/// Calls `on_token` for each generated token, allowing callers to process
|
|
/// tokens incrementally (e.g., for real-time output). The callback receives
|
|
/// the token ID, the decoded text for that token, and the token's position.
|
|
///
|
|
/// Returns the concatenated generated text. If the callback returns `false`,
|
|
/// generation stops early (allows callers to implement stop conditions).
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `prompt` - Input text to condition on
|
|
/// * `max_tokens` - Maximum number of tokens to generate
|
|
/// * `on_token` - Callback invoked for each token: `(token_id, text, position) -> continue?`
|
|
pub fn generate_streaming<F>(
|
|
&mut self,
|
|
prompt: &str,
|
|
max_tokens: usize,
|
|
mut on_token: F,
|
|
) -> Result<GenerationStats>
|
|
where
|
|
F: FnMut(u32, &str, usize) -> bool,
|
|
{
|
|
if !self.loaded {
|
|
return Err(RuvLLMError::Model("No model loaded".to_string()));
|
|
}
|
|
let tokenizer = self
|
|
.tok
|
|
.as_ref()
|
|
.ok_or_else(|| RuvLLMError::Model("No tokenizer loaded".to_string()))?;
|
|
|
|
let prompt_tokens = tokenizer.encode(prompt);
|
|
let eos_id = 2u32;
|
|
let prompt_len = prompt_tokens.len();
|
|
|
|
self.reset_cache();
|
|
|
|
// Prefill: process all prompt tokens
|
|
let mut last_logits = Vec::new();
|
|
for (pos, &tid) in prompt_tokens.iter().enumerate() {
|
|
last_logits = self.forward_token(tid, pos)?;
|
|
}
|
|
|
|
// Decode with streaming callback
|
|
let mut generated_tokens = Vec::new();
|
|
let mut pos = prompt_len;
|
|
|
|
let start_time = std::time::Instant::now();
|
|
|
|
for _ in 0..max_tokens {
|
|
let next_token = Self::argmax(&last_logits);
|
|
if next_token == eos_id || next_token == 0 {
|
|
break;
|
|
}
|
|
|
|
// Decode single token
|
|
let tokenizer = self.tok.as_ref().unwrap();
|
|
let token_text = tokenizer.decode(&[next_token]);
|
|
|
|
generated_tokens.push(next_token);
|
|
|
|
// Invoke callback; stop if it returns false
|
|
if !on_token(next_token, &token_text, pos) {
|
|
break;
|
|
}
|
|
|
|
last_logits = self.forward_token(next_token, pos)?;
|
|
pos += 1;
|
|
}
|
|
|
|
let elapsed = start_time.elapsed();
|
|
let num_generated = generated_tokens.len();
|
|
|
|
Ok(GenerationStats {
|
|
prompt_tokens: prompt_len,
|
|
generated_tokens: num_generated,
|
|
total_tokens: prompt_len + num_generated,
|
|
elapsed_ms: elapsed.as_millis() as u64,
|
|
tokens_per_second: if elapsed.as_secs_f64() > 0.0 {
|
|
num_generated as f64 / elapsed.as_secs_f64()
|
|
} else {
|
|
0.0
|
|
},
|
|
})
|
|
}
|
|
|
|
// ========================================================================
|
|
// Predictive Expert Prefetcher
|
|
// ========================================================================
|
|
|
|
/// Create a predictive expert prefetcher from routing history.
|
|
///
|
|
/// Analyzes past routing decisions to build a co-occurrence matrix:
|
|
/// if expert A is selected at position t, which experts are likely at t+1?
|
|
/// Uses this to predict and warm up likely-next experts before they're needed.
|
|
pub fn build_expert_predictor(&self, routing_history: &[Vec<usize>]) -> ExpertPredictor {
|
|
let num_experts = self.config.as_ref().map(|c| c.num_experts).unwrap_or(64);
|
|
|
|
ExpertPredictor::from_history(num_experts, routing_history)
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Math Helpers (standalone functions used by the backend)
|
|
// ============================================================================
|
|
|
|
/// In-place RMSNorm: x = x / rms(x) * weight
|
|
///
|
|
/// Optimized with 4-wide accumulator and fused multiply for better ILP.
|
|
#[inline]
|
|
fn rms_norm_inplace(x: &mut [f32], weight: &[f32], eps: f32) {
|
|
let n = x.len();
|
|
if n == 0 {
|
|
return;
|
|
}
|
|
|
|
// 4-way parallel accumulation for sum of squares
|
|
let mut s0 = 0.0f32;
|
|
let mut s1 = 0.0f32;
|
|
let mut s2 = 0.0f32;
|
|
let mut s3 = 0.0f32;
|
|
let chunks = n / 4;
|
|
let tail = chunks * 4;
|
|
|
|
for c in 0..chunks {
|
|
let base = c * 4;
|
|
unsafe {
|
|
let v0 = *x.get_unchecked(base);
|
|
let v1 = *x.get_unchecked(base + 1);
|
|
let v2 = *x.get_unchecked(base + 2);
|
|
let v3 = *x.get_unchecked(base + 3);
|
|
s0 += v0 * v0;
|
|
s1 += v1 * v1;
|
|
s2 += v2 * v2;
|
|
s3 += v3 * v3;
|
|
}
|
|
}
|
|
let mut sum_sq = s0 + s1 + s2 + s3;
|
|
for i in tail..n {
|
|
sum_sq += x[i] * x[i];
|
|
}
|
|
|
|
let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
|
|
|
|
// Fused scale: x[i] = x[i] * inv_rms * weight[i]
|
|
if weight.len() >= n {
|
|
// Fast path: weight is correctly sized (common case)
|
|
for c in 0..chunks {
|
|
let base = c * 4;
|
|
unsafe {
|
|
*x.get_unchecked_mut(base) *= inv_rms * *weight.get_unchecked(base);
|
|
*x.get_unchecked_mut(base + 1) *= inv_rms * *weight.get_unchecked(base + 1);
|
|
*x.get_unchecked_mut(base + 2) *= inv_rms * *weight.get_unchecked(base + 2);
|
|
*x.get_unchecked_mut(base + 3) *= inv_rms * *weight.get_unchecked(base + 3);
|
|
}
|
|
}
|
|
for i in tail..n {
|
|
x[i] *= inv_rms * weight[i];
|
|
}
|
|
} else {
|
|
// Fallback: weight may be shorter
|
|
for i in 0..n {
|
|
x[i] *= inv_rms * weight.get(i).copied().unwrap_or(1.0);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// In-place softmax with streaming max and fused exp+sum.
|
|
///
|
|
/// Guards against NaN propagation: if all inputs are -inf or NaN,
|
|
/// the result is a uniform distribution (1/n for each element).
|
|
#[inline]
|
|
fn softmax_inplace(x: &mut [f32]) {
|
|
let n = x.len();
|
|
if n == 0 {
|
|
return;
|
|
}
|
|
|
|
// Streaming max with 4-wide reduction
|
|
let mut max_val = f32::NEG_INFINITY;
|
|
for &v in x.iter() {
|
|
if v > max_val {
|
|
max_val = v;
|
|
}
|
|
}
|
|
|
|
// Guard: if max_val is -inf or NaN, fall back to uniform
|
|
if max_val.is_nan() || (max_val.is_infinite() && max_val.is_sign_negative()) {
|
|
let uniform = 1.0 / n as f32;
|
|
for v in x.iter_mut() {
|
|
*v = uniform;
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Fused exp + sum in a single pass
|
|
let mut sum_exp = 0.0f32;
|
|
for v in x.iter_mut() {
|
|
let e = (*v - max_val).exp();
|
|
*v = e;
|
|
sum_exp += e;
|
|
}
|
|
|
|
// Guard: degenerate sum
|
|
if !sum_exp.is_normal() || sum_exp <= 0.0 {
|
|
let uniform = 1.0 / n as f32;
|
|
for v in x.iter_mut() {
|
|
*v = uniform;
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Normalize with reciprocal multiply (faster than per-element division)
|
|
let inv_sum = 1.0 / sum_exp;
|
|
for v in x.iter_mut() {
|
|
*v *= inv_sum;
|
|
}
|
|
}
|
|
|
|
/// Sigmoid activation.
|
|
#[inline(always)]
|
|
fn sigmoid(x: f32) -> f32 {
|
|
1.0 / (1.0 + (-x).exp())
|
|
}
|
|
|
|
/// FP16 bits to FP32 conversion (same as in gguf/quantization.rs).
|
|
#[inline(always)]
|
|
fn f16_to_f32(bits: u16) -> f32 {
|
|
let sign = ((bits & 0x8000) as u32) << 16;
|
|
let exp = ((bits >> 10) & 0x1F) as u32;
|
|
let frac = (bits & 0x03FF) as u32;
|
|
|
|
if exp == 0 {
|
|
if frac == 0 {
|
|
return f32::from_bits(sign);
|
|
}
|
|
let mut e = 1u32;
|
|
let mut f = frac;
|
|
while (f & 0x0400) == 0 {
|
|
f <<= 1;
|
|
e += 1;
|
|
}
|
|
f &= 0x03FF;
|
|
return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | (f << 13));
|
|
}
|
|
|
|
if exp == 31 {
|
|
return f32::from_bits(sign | 0x7F80_0000 | (frac << 13));
|
|
}
|
|
|
|
f32::from_bits(sign | ((exp + 127 - 15) << 23) | (frac << 13))
|
|
}
|
|
|
|
/// FP32 matrix-vector product (transposed): out[i] = dot(mat[i*cols..], vec)
|
|
///
|
|
/// mat is [rows, cols] row-major, vec is [cols], out is [rows].
|
|
/// Optimized with 4-wide unrolled inner loop for better ILP and cache utilization.
|
|
#[inline]
|
|
fn fp32_matvec_transposed(mat: &[f32], vec: &[f32], rows: usize, cols: usize) -> Vec<f32> {
|
|
let mut output = vec![0.0f32; rows];
|
|
let chunks = cols / 4;
|
|
let tail = chunks * 4;
|
|
|
|
for i in 0..rows {
|
|
let row_start = i * cols;
|
|
if row_start + cols > mat.len() {
|
|
break;
|
|
}
|
|
|
|
// 4-wide unrolled dot product
|
|
let mut d0 = 0.0f32;
|
|
let mut d1 = 0.0f32;
|
|
let mut d2 = 0.0f32;
|
|
let mut d3 = 0.0f32;
|
|
|
|
for c in 0..chunks {
|
|
let j = c * 4;
|
|
unsafe {
|
|
let m0 = *mat.get_unchecked(row_start + j);
|
|
let m1 = *mat.get_unchecked(row_start + j + 1);
|
|
let m2 = *mat.get_unchecked(row_start + j + 2);
|
|
let m3 = *mat.get_unchecked(row_start + j + 3);
|
|
let v0 = *vec.get_unchecked(j);
|
|
let v1 = *vec.get_unchecked(j + 1);
|
|
let v2 = *vec.get_unchecked(j + 2);
|
|
let v3 = *vec.get_unchecked(j + 3);
|
|
d0 += m0 * v0;
|
|
d1 += m1 * v1;
|
|
d2 += m2 * v2;
|
|
d3 += m3 * v3;
|
|
}
|
|
}
|
|
|
|
let mut dot = d0 + d1 + d2 + d3;
|
|
for j in tail..cols {
|
|
dot += mat[row_start + j] * vec[j];
|
|
}
|
|
output[i] = dot;
|
|
}
|
|
output
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tests
|
|
// ============================================================================
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::bitnet::{pack_ternary, TernaryTensor};
|
|
|
|
#[test]
|
|
fn test_build_tl1_lut() {
|
|
let lut = build_tl1_lut();
|
|
|
|
// Byte 0x00 = all bits 00 = all -1
|
|
assert_eq!(lut[0x00], [-1, -1, -1, -1]);
|
|
|
|
// Byte 0x55 = 01_01_01_01 = all 0
|
|
assert_eq!(lut[0x55], [0, 0, 0, 0]);
|
|
|
|
// Byte 0xAA = 10_10_10_10 = all +1
|
|
assert_eq!(lut[0xAA], [1, 1, 1, 1]);
|
|
|
|
// Byte 0x24 = 00_10_01_00 => positions: [00, 01, 10, 00] => [-1, 0, 1, -1]
|
|
// bit layout LSB first: bits[0:1]=00, bits[2:3]=01, bits[4:5]=10, bits[6:7]=00
|
|
// 0x24 = 0b00_10_01_00
|
|
assert_eq!(lut[0x24], [-1, 0, 1, -1]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_rms_norm_inplace() {
|
|
let mut x = vec![1.0, 2.0, 3.0, 4.0];
|
|
let w = vec![1.0; 4];
|
|
rms_norm_inplace(&mut x, &w, 1e-6);
|
|
|
|
// RMS of [1,2,3,4] = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.7386
|
|
let rms = (30.0f32 / 4.0).sqrt();
|
|
let expected: Vec<f32> = [1.0, 2.0, 3.0, 4.0].iter().map(|v| v / rms).collect();
|
|
|
|
for (a, b) in x.iter().zip(expected.iter()) {
|
|
assert!((a - b).abs() < 1e-4, "got {} expected {}", a, b);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_softmax_inplace() {
|
|
let mut x = vec![1.0, 2.0, 3.0];
|
|
softmax_inplace(&mut x);
|
|
|
|
// Sum should be 1.0
|
|
let sum: f32 = x.iter().sum();
|
|
assert!((sum - 1.0).abs() < 1e-6);
|
|
|
|
// Values should be ordered
|
|
assert!(x[0] < x[1]);
|
|
assert!(x[1] < x[2]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sigmoid() {
|
|
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
|
|
assert!(sigmoid(10.0) > 0.999);
|
|
assert!(sigmoid(-10.0) < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_fp32_matvec_transposed() {
|
|
// Identity matrix 3x3
|
|
let mat = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
|
|
let vec_in = vec![2.0, 3.0, 4.0];
|
|
let out = fp32_matvec_transposed(&mat, &vec_in, 3, 3);
|
|
assert_eq!(out, vec![2.0, 3.0, 4.0]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tl1_gemv_simple() {
|
|
let backend = BitNetBackend::new();
|
|
|
|
// Create a 2x4 ternary weight matrix:
|
|
// Row 0: [+1, +1, +1, +1]
|
|
// Row 1: [-1, -1, -1, -1]
|
|
let row0 = vec![1i8, 1, 1, 1];
|
|
let row1 = vec![-1i8, -1, -1, -1];
|
|
let mut all = row0.clone();
|
|
all.extend_from_slice(&row1);
|
|
let packed = pack_ternary(&all);
|
|
|
|
let weight = TernaryTensor {
|
|
packed_data: packed,
|
|
scales: vec![1.0, 1.0], // one scale per block (each row < 256, so 1 block per row)
|
|
shape: (2, 4),
|
|
block_size: 256,
|
|
};
|
|
|
|
let input = vec![1.0, 2.0, 3.0, 4.0];
|
|
let output = backend.tl1_gemv(&weight, &input, 2, 4);
|
|
|
|
// Row 0: 1+2+3+4 = 10, scale=1.0
|
|
assert!((output[0] - 10.0).abs() < 1e-6);
|
|
// Row 1: -(1+2+3+4) = -10, scale=1.0
|
|
assert!((output[1] - (-10.0)).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tl1_gemv_with_zeros() {
|
|
let backend = BitNetBackend::new();
|
|
|
|
// Row: [+1, 0, -1, 0]
|
|
let vals = vec![1i8, 0, -1, 0];
|
|
let packed = pack_ternary(&vals);
|
|
|
|
let weight = TernaryTensor {
|
|
packed_data: packed,
|
|
scales: vec![2.0],
|
|
shape: (1, 4),
|
|
block_size: 256,
|
|
};
|
|
|
|
let input = vec![5.0, 3.0, 7.0, 9.0];
|
|
let output = backend.tl1_gemv(&weight, &input, 1, 4);
|
|
|
|
// Result: (5.0 + 0 - 7.0 + 0) * 2.0 = -2.0 * 2.0 = -4.0
|
|
assert!((output[0] - (-4.0)).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_bitnet_model_config_default() {
|
|
let config = BitNetModelConfig::default();
|
|
// GLM-4.7-Flash defaults
|
|
assert_eq!(config.num_layers, 47);
|
|
assert_eq!(config.hidden_size, 2048);
|
|
assert_eq!(config.num_experts, 64);
|
|
assert_eq!(config.active_experts, 4);
|
|
assert_eq!(config.moe_intermediate_size, 1536);
|
|
assert!(config.use_mla);
|
|
assert_eq!(config.q_lora_rank, 768);
|
|
assert_eq!(config.kv_lora_rank, 512);
|
|
assert_eq!(config.qk_nope_head_dim, 192);
|
|
assert_eq!(config.qk_rope_head_dim, 64);
|
|
assert_eq!(config.v_head_dim, 256);
|
|
assert_eq!(config.n_shared_experts, 1);
|
|
assert_eq!(config.first_k_dense_replace, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_route_experts_topk() {
|
|
let backend = BitNetBackend::new();
|
|
let config = BitNetModelConfig {
|
|
num_experts: 4,
|
|
active_experts: 2,
|
|
hidden_size: 4,
|
|
..Default::default()
|
|
};
|
|
|
|
// Gate weight [4 experts, 4 hidden]: identity-like so expert scores = hidden_states
|
|
let gate_weight = vec![
|
|
1.0, 0.0, 0.0, 0.0, // Expert 0 looks at dim 0
|
|
0.0, 1.0, 0.0, 0.0, // Expert 1 looks at dim 1
|
|
0.0, 0.0, 1.0, 0.0, // Expert 2 looks at dim 2
|
|
0.0, 0.0, 0.0, 1.0, // Expert 3 looks at dim 3
|
|
];
|
|
|
|
// Hidden states: dim 2 is highest, dim 3 is second
|
|
let hidden = vec![0.1, 0.2, 0.9, 0.5];
|
|
|
|
let (indices, weights) = backend
|
|
.route_experts(&hidden, &gate_weight, &config)
|
|
.unwrap();
|
|
|
|
assert_eq!(indices.len(), 2);
|
|
assert_eq!(weights.len(), 2);
|
|
|
|
// Expert 2 should be first (score 0.9), Expert 3 second (score 0.5)
|
|
assert_eq!(indices[0], 2);
|
|
assert_eq!(indices[1], 3);
|
|
|
|
// Weights should sum to ~1.0
|
|
let wsum: f32 = weights.iter().sum();
|
|
assert!((wsum - 1.0).abs() < 1e-4);
|
|
}
|
|
|
|
#[test]
|
|
fn test_backend_new_unloaded() {
|
|
let backend = BitNetBackend::new();
|
|
assert!(!backend.is_model_loaded());
|
|
assert!(backend.model_info().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_rope_tables() {
|
|
let mut backend = BitNetBackend::new();
|
|
backend.build_rope_tables(16, 8, 10000.0);
|
|
|
|
let half = 4; // head_dim / 2
|
|
// Position 0: all angles are 0 → cos=1, sin=0
|
|
for i in 0..half {
|
|
assert!(
|
|
(backend.rope_cos[i] - 1.0).abs() < 1e-5,
|
|
"cos[0][{}]={}",
|
|
i,
|
|
backend.rope_cos[i]
|
|
);
|
|
assert!(
|
|
backend.rope_sin[i].abs() < 1e-5,
|
|
"sin[0][{}]={}",
|
|
i,
|
|
backend.rope_sin[i]
|
|
);
|
|
}
|
|
|
|
// Table size should be max_seq * half
|
|
assert_eq!(backend.rope_cos.len(), 16 * 4);
|
|
assert_eq!(backend.rope_sin.len(), 16 * 4);
|
|
}
|
|
|
|
#[test]
|
|
fn test_apply_rope_identity_at_pos_0() {
|
|
let mut backend = BitNetBackend::new();
|
|
backend.build_rope_tables(8, 4, 10000.0);
|
|
|
|
let mut x = vec![1.0, 2.0, 3.0, 4.0];
|
|
let original = x.clone();
|
|
backend.apply_rope(&mut x, 1, 4, 0);
|
|
|
|
// At position 0, all angles are 0, so cos=1, sin=0 → identity
|
|
for (a, b) in x.iter().zip(original.iter()) {
|
|
assert!(
|
|
(a - b).abs() < 1e-5,
|
|
"RoPE at pos 0 should be identity: got {} vs {}",
|
|
a,
|
|
b
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_apply_rope_rotates_at_pos_1() {
|
|
let mut backend = BitNetBackend::new();
|
|
backend.build_rope_tables(8, 4, 10000.0);
|
|
|
|
let mut x = vec![1.0, 0.0, 1.0, 0.0]; // head_dim=4, 1 head
|
|
let original = x.clone();
|
|
backend.apply_rope(&mut x, 1, 4, 1);
|
|
|
|
// At position 1, some rotation should happen
|
|
let changed = x
|
|
.iter()
|
|
.zip(original.iter())
|
|
.any(|(a, b)| (a - b).abs() > 1e-6);
|
|
assert!(changed, "RoPE at pos 1 should rotate the vector");
|
|
|
|
// Norm should be preserved (RoPE is an orthogonal rotation)
|
|
let orig_norm: f32 = original.iter().map(|v| v * v).sum::<f32>().sqrt();
|
|
let new_norm: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
|
|
assert!(
|
|
(orig_norm - new_norm).abs() < 1e-4,
|
|
"RoPE should preserve norm"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_kv_cache_operations() {
|
|
let mut cache = LayerKvCache::new();
|
|
assert_eq!(cache.len(), 0);
|
|
|
|
cache.keys.push(vec![1.0, 2.0]);
|
|
cache.values.push(vec![3.0, 4.0]);
|
|
assert_eq!(cache.len(), 1);
|
|
|
|
cache.keys.push(vec![5.0, 6.0]);
|
|
cache.values.push(vec![7.0, 8.0]);
|
|
assert_eq!(cache.len(), 2);
|
|
|
|
cache.clear();
|
|
assert_eq!(cache.len(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_byte_level_tokenizer() {
|
|
let tok = BitNetBackend::build_byte_level_tokenizer();
|
|
assert_eq!(tok.vocab_size(), 260); // 4 special + 256 byte tokens
|
|
|
|
// Roundtrip ASCII
|
|
let ids = tok.encode("Hello");
|
|
let decoded = tok.decode(&ids);
|
|
assert_eq!(decoded, "Hello", "Byte-level tokenizer roundtrip failed");
|
|
|
|
// BOS should be prepended
|
|
assert_eq!(ids[0], 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_byte_level_tokenizer_utf8() {
|
|
let tok = BitNetBackend::build_byte_level_tokenizer();
|
|
let text = "cafe\u{0301}"; // combining accent
|
|
let ids = tok.encode(text);
|
|
let decoded = tok.decode(&ids);
|
|
assert_eq!(decoded, text);
|
|
}
|
|
|
|
#[test]
|
|
fn test_backend_reset_cache() {
|
|
let mut backend = BitNetBackend::new();
|
|
// Manually set up caches
|
|
backend.kv_caches = vec![LayerKvCache::new(), LayerKvCache::new()];
|
|
backend.kv_caches[0].keys.push(vec![1.0]);
|
|
backend.kv_caches[1].keys.push(vec![2.0]);
|
|
|
|
backend.reset_cache();
|
|
assert_eq!(backend.kv_caches[0].len(), 0);
|
|
assert_eq!(backend.kv_caches[1].len(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_weights_gqa() {
|
|
// Verify GQA AttentionWeights construction
|
|
let packed = pack_ternary(&[1, 0, -1, 0]);
|
|
let tensor = TernaryTensor {
|
|
packed_data: packed.clone(),
|
|
scales: vec![1.0],
|
|
shape: (1, 4),
|
|
block_size: 256,
|
|
};
|
|
let attn = AttentionWeights {
|
|
is_mla: false,
|
|
q_proj: tensor.clone(),
|
|
k_proj: tensor.clone(),
|
|
v_proj: tensor.clone(),
|
|
o_proj: tensor,
|
|
q_a: None,
|
|
q_b: None,
|
|
q_a_norm: None,
|
|
kv_a_mqa: None,
|
|
kv_a_norm: None,
|
|
k_b: None,
|
|
v_b: None,
|
|
};
|
|
assert!(!attn.is_mla);
|
|
assert_eq!(attn.q_proj.shape, (1, 4));
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_weights_mla() {
|
|
// Verify MLA AttentionWeights construction
|
|
let packed = pack_ternary(&[1, 0, -1, 0]);
|
|
let tensor = TernaryTensor {
|
|
packed_data: packed.clone(),
|
|
scales: vec![1.0],
|
|
shape: (1, 4),
|
|
block_size: 256,
|
|
};
|
|
let placeholder = TernaryTensor {
|
|
packed_data: vec![],
|
|
scales: vec![],
|
|
shape: (0, 0),
|
|
block_size: 256,
|
|
};
|
|
let attn = AttentionWeights {
|
|
is_mla: true,
|
|
q_proj: placeholder.clone(),
|
|
k_proj: placeholder.clone(),
|
|
v_proj: placeholder,
|
|
o_proj: tensor.clone(),
|
|
q_a: Some(tensor.clone()),
|
|
q_b: Some(tensor.clone()),
|
|
q_a_norm: Some(vec![1.0; 4]),
|
|
kv_a_mqa: Some(tensor.clone()),
|
|
kv_a_norm: Some(vec![1.0; 4]),
|
|
k_b: Some(tensor.clone()),
|
|
v_b: Some(tensor),
|
|
};
|
|
assert!(attn.is_mla);
|
|
assert!(attn.q_a.is_some());
|
|
assert!(attn.q_b.is_some());
|
|
assert!(attn.kv_a_mqa.is_some());
|
|
assert!(attn.k_b.is_some());
|
|
assert!(attn.v_b.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_tok_accessor() {
|
|
let mut backend = BitNetBackend::new();
|
|
assert!(backend.tok().is_none());
|
|
|
|
backend.tok = Some(BitNetBackend::build_byte_level_tokenizer());
|
|
assert!(backend.tok().is_some());
|
|
assert_eq!(backend.tok().unwrap().vocab_size(), 260);
|
|
}
|
|
|
|
#[test]
|
|
fn test_layer_type_enum() {
|
|
assert_eq!(LayerType::Dense, LayerType::Dense);
|
|
assert_ne!(LayerType::Dense, LayerType::Moe);
|
|
assert_ne!(LayerType::Moe, LayerType::MoeWithShared);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_name_mapper_embedding() {
|
|
let candidates = TensorNameMapper::embedding();
|
|
assert_eq!(candidates.len(), 2);
|
|
assert!(candidates.contains(&"token_embd.weight".to_string()));
|
|
assert!(candidates.contains(&"model.embed_tokens.weight".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_name_mapper_mla() {
|
|
let q_a = TensorNameMapper::attn_q_a(5);
|
|
assert_eq!(q_a, vec!["blk.5.attn_q_a.weight".to_string()]);
|
|
|
|
let q_b = TensorNameMapper::attn_q_b(5);
|
|
assert_eq!(q_b, vec!["blk.5.attn_q_b.weight".to_string()]);
|
|
|
|
let kv_a = TensorNameMapper::attn_kv_a_mqa(5);
|
|
assert_eq!(kv_a, vec!["blk.5.attn_kv_a_mqa.weight".to_string()]);
|
|
|
|
let k_b = TensorNameMapper::attn_k_b(5);
|
|
assert_eq!(k_b, vec!["blk.5.attn_k_b.weight".to_string()]);
|
|
|
|
let v_b = TensorNameMapper::attn_v_b(5);
|
|
assert_eq!(v_b, vec!["blk.5.attn_v_b.weight".to_string()]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_name_mapper_norms() {
|
|
let in_norm = TensorNameMapper::input_norm(3);
|
|
assert!(in_norm.contains(&"blk.3.attn_norm.weight".to_string()));
|
|
assert!(in_norm.contains(&"model.layers.3.input_layernorm.weight".to_string()));
|
|
|
|
let post_norm = TensorNameMapper::post_attn_norm(3);
|
|
assert!(post_norm.contains(&"blk.3.ffn_norm.weight".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_name_mapper_moe() {
|
|
let gate = TensorNameMapper::moe_gate(2);
|
|
assert!(gate.contains(&"blk.2.ffn_gate_inp.weight".to_string()));
|
|
|
|
let exps = TensorNameMapper::ffn_gate_exps(2);
|
|
assert_eq!(exps, vec!["blk.2.ffn_gate_exps.weight".to_string()]);
|
|
|
|
let shexp = TensorNameMapper::ffn_gate_shexp(2);
|
|
assert_eq!(shexp, vec!["blk.2.ffn_gate_shexp.weight".to_string()]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_name_mapper_dense_ffn() {
|
|
let gate = TensorNameMapper::ffn_gate(0);
|
|
assert!(gate.contains(&"blk.0.ffn_gate.weight".to_string()));
|
|
assert!(gate.contains(&"model.layers.0.mlp.gate_proj.weight".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_name_mapper_individual_experts() {
|
|
let gate = TensorNameMapper::expert_gate(1, 3);
|
|
assert_eq!(
|
|
gate,
|
|
vec!["model.layers.1.mlp.experts.3.gate_proj.weight".to_string()]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mla_config_dimensions() {
|
|
let config = BitNetModelConfig::default();
|
|
// Q head dim = qk_nope_head_dim + qk_rope_head_dim
|
|
let q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim;
|
|
assert_eq!(q_head_dim, 256);
|
|
|
|
// Total Q dim = num_heads * q_head_dim
|
|
let total_q_dim = config.num_attention_heads * q_head_dim;
|
|
assert_eq!(total_q_dim, 5120);
|
|
|
|
// KV compression output = kv_lora_rank + qk_rope_head_dim
|
|
let kv_a_out = config.kv_lora_rank + config.qk_rope_head_dim;
|
|
assert_eq!(kv_a_out, 576);
|
|
}
|
|
|
|
#[test]
|
|
fn test_transformer_layer_dense() {
|
|
let packed = pack_ternary(&[1, 0, -1, 0]);
|
|
let tensor = TernaryTensor {
|
|
packed_data: packed.clone(),
|
|
scales: vec![1.0],
|
|
shape: (1, 4),
|
|
block_size: 256,
|
|
};
|
|
let attn = AttentionWeights {
|
|
is_mla: false,
|
|
q_proj: tensor.clone(),
|
|
k_proj: tensor.clone(),
|
|
v_proj: tensor.clone(),
|
|
o_proj: tensor.clone(),
|
|
q_a: None,
|
|
q_b: None,
|
|
q_a_norm: None,
|
|
kv_a_mqa: None,
|
|
kv_a_norm: None,
|
|
k_b: None,
|
|
v_b: None,
|
|
};
|
|
let layer = TransformerLayer {
|
|
input_norm_weight: vec![1.0; 4],
|
|
post_attn_norm_weight: vec![1.0; 4],
|
|
attention: attn,
|
|
layer_type: LayerType::Dense,
|
|
gate_weight: Vec::new(),
|
|
experts: Vec::new(),
|
|
shared_expert: None,
|
|
dense_ffn: Some(ExpertWeights {
|
|
gate_proj: tensor.clone(),
|
|
up_proj: tensor.clone(),
|
|
down_proj: tensor,
|
|
}),
|
|
};
|
|
assert_eq!(layer.layer_type, LayerType::Dense);
|
|
assert!(layer.dense_ffn.is_some());
|
|
assert!(layer.shared_expert.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_transformer_layer_moe_with_shared() {
|
|
let packed = pack_ternary(&[1, 0, -1, 0]);
|
|
let tensor = TernaryTensor {
|
|
packed_data: packed.clone(),
|
|
scales: vec![1.0],
|
|
shape: (1, 4),
|
|
block_size: 256,
|
|
};
|
|
let attn = AttentionWeights {
|
|
is_mla: false,
|
|
q_proj: tensor.clone(),
|
|
k_proj: tensor.clone(),
|
|
v_proj: tensor.clone(),
|
|
o_proj: tensor.clone(),
|
|
q_a: None,
|
|
q_b: None,
|
|
q_a_norm: None,
|
|
kv_a_mqa: None,
|
|
kv_a_norm: None,
|
|
k_b: None,
|
|
v_b: None,
|
|
};
|
|
let expert = ExpertWeights {
|
|
gate_proj: tensor.clone(),
|
|
up_proj: tensor.clone(),
|
|
down_proj: tensor.clone(),
|
|
};
|
|
let layer = TransformerLayer {
|
|
input_norm_weight: vec![1.0; 4],
|
|
post_attn_norm_weight: vec![1.0; 4],
|
|
attention: attn,
|
|
layer_type: LayerType::MoeWithShared,
|
|
gate_weight: vec![1.0; 8], // 2 experts x 4 hidden
|
|
experts: vec![expert.clone(), expert.clone()],
|
|
shared_expert: Some(expert),
|
|
dense_ffn: None,
|
|
};
|
|
assert_eq!(layer.layer_type, LayerType::MoeWithShared);
|
|
assert_eq!(layer.experts.len(), 2);
|
|
assert!(layer.shared_expert.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_discovery_report_struct() {
|
|
let report = TensorDiscoveryReport {
|
|
total_tensors: 10,
|
|
total_bytes: 1024,
|
|
architecture: Some("deepseek2".into()),
|
|
tensor_groups: vec![TensorGroup {
|
|
name: "Embedding".into(),
|
|
tensors: vec![TensorEntry {
|
|
name: "token_embd.weight".into(),
|
|
shape: vec![154880, 2048],
|
|
dtype: "Q8_0".into(),
|
|
bytes: 512,
|
|
}],
|
|
}],
|
|
warnings: vec!["MLA detected".into()],
|
|
};
|
|
assert_eq!(report.total_tensors, 10);
|
|
assert_eq!(report.tensor_groups.len(), 1);
|
|
assert_eq!(report.warnings.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_model_validation_struct() {
|
|
let validation = ModelValidation {
|
|
can_load: true,
|
|
config_summary: "layers=47, hidden=2048".into(),
|
|
found: vec!["Embedding: token_embd.weight".into()],
|
|
missing: vec![],
|
|
};
|
|
assert!(validation.can_load);
|
|
assert_eq!(validation.found.len(), 1);
|
|
assert!(validation.missing.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_meta_helpers() {
|
|
// Test that meta_usize and meta_f32 handle missing keys
|
|
// (We can't easily construct a GgufFile in tests, so we test the
|
|
// behavior through the config defaults)
|
|
let config = BitNetModelConfig::default();
|
|
assert_eq!(config.rope_theta, 1_000_000.0);
|
|
assert_eq!(config.routed_scaling_factor, 1.8);
|
|
}
|
|
|
|
// =========================================================================
|
|
// Generation Stats tests
|
|
// =========================================================================
|
|
|
|
#[test]
|
|
fn test_generation_stats_struct() {
|
|
let stats = GenerationStats {
|
|
prompt_tokens: 10,
|
|
generated_tokens: 50,
|
|
total_tokens: 60,
|
|
elapsed_ms: 1000,
|
|
tokens_per_second: 50.0,
|
|
};
|
|
assert_eq!(stats.prompt_tokens, 10);
|
|
assert_eq!(stats.generated_tokens, 50);
|
|
assert_eq!(stats.total_tokens, 60);
|
|
assert_eq!(stats.elapsed_ms, 1000);
|
|
assert!((stats.tokens_per_second - 50.0).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generation_stats_zero_elapsed() {
|
|
let stats = GenerationStats {
|
|
prompt_tokens: 5,
|
|
generated_tokens: 0,
|
|
total_tokens: 5,
|
|
elapsed_ms: 0,
|
|
tokens_per_second: 0.0,
|
|
};
|
|
assert_eq!(stats.generated_tokens, 0);
|
|
assert_eq!(stats.tokens_per_second, 0.0);
|
|
}
|
|
|
|
// =========================================================================
|
|
// Expert Predictor tests
|
|
// =========================================================================
|
|
|
|
#[test]
|
|
fn test_expert_predictor_from_empty_history() {
|
|
let predictor = ExpertPredictor::from_history(8, &[]);
|
|
assert_eq!(predictor.num_experts(), 8);
|
|
assert_eq!(predictor.total_observations(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_expert_predictor_from_single_entry() {
|
|
// Single entry = no transitions
|
|
let history = vec![vec![2, 5]];
|
|
let predictor = ExpertPredictor::from_history(8, &history);
|
|
assert_eq!(predictor.total_observations(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_expert_predictor_transition_counts() {
|
|
// Two entries: experts [2,5] -> experts [3,7]
|
|
// Expected transitions: 2->3, 2->7, 5->3, 5->7 (each count=1)
|
|
let history = vec![vec![2, 5], vec![3, 7]];
|
|
let predictor = ExpertPredictor::from_history(8, &history);
|
|
assert_eq!(predictor.total_observations(), 4);
|
|
|
|
// Transition probabilities should reflect counts + Laplace smoothing
|
|
let p_2_3 = predictor.transition_prob(2, 3);
|
|
let p_2_7 = predictor.transition_prob(2, 7);
|
|
let p_2_0 = predictor.transition_prob(2, 0); // unobserved
|
|
|
|
// 2->3 has count=1, total from expert 2 = 2, Laplace denom = 2+8=10
|
|
// p = (1+1)/10 = 0.2
|
|
assert!((p_2_3 - 0.2).abs() < 1e-6, "p(2->3)={}", p_2_3);
|
|
assert!((p_2_7 - 0.2).abs() < 1e-6, "p(2->7)={}", p_2_7);
|
|
// 2->0 has count=0, p = (0+1)/10 = 0.1
|
|
assert!((p_2_0 - 0.1).abs() < 1e-6, "p(2->0)={}", p_2_0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_expert_predictor_predict_next() {
|
|
// Build a history where expert 2 always transitions to expert 5
|
|
let history = vec![
|
|
vec![2],
|
|
vec![5],
|
|
vec![2],
|
|
vec![5],
|
|
vec![2],
|
|
vec![5],
|
|
vec![2],
|
|
vec![5],
|
|
];
|
|
let predictor = ExpertPredictor::from_history(8, &history);
|
|
|
|
// Given current = [2], predict next
|
|
let predicted = predictor.predict_next(&[2], 3);
|
|
|
|
// Expert 5 should be the top prediction (highest transition count)
|
|
assert!(!predicted.is_empty());
|
|
assert_eq!(predicted[0], 5, "Expert 5 should be top prediction");
|
|
}
|
|
|
|
#[test]
|
|
fn test_expert_predictor_excludes_current() {
|
|
// Build a history where expert 2 transitions to itself often
|
|
let history = vec![vec![2], vec![2], vec![2], vec![2]];
|
|
let predictor = ExpertPredictor::from_history(8, &history);
|
|
|
|
// Predict next given current=[2]; expert 2 should be excluded
|
|
let predicted = predictor.predict_next(&[2], 3);
|
|
assert!(
|
|
!predicted.contains(&2),
|
|
"Current experts should be excluded"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_expert_predictor_out_of_bounds() {
|
|
let predictor = ExpertPredictor::from_history(4, &[]);
|
|
assert_eq!(predictor.transition_prob(10, 0), 0.0);
|
|
assert_eq!(predictor.transition_prob(0, 10), 0.0);
|
|
|
|
// Predict with out-of-bounds experts should not panic
|
|
let predicted = predictor.predict_next(&[99], 2);
|
|
assert!(predicted.len() <= 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_expert_predictor_build_from_backend() {
|
|
let backend = BitNetBackend::new();
|
|
let history = vec![vec![1, 2], vec![3, 4]];
|
|
let predictor = backend.build_expert_predictor(&history);
|
|
assert_eq!(predictor.num_experts(), 64); // default config
|
|
}
|
|
|
|
// =========================================================================
|
|
// Compressed MLA Cache tests
|
|
// =========================================================================
|
|
|
|
#[test]
|
|
fn test_compressed_mla_cache_new() {
|
|
let cache = CompressedMlaCache::new();
|
|
assert_eq!(cache.len(), 0);
|
|
assert!(cache.is_empty());
|
|
assert_eq!(cache.memory_bytes(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compressed_mla_cache_push() {
|
|
let mut cache = CompressedMlaCache::new();
|
|
let c_kv = vec![1.0f32; 512]; // kv_lora_rank
|
|
let k_pe = vec![0.5f32; 64]; // qk_rope_head_dim
|
|
|
|
cache.push(c_kv, k_pe);
|
|
assert_eq!(cache.len(), 1);
|
|
assert!(!cache.is_empty());
|
|
|
|
// Memory: 512*4 + 64*4 = 2304 bytes
|
|
assert_eq!(cache.memory_bytes(), 2304);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compressed_mla_cache_clear() {
|
|
let mut cache = CompressedMlaCache::new();
|
|
cache.push(vec![1.0; 512], vec![0.5; 64]);
|
|
cache.push(vec![2.0; 512], vec![0.5; 64]);
|
|
assert_eq!(cache.len(), 2);
|
|
|
|
cache.clear();
|
|
assert_eq!(cache.len(), 0);
|
|
assert!(cache.is_empty());
|
|
assert_eq!(cache.memory_bytes(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compressed_mla_cache_savings_ratio() {
|
|
// GLM-4.7-Flash dimensions
|
|
let ratio = CompressedMlaCache::savings_ratio(
|
|
20, // num_heads
|
|
192, // qk_nope_head_dim
|
|
64, // qk_rope_head_dim
|
|
256, // v_head_dim
|
|
512, // kv_lora_rank
|
|
);
|
|
// Full K: 20 * 256 = 5120, Full V: 20 * 256 = 5120, total = 10240
|
|
// Compressed: 512 + 64 = 576
|
|
// Ratio: 10240 / 576 ≈ 17.78
|
|
assert!(ratio > 17.0, "Expected ~17.8x savings, got {}", ratio);
|
|
assert!(ratio < 18.5, "Expected ~17.8x savings, got {}", ratio);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compressed_mla_cache_multiple_positions() {
|
|
let mut cache = CompressedMlaCache::new();
|
|
for i in 0..100 {
|
|
cache.push(vec![i as f32; 512], vec![(i as f32) * 0.1; 64]);
|
|
}
|
|
assert_eq!(cache.len(), 100);
|
|
// 100 positions * (512 + 64) * 4 bytes = 230,400 bytes
|
|
assert_eq!(cache.memory_bytes(), 230_400);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compressed_vs_full_kv_memory() {
|
|
// Compare memory usage: compressed vs full cache for 1024 positions
|
|
let positions = 1024;
|
|
let config = BitNetModelConfig::default();
|
|
|
|
// Full KV cache per position:
|
|
let full_k_dim =
|
|
config.num_attention_heads * (config.qk_nope_head_dim + config.qk_rope_head_dim);
|
|
let full_v_dim = config.num_attention_heads * config.v_head_dim;
|
|
let full_per_pos = (full_k_dim + full_v_dim) * 4; // FP32
|
|
let full_total = full_per_pos * positions;
|
|
|
|
// Compressed cache per position:
|
|
let compressed_per_pos = (config.kv_lora_rank + config.qk_rope_head_dim) * 4;
|
|
let compressed_total = compressed_per_pos * positions;
|
|
|
|
// For 1024 positions, full = ~40 MB vs compressed = ~2.3 MB
|
|
assert!(
|
|
full_total > compressed_total * 10,
|
|
"Full ({} bytes) should be >10x compressed ({} bytes)",
|
|
full_total,
|
|
compressed_total
|
|
);
|
|
}
|
|
|
|
// =========================================================================
|
|
// End-to-end inference tests with synthetic model
|
|
// =========================================================================
|
|
|
|
/// Build a tiny synthetic model for E2E testing.
|
|
///
|
|
/// Config: 2 layers, hidden_size=8, vocab=16, 2 heads, 2 KV heads, GQA,
|
|
/// 2 experts (top-1), dense layer 0 + MoE layer 1, intermediate_size=4.
|
|
fn build_tiny_model() -> BitNetBackend {
|
|
let hidden = 8;
|
|
let vocab = 16;
|
|
let num_heads = 2;
|
|
let num_kv_heads = 2;
|
|
let head_dim = hidden / num_heads; // 4
|
|
let intermediate = 4;
|
|
let num_experts = 2;
|
|
|
|
// Helper: create a ternary tensor of given shape filled with +1
|
|
let make_ternary = |rows: usize, cols: usize| -> TernaryTensor {
|
|
let ternary_vals: Vec<i8> = (0..rows * cols)
|
|
.map(|i| match i % 3 {
|
|
0 => 1,
|
|
1 => -1,
|
|
_ => 0,
|
|
})
|
|
.collect();
|
|
let packed = pack_ternary(&ternary_vals);
|
|
let block_size = 256;
|
|
let blocks_per_row = (cols + block_size - 1) / block_size;
|
|
TernaryTensor {
|
|
packed_data: packed,
|
|
scales: vec![1.0; rows * blocks_per_row],
|
|
shape: (rows, cols),
|
|
block_size,
|
|
}
|
|
};
|
|
|
|
let make_expert = || ExpertWeights {
|
|
gate_proj: make_ternary(intermediate, hidden),
|
|
up_proj: make_ternary(intermediate, hidden),
|
|
down_proj: make_ternary(hidden, intermediate),
|
|
};
|
|
|
|
let make_gqa_attn = || AttentionWeights {
|
|
is_mla: false,
|
|
q_proj: make_ternary(hidden, hidden),
|
|
k_proj: make_ternary(num_kv_heads * head_dim, hidden),
|
|
v_proj: make_ternary(num_kv_heads * head_dim, hidden),
|
|
o_proj: make_ternary(hidden, hidden),
|
|
q_a: None,
|
|
q_b: None,
|
|
q_a_norm: None,
|
|
kv_a_mqa: None,
|
|
kv_a_norm: None,
|
|
k_b: None,
|
|
v_b: None,
|
|
};
|
|
|
|
// Layer 0: Dense FFN
|
|
let layer0 = TransformerLayer {
|
|
input_norm_weight: vec![1.0; hidden],
|
|
post_attn_norm_weight: vec![1.0; hidden],
|
|
attention: make_gqa_attn(),
|
|
layer_type: LayerType::Dense,
|
|
gate_weight: Vec::new(),
|
|
experts: Vec::new(),
|
|
shared_expert: None,
|
|
dense_ffn: Some(make_expert()),
|
|
};
|
|
|
|
// Layer 1: MoE with 2 experts, top-1
|
|
let layer1 = TransformerLayer {
|
|
input_norm_weight: vec![1.0; hidden],
|
|
post_attn_norm_weight: vec![1.0; hidden],
|
|
attention: make_gqa_attn(),
|
|
layer_type: LayerType::Moe,
|
|
gate_weight: vec![1.0; num_experts * hidden], // [2 experts, 8 hidden]
|
|
experts: vec![make_expert(), make_expert()],
|
|
shared_expert: None,
|
|
dense_ffn: None,
|
|
};
|
|
|
|
let config = BitNetModelConfig {
|
|
num_layers: 2,
|
|
hidden_size: hidden,
|
|
intermediate_size: intermediate,
|
|
vocab_size: vocab,
|
|
num_attention_heads: num_heads,
|
|
num_kv_heads,
|
|
num_experts,
|
|
active_experts: 1,
|
|
moe_intermediate_size: intermediate,
|
|
max_context: 64,
|
|
use_mla: false,
|
|
q_lora_rank: 0,
|
|
kv_lora_rank: 0,
|
|
qk_nope_head_dim: 0,
|
|
qk_rope_head_dim: 0,
|
|
v_head_dim: 0,
|
|
n_shared_experts: 0,
|
|
first_k_dense_replace: 1,
|
|
rope_theta: 10000.0,
|
|
routed_scaling_factor: 1.0,
|
|
};
|
|
|
|
// Build embedding table: [vocab * hidden] with simple deterministic pattern
|
|
let mut embedding = vec![0.0f32; vocab * hidden];
|
|
for tok in 0..vocab {
|
|
for d in 0..hidden {
|
|
embedding[tok * hidden + d] = ((tok * hidden + d) as f32 * 0.01).sin();
|
|
}
|
|
}
|
|
|
|
// LM head: [vocab * hidden] — simple identity-like
|
|
let mut lm_head = vec![0.0f32; vocab * hidden];
|
|
for tok in 0..vocab {
|
|
for d in 0..hidden {
|
|
lm_head[tok * hidden + d] = if d == tok % hidden { 1.0 } else { 0.0 };
|
|
}
|
|
}
|
|
|
|
let final_norm = vec![1.0; hidden];
|
|
|
|
let mut backend = BitNetBackend::new();
|
|
backend.config = Some(config.clone());
|
|
backend.embedding = embedding;
|
|
backend.lm_head = lm_head;
|
|
backend.final_norm_weight = final_norm;
|
|
backend.layers = vec![layer0, layer1];
|
|
backend.kv_caches = vec![LayerKvCache::new(), LayerKvCache::new()];
|
|
backend.mla_caches = vec![CompressedMlaCache::new(), CompressedMlaCache::new()];
|
|
backend.loaded = true;
|
|
backend.scratch.allocate(&config);
|
|
backend.build_rope_tables(
|
|
config.max_context.min(64),
|
|
hidden / num_heads,
|
|
config.rope_theta,
|
|
);
|
|
|
|
backend
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_forward_produces_logits() {
|
|
let backend = build_tiny_model();
|
|
let logits = backend.forward(&[0, 1, 2]).unwrap();
|
|
assert_eq!(logits.len(), 16, "Should produce vocab_size=16 logits");
|
|
|
|
// Logits should be finite
|
|
for (i, &l) in logits.iter().enumerate() {
|
|
assert!(l.is_finite(), "Logit {} is not finite: {}", i, l);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_forward_token_with_kv_cache() {
|
|
let mut backend = build_tiny_model();
|
|
backend.reset_cache();
|
|
|
|
// Process 3 tokens autoregressively
|
|
let logits_0 = backend.forward_token(0, 0).unwrap();
|
|
assert_eq!(logits_0.len(), 16);
|
|
|
|
let logits_1 = backend.forward_token(1, 1).unwrap();
|
|
assert_eq!(logits_1.len(), 16);
|
|
|
|
let logits_2 = backend.forward_token(2, 2).unwrap();
|
|
assert_eq!(logits_2.len(), 16);
|
|
|
|
// KV cache should have 3 positions per layer
|
|
assert_eq!(backend.kv_caches[0].len(), 3);
|
|
assert_eq!(backend.kv_caches[1].len(), 3);
|
|
|
|
// All logits should be finite
|
|
for &l in logits_2.iter() {
|
|
assert!(l.is_finite());
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_forward_deterministic() {
|
|
let backend = build_tiny_model();
|
|
let logits_a = backend.forward(&[3, 5, 7]).unwrap();
|
|
let logits_b = backend.forward(&[3, 5, 7]).unwrap();
|
|
|
|
// Same input should produce same output (no randomness)
|
|
for (a, b) in logits_a.iter().zip(logits_b.iter()) {
|
|
assert!(
|
|
(a - b).abs() < 1e-6,
|
|
"Forward should be deterministic: {} vs {}",
|
|
a,
|
|
b
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_forward_different_tokens_different_logits() {
|
|
let backend = build_tiny_model();
|
|
let logits_a = backend.forward(&[0]).unwrap();
|
|
let logits_b = backend.forward(&[1]).unwrap();
|
|
|
|
// Different tokens should produce different logits
|
|
let diff: f32 = logits_a
|
|
.iter()
|
|
.zip(logits_b.iter())
|
|
.map(|(a, b)| (a - b).abs())
|
|
.sum();
|
|
assert!(
|
|
diff > 1e-6,
|
|
"Different tokens should produce different logits, diff={}",
|
|
diff
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_expert_predictor_builds_from_inference() {
|
|
let mut backend = build_tiny_model();
|
|
backend.reset_cache();
|
|
|
|
// Run enough tokens to accumulate routing history and trigger predictor rebuild
|
|
for pos in 0..20 {
|
|
let _ = backend.forward_token(pos as u32 % 16, pos).unwrap();
|
|
}
|
|
|
|
// Predictor should have been built (rebuilds every 16 tokens)
|
|
assert!(
|
|
backend.expert_predictor.is_some(),
|
|
"Expert predictor should be built after 16+ tokens"
|
|
);
|
|
|
|
let predictor = backend.expert_predictor.as_ref().unwrap();
|
|
assert!(
|
|
predictor.total_observations() > 0,
|
|
"Predictor should have observations from routing history"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_forward_token_reset_cache() {
|
|
let mut backend = build_tiny_model();
|
|
|
|
// First sequence
|
|
let _ = backend.forward_token(0, 0).unwrap();
|
|
let _ = backend.forward_token(1, 1).unwrap();
|
|
assert_eq!(backend.kv_caches[0].len(), 2);
|
|
|
|
// Reset and start new sequence
|
|
backend.reset_cache();
|
|
assert_eq!(backend.kv_caches[0].len(), 0);
|
|
|
|
let logits = backend.forward_token(5, 0).unwrap();
|
|
assert_eq!(logits.len(), 16);
|
|
assert_eq!(backend.kv_caches[0].len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_compressed_kv_toggle() {
|
|
let mut backend = build_tiny_model();
|
|
|
|
// Default: compressed KV disabled
|
|
assert!(!backend.compressed_kv_enabled());
|
|
|
|
backend.set_compressed_kv(true);
|
|
assert!(backend.compressed_kv_enabled());
|
|
|
|
backend.set_compressed_kv(false);
|
|
assert!(!backend.compressed_kv_enabled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_scratch_pool_allocated() {
|
|
let backend = build_tiny_model();
|
|
|
|
// Scratch pool should be allocated after build
|
|
assert!(
|
|
backend.scratch.memory_bytes() > 0,
|
|
"Scratch pool should be allocated"
|
|
);
|
|
|
|
// Should have buffers for at least hidden_size (8)
|
|
assert!(backend.scratch.buf_hidden_a.len() >= 8);
|
|
assert!(backend.scratch.buf_ffn_gate.len() >= 4); // intermediate_size
|
|
}
|
|
|
|
// =========================================================================
|
|
// Benchmark-style performance tests
|
|
// =========================================================================
|
|
|
|
#[test]
|
|
fn test_bench_forward_token_throughput() {
|
|
let mut backend = build_tiny_model();
|
|
backend.reset_cache();
|
|
|
|
let start = std::time::Instant::now();
|
|
let num_tokens = 32;
|
|
for pos in 0..num_tokens {
|
|
let _ = backend.forward_token(pos as u32 % 16, pos).unwrap();
|
|
}
|
|
let elapsed = start.elapsed();
|
|
|
|
let tokens_per_sec = num_tokens as f64 / elapsed.as_secs_f64();
|
|
// Just verify it runs and is reasonably fast (should be >100 tok/s on any machine)
|
|
assert!(
|
|
tokens_per_sec > 10.0,
|
|
"Expected >10 tok/s for tiny model, got {:.1}",
|
|
tokens_per_sec
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_bench_tl1_gemv_dispatch_performance() {
|
|
let backend = BitNetBackend::new();
|
|
|
|
// Create a 64x64 ternary weight matrix
|
|
let vals: Vec<i8> = (0..64 * 64)
|
|
.map(|i| match i % 3 {
|
|
0 => 1,
|
|
1 => -1,
|
|
_ => 0,
|
|
})
|
|
.collect();
|
|
let packed = pack_ternary(&vals);
|
|
let weight = TernaryTensor {
|
|
packed_data: packed,
|
|
scales: vec![1.0; 64],
|
|
shape: (64, 64),
|
|
block_size: 256,
|
|
};
|
|
let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
|
|
|
|
let start = std::time::Instant::now();
|
|
let iters = 1000;
|
|
for _ in 0..iters {
|
|
let _ = backend.tl1_gemv(&weight, &input, 64, 64);
|
|
}
|
|
let elapsed = start.elapsed();
|
|
|
|
let gemvs_per_sec = iters as f64 / elapsed.as_secs_f64();
|
|
// Verify GEMV performance: should manage >10K/s for 64x64 on any machine
|
|
assert!(
|
|
gemvs_per_sec > 1000.0,
|
|
"Expected >1K GEMV/s for 64x64, got {:.1}",
|
|
gemvs_per_sec
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_bench_rms_norm_performance() {
|
|
let w = vec![1.0f32; 2048];
|
|
let mut x: Vec<f32> = (0..2048).map(|i| (i as f32) * 0.001).collect();
|
|
|
|
let start = std::time::Instant::now();
|
|
let iters = 10000;
|
|
for _ in 0..iters {
|
|
rms_norm_inplace(&mut x, &w, 1e-6);
|
|
}
|
|
let elapsed = start.elapsed();
|
|
|
|
let norms_per_sec = iters as f64 / elapsed.as_secs_f64();
|
|
assert!(
|
|
norms_per_sec > 10000.0,
|
|
"Expected >10K norms/s for dim=2048, got {:.1}",
|
|
norms_per_sec
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_bench_softmax_performance() {
|
|
let mut x: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.01).collect();
|
|
|
|
let start = std::time::Instant::now();
|
|
let iters = 10000;
|
|
for _ in 0..iters {
|
|
softmax_inplace(&mut x);
|
|
}
|
|
let elapsed = start.elapsed();
|
|
|
|
let ops_per_sec = iters as f64 / elapsed.as_secs_f64();
|
|
assert!(
|
|
ops_per_sec > 10000.0,
|
|
"Expected >10K softmax/s for dim=1024, got {:.1}",
|
|
ops_per_sec
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_bench_expert_forward_performance() {
|
|
let backend = BitNetBackend::new();
|
|
let config = BitNetModelConfig {
|
|
hidden_size: 64,
|
|
intermediate_size: 32,
|
|
moe_intermediate_size: 32,
|
|
..Default::default()
|
|
};
|
|
|
|
let vals: Vec<i8> = (0..32 * 64)
|
|
.map(|i| match i % 3 {
|
|
0 => 1,
|
|
1 => -1,
|
|
_ => 0,
|
|
})
|
|
.collect();
|
|
let packed = pack_ternary(&vals);
|
|
let make_t = |rows, cols| TernaryTensor {
|
|
packed_data: packed.clone(),
|
|
scales: vec![1.0; rows],
|
|
shape: (rows, cols),
|
|
block_size: 256,
|
|
};
|
|
|
|
let expert = ExpertWeights {
|
|
gate_proj: make_t(32, 64),
|
|
up_proj: make_t(32, 64),
|
|
down_proj: make_t(64, 32),
|
|
};
|
|
|
|
let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.01).collect();
|
|
|
|
let start = std::time::Instant::now();
|
|
let iters = 500;
|
|
for _ in 0..iters {
|
|
let _ = backend.expert_forward(&input, &expert, &config).unwrap();
|
|
}
|
|
let elapsed = start.elapsed();
|
|
|
|
let experts_per_sec = iters as f64 / elapsed.as_secs_f64();
|
|
assert!(
|
|
experts_per_sec > 100.0,
|
|
"Expected >100 expert_forward/s for 64→32→64, got {:.1}",
|
|
experts_per_sec
|
|
);
|
|
}
|
|
}
|