472 lines
16 KiB
Rust
472 lines
16 KiB
Rust
//! Real LLM Inference with CPU SIMD Optimization
|
|
//!
|
|
//! Uses candle for native Rust tensor operations with SIMD support (AVX2/AVX512).
|
|
//! Optimized for CPU sandbox environments with small, efficient models.
|
|
|
|
#[cfg(feature = "real-inference")]
|
|
mod real {
|
|
use candle_core::{DType, Device, Tensor, D};
|
|
use candle_nn::{linear, Linear, Module, VarBuilder};
|
|
use candle_transformers::models::quantized_llama as llama;
|
|
use hf_hub::{api::tokio::Api, Repo, RepoType};
|
|
use tokenizers::Tokenizer;
|
|
|
|
use crate::config::InferenceConfig;
|
|
use crate::error::{Error, InferenceError, Result};
|
|
use crate::types::ModelSize;
|
|
|
|
use dashmap::DashMap;
|
|
use parking_lot::RwLock;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
/// Supported small models optimized for CPU
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub enum SmallModel {
|
|
/// SmolLM 135M - Smallest viable model
|
|
SmolLM135M,
|
|
/// SmolLM 360M - Better quality, still fast
|
|
SmolLM360M,
|
|
/// Qwen2 0.5B - Good balance
|
|
Qwen2_500M,
|
|
/// TinyLlama 1.1B - Best quality for small
|
|
TinyLlama1B,
|
|
}
|
|
|
|
impl SmallModel {
|
|
pub fn repo_id(&self) -> &'static str {
|
|
match self {
|
|
SmallModel::SmolLM135M => "HuggingFaceTB/SmolLM-135M",
|
|
SmallModel::SmolLM360M => "HuggingFaceTB/SmolLM-360M",
|
|
SmallModel::Qwen2_500M => "Qwen/Qwen2-0.5B",
|
|
SmallModel::TinyLlama1B => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
|
}
|
|
}
|
|
|
|
pub fn quantized_repo(&self) -> &'static str {
|
|
match self {
|
|
SmallModel::SmolLM135M => "HuggingFaceTB/SmolLM-135M-GGUF",
|
|
SmallModel::SmolLM360M => "HuggingFaceTB/SmolLM-360M-GGUF",
|
|
SmallModel::Qwen2_500M => "Qwen/Qwen2-0.5B-GGUF",
|
|
SmallModel::TinyLlama1B => "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
|
}
|
|
}
|
|
|
|
pub fn gguf_file(&self) -> &'static str {
|
|
match self {
|
|
SmallModel::SmolLM135M => "smollm-135m-q4_k_m.gguf",
|
|
SmallModel::SmolLM360M => "smollm-360m-q4_k_m.gguf",
|
|
SmallModel::Qwen2_500M => "qwen2-0_5b-instruct-q4_k_m.gguf",
|
|
SmallModel::TinyLlama1B => "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
|
|
}
|
|
}
|
|
|
|
pub fn context_size(&self) -> usize {
|
|
match self {
|
|
SmallModel::SmolLM135M => 2048,
|
|
SmallModel::SmolLM360M => 2048,
|
|
SmallModel::Qwen2_500M => 4096,
|
|
SmallModel::TinyLlama1B => 2048,
|
|
}
|
|
}
|
|
|
|
pub fn from_model_size(size: ModelSize) -> Self {
|
|
match size {
|
|
ModelSize::M350 => SmallModel::SmolLM135M,
|
|
ModelSize::M700 => SmallModel::SmolLM360M,
|
|
ModelSize::B1_2 => SmallModel::Qwen2_500M,
|
|
ModelSize::B2_6 => SmallModel::TinyLlama1B,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Generation configuration
|
|
#[derive(Debug, Clone)]
|
|
pub struct GenerationConfig {
|
|
pub max_tokens: usize,
|
|
pub temperature: f32,
|
|
pub top_p: f32,
|
|
pub top_k: usize,
|
|
pub repeat_penalty: f32,
|
|
pub seed: u64,
|
|
}
|
|
|
|
impl Default for GenerationConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
max_tokens: 256,
|
|
temperature: 0.7,
|
|
top_p: 0.9,
|
|
top_k: 40,
|
|
repeat_penalty: 1.1,
|
|
seed: 42,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Generation result
|
|
#[derive(Debug, Clone)]
|
|
pub struct GenerationResult {
|
|
pub text: String,
|
|
pub tokens_generated: usize,
|
|
pub model_used: ModelSize,
|
|
pub cache_hit: bool,
|
|
pub inference_time_ms: f64,
|
|
pub tokens_per_second: f64,
|
|
}
|
|
|
|
/// KV Cache for efficient generation
|
|
struct KvCache {
|
|
key: Option<Tensor>,
|
|
value: Option<Tensor>,
|
|
seq_len: usize,
|
|
}
|
|
|
|
impl KvCache {
|
|
fn new() -> Self {
|
|
Self {
|
|
key: None,
|
|
value: None,
|
|
seq_len: 0,
|
|
}
|
|
}
|
|
|
|
fn append(&mut self, key: Tensor, value: Tensor) -> Result<(Tensor, Tensor)> {
|
|
let (key, value) = match (&self.key, &self.value) {
|
|
(Some(k), Some(v)) => {
|
|
let key = Tensor::cat(&[k, &key], 2)?;
|
|
let value = Tensor::cat(&[v, &value], 2)?;
|
|
(key, value)
|
|
}
|
|
_ => (key, value),
|
|
};
|
|
self.seq_len = key.dims()[2];
|
|
self.key = Some(key.clone());
|
|
self.value = Some(value.clone());
|
|
Ok((key, value))
|
|
}
|
|
|
|
fn reset(&mut self) {
|
|
self.key = None;
|
|
self.value = None;
|
|
self.seq_len = 0;
|
|
}
|
|
}
|
|
|
|
/// Real inference pool with CPU SIMD optimization
|
|
pub struct RealInferencePool {
|
|
/// Device (CPU with SIMD)
|
|
device: Device,
|
|
/// Loaded GGUF models
|
|
models: DashMap<SmallModel, Arc<llama::ModelWeights>>,
|
|
/// Tokenizers
|
|
tokenizers: DashMap<SmallModel, Arc<Tokenizer>>,
|
|
/// KV caches per session
|
|
kv_caches: DashMap<String, Vec<KvCache>>,
|
|
/// Configuration
|
|
config: InferenceConfig,
|
|
/// Model cache directory
|
|
cache_dir: PathBuf,
|
|
}
|
|
|
|
impl RealInferencePool {
|
|
/// Create new inference pool
|
|
pub async fn new(config: &InferenceConfig) -> Result<Self> {
|
|
// Use CPU device - candle will auto-detect SIMD capabilities
|
|
let device = Device::Cpu;
|
|
|
|
// Setup cache directory
|
|
let cache_dir = dirs::cache_dir()
|
|
.unwrap_or_else(|| PathBuf::from("."))
|
|
.join("ruvllm")
|
|
.join("models");
|
|
|
|
tokio::fs::create_dir_all(&cache_dir).await.map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!(
|
|
"Failed to create cache dir: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
Ok(Self {
|
|
device,
|
|
models: DashMap::new(),
|
|
tokenizers: DashMap::new(),
|
|
kv_caches: DashMap::new(),
|
|
config: config.clone(),
|
|
cache_dir,
|
|
})
|
|
}
|
|
|
|
/// Download and load a model
|
|
async fn load_model(&self, model: SmallModel) -> Result<Arc<llama::ModelWeights>> {
|
|
// Check if already loaded
|
|
if let Some(m) = self.models.get(&model) {
|
|
return Ok(m.clone());
|
|
}
|
|
|
|
tracing::info!("Downloading model: {:?}", model);
|
|
|
|
// Download from HuggingFace Hub
|
|
let api = Api::new().map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!("HF API error: {}", e)))
|
|
})?;
|
|
|
|
let repo = api.repo(Repo::with_revision(
|
|
model.quantized_repo().to_string(),
|
|
RepoType::Model,
|
|
"main".to_string(),
|
|
));
|
|
|
|
let model_path = repo.get(model.gguf_file()).await.map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!(
|
|
"Failed to download model: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
tracing::info!("Loading GGUF model from: {:?}", model_path);
|
|
|
|
// Load GGUF model with memory mapping for efficiency
|
|
let mut file = std::fs::File::open(&model_path).map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!(
|
|
"Failed to open model: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
let model_weights = llama::ModelWeights::from_gguf(file, &mut file, &self.device)
|
|
.map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!(
|
|
"Failed to load GGUF: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
let model_arc = Arc::new(model_weights);
|
|
self.models.insert(model, model_arc.clone());
|
|
|
|
Ok(model_arc)
|
|
}
|
|
|
|
/// Download and load tokenizer
|
|
async fn load_tokenizer(&self, model: SmallModel) -> Result<Arc<Tokenizer>> {
|
|
if let Some(t) = self.tokenizers.get(&model) {
|
|
return Ok(t.clone());
|
|
}
|
|
|
|
let api = Api::new().map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!("HF API error: {}", e)))
|
|
})?;
|
|
|
|
let repo = api.repo(Repo::new(model.repo_id().to_string(), RepoType::Model));
|
|
|
|
let tokenizer_path = repo.get("tokenizer.json").await.map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!(
|
|
"Failed to download tokenizer: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
|
|
Error::Inference(InferenceError::InitFailed(format!(
|
|
"Failed to load tokenizer: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
let tokenizer_arc = Arc::new(tokenizer);
|
|
self.tokenizers.insert(model, tokenizer_arc.clone());
|
|
|
|
Ok(tokenizer_arc)
|
|
}
|
|
|
|
/// Sample next token with temperature and top-p
|
|
fn sample_token(
|
|
&self,
|
|
logits: &Tensor,
|
|
config: &GenerationConfig,
|
|
generated_tokens: &[u32],
|
|
) -> Result<u32> {
|
|
let logits = logits.squeeze(0)?.squeeze(0)?;
|
|
let mut logits_vec: Vec<f32> = logits.to_vec1()?;
|
|
|
|
// Apply repeat penalty
|
|
for &token in generated_tokens {
|
|
if (token as usize) < logits_vec.len() {
|
|
logits_vec[token as usize] /= config.repeat_penalty;
|
|
}
|
|
}
|
|
|
|
// Apply temperature
|
|
if config.temperature > 0.0 {
|
|
for l in &mut logits_vec {
|
|
*l /= config.temperature;
|
|
}
|
|
}
|
|
|
|
// Softmax
|
|
let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
|
let mut probs: Vec<f32> = logits_vec.iter().map(|l| (l - max_logit).exp()).collect();
|
|
let sum: f32 = probs.iter().sum();
|
|
for p in &mut probs {
|
|
*p /= sum;
|
|
}
|
|
|
|
// Top-p sampling
|
|
let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
|
|
sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
|
|
|
|
let mut cumsum = 0.0;
|
|
let mut cutoff_idx = sorted_indices.len();
|
|
for (i, &idx) in sorted_indices.iter().enumerate() {
|
|
cumsum += probs[idx];
|
|
if cumsum > config.top_p {
|
|
cutoff_idx = i + 1;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Top-k limiting
|
|
cutoff_idx = cutoff_idx.min(config.top_k);
|
|
|
|
// Renormalize
|
|
let valid_indices: Vec<usize> = sorted_indices[..cutoff_idx].to_vec();
|
|
let mut valid_probs: Vec<f32> = valid_indices.iter().map(|&i| probs[i]).collect();
|
|
let sum: f32 = valid_probs.iter().sum();
|
|
for p in &mut valid_probs {
|
|
*p /= sum;
|
|
}
|
|
|
|
// Sample
|
|
use rand::Rng;
|
|
let mut rng = rand::thread_rng();
|
|
let r: f32 = rng.gen();
|
|
let mut cumsum = 0.0;
|
|
for (i, &p) in valid_probs.iter().enumerate() {
|
|
cumsum += p;
|
|
if r < cumsum {
|
|
return Ok(valid_indices[i] as u32);
|
|
}
|
|
}
|
|
|
|
Ok(valid_indices[0] as u32)
|
|
}
|
|
|
|
/// Generate text with real inference
|
|
pub async fn generate(
|
|
&self,
|
|
model_size: ModelSize,
|
|
prompt: &str,
|
|
config: GenerationConfig,
|
|
session_key: Option<&str>,
|
|
) -> Result<GenerationResult> {
|
|
let start = Instant::now();
|
|
let small_model = SmallModel::from_model_size(model_size);
|
|
|
|
// Load model and tokenizer
|
|
let model = self.load_model(small_model).await?;
|
|
let tokenizer = self.load_tokenizer(small_model).await?;
|
|
|
|
// Tokenize input
|
|
let encoding = tokenizer.encode(prompt, true).map_err(|e| {
|
|
Error::Inference(InferenceError::GenerationFailed(format!(
|
|
"Tokenization failed: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
let mut tokens: Vec<u32> = encoding.get_ids().to_vec();
|
|
let input_len = tokens.len();
|
|
|
|
// Initialize or get KV cache
|
|
let cache_key = session_key
|
|
.map(|s| s.to_string())
|
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
|
|
|
let num_layers = 12; // Typical for small models
|
|
if !self.kv_caches.contains_key(&cache_key) {
|
|
let caches: Vec<KvCache> = (0..num_layers).map(|_| KvCache::new()).collect();
|
|
self.kv_caches.insert(cache_key.clone(), caches);
|
|
}
|
|
|
|
// Generate tokens
|
|
let mut generated = Vec::new();
|
|
let eos_token = tokenizer
|
|
.token_to_id("</s>")
|
|
.or_else(|| tokenizer.token_to_id("<|endoftext|>"))
|
|
.unwrap_or(2);
|
|
|
|
for _ in 0..config.max_tokens {
|
|
// Create input tensor
|
|
let input = Tensor::new(&tokens[tokens.len() - 1..], &self.device)?;
|
|
let input = input.unsqueeze(0)?;
|
|
|
|
// Forward pass with SIMD-optimized operations
|
|
let logits = model.forward(&input, tokens.len() - 1)?;
|
|
|
|
// Sample next token
|
|
let next_token = self.sample_token(&logits, &config, &generated)?;
|
|
|
|
if next_token == eos_token {
|
|
break;
|
|
}
|
|
|
|
tokens.push(next_token);
|
|
generated.push(next_token);
|
|
}
|
|
|
|
// Decode output
|
|
let output_text = tokenizer.decode(&generated, true).map_err(|e| {
|
|
Error::Inference(InferenceError::GenerationFailed(format!(
|
|
"Decoding failed: {}",
|
|
e
|
|
)))
|
|
})?;
|
|
|
|
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
|
|
let tokens_per_second = if elapsed > 0.0 {
|
|
(generated.len() as f64 / elapsed) * 1000.0
|
|
} else {
|
|
0.0
|
|
};
|
|
|
|
Ok(GenerationResult {
|
|
text: output_text,
|
|
tokens_generated: generated.len(),
|
|
model_used: model_size,
|
|
cache_hit: session_key.is_some(),
|
|
inference_time_ms: elapsed,
|
|
tokens_per_second,
|
|
})
|
|
}
|
|
|
|
/// Get pool health info
|
|
pub async fn health_check(&self) -> Result<HealthInfo> {
|
|
Ok(HealthInfo {
|
|
loaded_models: self.models.len(),
|
|
loaded_tokenizers: self.tokenizers.len(),
|
|
active_sessions: self.kv_caches.len(),
|
|
device: "CPU (SIMD)".to_string(),
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Health information
|
|
#[derive(Debug, Clone)]
|
|
pub struct HealthInfo {
|
|
pub loaded_models: usize,
|
|
pub loaded_tokenizers: usize,
|
|
pub active_sessions: usize,
|
|
pub device: String,
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "real-inference")]
|
|
pub use real::*;
|
|
|
|
// Re-export types for non-real-inference builds
|
|
#[cfg(not(feature = "real-inference"))]
|
|
pub use crate::inference::{GenerationConfig, GenerationResult, HealthInfo, InferencePool};
|