//! Gemma-2 Model Architecture Implementation //! //! Google's Gemma-2 features advanced attention mechanisms: //! - **Logit soft-capping**: Stabilizes attention with `cap * tanh(x / cap)` //! - **Alternating local/global attention**: Odd layers use sliding window //! - **GeGLU activation**: Gated Linear Unit with GELU //! - **GQA**: Grouped Query Attention for memory efficiency //! - **Large head dimension**: 256 for improved representation //! //! ## Model Variants //! //! | Model | Hidden Size | Layers | Heads | KV Heads | Context | //! |-------|-------------|--------|-------|----------|---------| //! | Gemma-2-2B | 2304 | 26 | 8 | 4 | 8192 | //! | Gemma-2-9B | 3584 | 42 | 16 | 8 | 8192 | //! | Gemma-2-27B | 4608 | 46 | 32 | 16 | 8192 | //! //! ## Example //! //! ```rust,ignore //! use ruvllm::backends::gemma2::{Gemma2Config, Gemma2Model}; //! //! let config = Gemma2Config::gemma2_9b(); //! let model = Gemma2Model::new(&config)?; //! //! let output = model.forward(&input_ids, &attention_mask, None)?; //! ``` use crate::error::{Result, RuvLLMError}; use crate::kernels::rope::{precompute_rope_tables_with_config, RopeConfig, RopeTables}; use crate::kernels::{apply_rope_neon, flash_attention_neon, rms_norm_neon, AttentionConfig}; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; /// Soft-capping value for attention logits pub const ATTENTION_SOFTCAP: f32 = 50.0; /// Soft-capping value for final logits pub const FINAL_LOGIT_SOFTCAP: f32 = 30.0; /// Gemma-2 model configuration #[derive(Debug, Clone)] pub struct Gemma2Config { /// Hidden size (embedding dimension) pub hidden_size: usize, /// Intermediate size for MLP pub intermediate_size: usize, /// Number of hidden layers pub num_hidden_layers: usize, /// Number of attention heads pub num_attention_heads: usize, /// Number of key-value heads (GQA) pub num_kv_heads: usize, /// Vocabulary size pub vocab_size: usize, /// Maximum position embeddings pub max_position_embeddings: usize, /// RoPE base frequency pub rope_theta: f32, /// RMS norm epsilon pub rms_norm_eps: f32, /// Sliding window size for local attention layers pub sliding_window: usize, /// Head dimension (typically 256 for Gemma-2) pub head_dim: usize, /// Query pre-attention normalization pub query_pre_attn_scalar: f32, /// Attention logit soft-capping value pub attn_logit_softcapping: f32, /// Final logit soft-capping value pub final_logit_softcapping: f32, /// Whether to use flash attention pub use_flash_attention: bool, /// BOS token ID pub bos_token_id: u32, /// EOS token ID pub eos_token_id: u32, } impl Default for Gemma2Config { fn default() -> Self { Self::gemma2_9b() } } impl Gemma2Config { /// Gemma-2 2B configuration pub fn gemma2_2b() -> Self { Self { hidden_size: 2304, intermediate_size: 9216, num_hidden_layers: 26, num_attention_heads: 8, num_kv_heads: 4, vocab_size: 256000, max_position_embeddings: 8192, rope_theta: 10000.0, rms_norm_eps: 1e-6, sliding_window: 4096, head_dim: 256, query_pre_attn_scalar: 256.0_f32.sqrt().recip(), attn_logit_softcapping: ATTENTION_SOFTCAP, final_logit_softcapping: FINAL_LOGIT_SOFTCAP, use_flash_attention: true, bos_token_id: 2, eos_token_id: 1, } } /// Gemma-2 9B configuration pub fn gemma2_9b() -> Self { Self { hidden_size: 3584, intermediate_size: 14336, num_hidden_layers: 42, num_attention_heads: 16, num_kv_heads: 8, vocab_size: 256000, max_position_embeddings: 8192, rope_theta: 10000.0, rms_norm_eps: 1e-6, sliding_window: 4096, head_dim: 256, query_pre_attn_scalar: 256.0_f32.sqrt().recip(), attn_logit_softcapping: ATTENTION_SOFTCAP, final_logit_softcapping: FINAL_LOGIT_SOFTCAP, use_flash_attention: true, bos_token_id: 2, eos_token_id: 1, } } /// Gemma-2 27B configuration pub fn gemma2_27b() -> Self { Self { hidden_size: 4608, intermediate_size: 36864, num_hidden_layers: 46, num_attention_heads: 32, num_kv_heads: 16, vocab_size: 256000, max_position_embeddings: 8192, rope_theta: 10000.0, rms_norm_eps: 1e-6, sliding_window: 4096, head_dim: 256, query_pre_attn_scalar: 256.0_f32.sqrt().recip(), attn_logit_softcapping: ATTENTION_SOFTCAP, final_logit_softcapping: FINAL_LOGIT_SOFTCAP, use_flash_attention: true, bos_token_id: 2, eos_token_id: 1, } } /// Get GQA ratio pub fn gqa_ratio(&self) -> usize { self.num_attention_heads / self.num_kv_heads } /// Get the attention configuration pub fn attention_config(&self) -> AttentionConfig { AttentionConfig { num_heads: self.num_attention_heads, num_kv_heads: self.num_kv_heads, head_dim: self.head_dim, max_seq_len: self.max_position_embeddings, causal: true, scale: self.query_pre_attn_scalar, } } /// Get the RoPE configuration pub fn rope_config(&self) -> RopeConfig { RopeConfig { base: self.rope_theta, head_dim: self.head_dim, max_seq_len: self.max_position_embeddings, scaling_factor: 1.0, ntk_aware: false, original_max_len: self.max_position_embeddings, } } /// Check if a layer uses local (sliding window) attention /// /// Gemma-2 alternates: even layers = global, odd layers = local pub fn is_local_attention_layer(&self, layer_idx: usize) -> bool { layer_idx % 2 == 1 } } /// Apply logit soft-capping: cap * tanh(x / cap) /// /// This prevents attention scores from becoming too large, /// improving training stability and generation quality. /// /// # Arguments /// * `x` - Input logits (modified in-place) /// * `cap` - Soft-capping value (typically 50.0 for attention) #[inline(always)] pub fn logit_soft_cap(x: &mut [f32], cap: f32) { #[cfg(target_arch = "aarch64")] unsafe { logit_soft_cap_neon(x, cap); } #[cfg(not(target_arch = "aarch64"))] { let inv_cap = 1.0 / cap; for v in x.iter_mut() { *v = cap * (*v * inv_cap).tanh(); } } } /// NEON-optimized logit soft-capping #[cfg(target_arch = "aarch64")] #[inline(always)] unsafe fn logit_soft_cap_neon(x: &mut [f32], cap: f32) { let cap_vec = vdupq_n_f32(cap); let inv_cap = 1.0 / cap; let inv_cap_vec = vdupq_n_f32(inv_cap); let ptr = x.as_mut_ptr(); let len = x.len(); let mut i = 0; // Process 4 elements at a time while i + 4 <= len { let v = vld1q_f32(ptr.add(i)); // Compute v / cap let scaled = vmulq_f32(v, inv_cap_vec); // Compute tanh using approximation or element-wise // tanh(x) ~ x for small x, tanh(x) ~ sign(x) for large x // Using element-wise for accuracy let t0 = (vgetq_lane_f32(scaled, 0)).tanh(); let t1 = (vgetq_lane_f32(scaled, 1)).tanh(); let t2 = (vgetq_lane_f32(scaled, 2)).tanh(); let t3 = (vgetq_lane_f32(scaled, 3)).tanh(); let tanh_vec = vsetq_lane_f32( t3, vsetq_lane_f32( t2, vsetq_lane_f32(t1, vsetq_lane_f32(t0, vdupq_n_f32(0.0), 0), 1), 2, ), 3, ); // Multiply by cap let result = vmulq_f32(tanh_vec, cap_vec); vst1q_f32(ptr.add(i), result); i += 4; } // Handle remainder while i < len { x[i] = cap * (x[i] * inv_cap).tanh(); i += 1; } } /// Gemma-2 Attention layer with soft-capping and alternating local/global #[derive(Debug)] pub struct Gemma2Attention { /// Query projection weights pub q_proj: Vec, /// Key projection weights pub k_proj: Vec, /// Value projection weights pub v_proj: Vec, /// Output projection weights pub o_proj: Vec, /// Configuration pub config: Gemma2Config, /// Layer index (for alternating attention) pub layer_idx: usize, /// Precomputed RoPE tables pub rope_tables: RopeTables, } impl Gemma2Attention { /// Create a new Gemma2Attention layer pub fn new(config: &Gemma2Config, layer_idx: usize) -> Self { let hidden_size = config.hidden_size; let num_heads = config.num_attention_heads; let num_kv_heads = config.num_kv_heads; let head_dim = config.head_dim; Self { q_proj: vec![0.0; num_heads * head_dim * hidden_size], k_proj: vec![0.0; num_kv_heads * head_dim * hidden_size], v_proj: vec![0.0; num_kv_heads * head_dim * hidden_size], o_proj: vec![0.0; hidden_size * num_heads * head_dim], config: config.clone(), layer_idx, rope_tables: precompute_rope_tables_with_config(&config.rope_config()), } } /// Load weights pub fn load_weights( &mut self, q_proj: &[f32], k_proj: &[f32], v_proj: &[f32], o_proj: &[f32], ) -> Result<()> { if q_proj.len() != self.q_proj.len() || k_proj.len() != self.k_proj.len() || v_proj.len() != self.v_proj.len() || o_proj.len() != self.o_proj.len() { return Err(RuvLLMError::Model( "Invalid attention weight dimensions".to_string(), )); } self.q_proj.copy_from_slice(q_proj); self.k_proj.copy_from_slice(k_proj); self.v_proj.copy_from_slice(v_proj); self.o_proj.copy_from_slice(o_proj); Ok(()) } /// Forward pass with soft-capped attention and alternating local/global pub fn forward( &self, hidden_states: &[f32], positions: &[usize], kv_cache: Option<(&mut Vec, &mut Vec)>, ) -> Result> { let seq_len = positions.len(); let hidden_size = self.config.hidden_size; let num_heads = self.config.num_attention_heads; let num_kv_heads = self.config.num_kv_heads; let head_dim = self.config.head_dim; let gqa_ratio = self.config.gqa_ratio(); if hidden_states.len() != seq_len * hidden_size { return Err(RuvLLMError::InvalidOperation(format!( "Invalid hidden_states shape: expected {}, got {}", seq_len * hidden_size, hidden_states.len() ))); } // Project to Q, K, V let mut query = self.linear_transform( hidden_states, &self.q_proj, hidden_size, num_heads * head_dim, ); let mut key = self.linear_transform( hidden_states, &self.k_proj, hidden_size, num_kv_heads * head_dim, ); let value = self.linear_transform( hidden_states, &self.v_proj, hidden_size, num_kv_heads * head_dim, ); // Apply RoPE self.apply_rope(&mut query, positions, num_heads); self.apply_rope(&mut key, positions, num_kv_heads); // Handle KV cache let (key_states, value_states) = if let Some((k_cache, v_cache)) = kv_cache { k_cache.extend_from_slice(&key); v_cache.extend_from_slice(&value); (k_cache.as_slice(), v_cache.as_slice()) } else { (key.as_slice(), value.as_slice()) }; let kv_len = key_states.len() / (num_kv_heads * head_dim); // Determine if this layer uses local (sliding window) or global attention let is_local = self.config.is_local_attention_layer(self.layer_idx); let effective_window = if is_local { Some(self.config.sliding_window) } else { None }; // Compute attention with soft-capping let scale = self.config.query_pre_attn_scalar; let mut output = vec![0.0; seq_len * num_heads * head_dim]; for h in 0..num_heads { let kv_head = h / gqa_ratio; for t in 0..seq_len { // Extract query for this head and position let q_offset = (t * num_heads + h) * head_dim; let q_slice = &query[q_offset..q_offset + head_dim]; // Determine attention range based on local/global let (start_pos, end_pos) = if let Some(window) = effective_window { let pos = positions[t]; let start = pos.saturating_sub(window); (start, kv_len) } else { (0, kv_len) }; // Extract keys and values for this KV head let effective_kv_len = end_pos - start_pos; let mut k_slice = Vec::with_capacity(effective_kv_len * head_dim); let mut v_slice = Vec::with_capacity(effective_kv_len * head_dim); for kv_t in start_pos..end_pos { let kv_offset = (kv_t * num_kv_heads + kv_head) * head_dim; k_slice.extend_from_slice(&key_states[kv_offset..kv_offset + head_dim]); v_slice.extend_from_slice(&value_states[kv_offset..kv_offset + head_dim]); } // Compute attention scores let mut scores = self.compute_attention_scores(q_slice, &k_slice, scale); // Apply soft-capping to attention logits logit_soft_cap(&mut scores, self.config.attn_logit_softcapping); // Apply causal mask (for positions after current) let current_pos = positions[t]; for (i, score) in scores.iter_mut().enumerate() { let kv_pos = start_pos + i; if kv_pos > current_pos { *score = f32::NEG_INFINITY; } } // Softmax let attn_weights = self.softmax(&scores); // Weighted sum of values let mut head_output = vec![0.0; head_dim]; for (i, &weight) in attn_weights.iter().enumerate() { for d in 0..head_dim { head_output[d] += weight * v_slice[i * head_dim + d]; } } // Write output let out_offset = (t * num_heads + h) * head_dim; output[out_offset..out_offset + head_dim].copy_from_slice(&head_output); } } // Output projection let output = self.linear_transform(&output, &self.o_proj, num_heads * head_dim, hidden_size); Ok(output) } /// Compute attention scores with proper scaling fn compute_attention_scores(&self, query: &[f32], keys: &[f32], scale: f32) -> Vec { let head_dim = query.len(); let kv_len = keys.len() / head_dim; let mut scores = vec![0.0; kv_len]; for t in 0..kv_len { let k_offset = t * head_dim; let mut score = 0.0; for d in 0..head_dim { score += query[d] * keys[k_offset + d]; } scores[t] = score * scale; } scores } /// Softmax normalization fn softmax(&self, x: &[f32]) -> Vec { let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let exp_vals: Vec = x.iter().map(|&v| (v - max_val).exp()).collect(); let sum: f32 = exp_vals.iter().sum(); exp_vals.iter().map(|&v| v / sum).collect() } /// Apply RoPE to query or key tensors fn apply_rope(&self, x: &mut [f32], positions: &[usize], num_heads: usize) { let head_dim = self.config.head_dim; let seq_len = positions.len(); for h in 0..num_heads { for t in 0..seq_len { let offset = (t * num_heads + h) * head_dim; let mut head_vec = x[offset..offset + head_dim].to_vec(); apply_rope_neon( &mut head_vec, &[positions[t]], head_dim, self.config.rope_theta, ); x[offset..offset + head_dim].copy_from_slice(&head_vec); } } } /// Linear transformation fn linear_transform( &self, input: &[f32], weights: &[f32], in_dim: usize, out_dim: usize, ) -> Vec { let batch_size = input.len() / in_dim; let mut output = vec![0.0; batch_size * out_dim]; for b in 0..batch_size { for o in 0..out_dim { let mut sum = 0.0; for i in 0..in_dim { sum += input[b * in_dim + i] * weights[o * in_dim + i]; } output[b * out_dim + o] = sum; } } output } } /// Gemma-2 MLP layer with GeGLU activation /// /// GeGLU combines gating with GELU activation: /// ```text /// MLP(x) = down_proj(GELU(gate_proj(x)) * up_proj(x)) /// ``` #[derive(Debug)] pub struct Gemma2MLP { /// Gate projection weights pub gate_proj: Vec, /// Up projection weights pub up_proj: Vec, /// Down projection weights pub down_proj: Vec, /// Hidden size pub hidden_size: usize, /// Intermediate size pub intermediate_size: usize, } impl Gemma2MLP { /// Create a new Gemma2MLP layer pub fn new(config: &Gemma2Config) -> Self { Self { gate_proj: vec![0.0; config.intermediate_size * config.hidden_size], up_proj: vec![0.0; config.intermediate_size * config.hidden_size], down_proj: vec![0.0; config.hidden_size * config.intermediate_size], hidden_size: config.hidden_size, intermediate_size: config.intermediate_size, } } /// Load weights pub fn load_weights( &mut self, gate_proj: &[f32], up_proj: &[f32], down_proj: &[f32], ) -> Result<()> { let gate_up_size = self.intermediate_size * self.hidden_size; let down_size = self.hidden_size * self.intermediate_size; if gate_proj.len() != gate_up_size || up_proj.len() != gate_up_size || down_proj.len() != down_size { return Err(RuvLLMError::Model( "Invalid MLP weight dimensions".to_string(), )); } self.gate_proj.copy_from_slice(gate_proj); self.up_proj.copy_from_slice(up_proj); self.down_proj.copy_from_slice(down_proj); Ok(()) } /// Forward pass with GeGLU activation pub fn forward(&self, hidden_states: &[f32]) -> Result> { let batch_size = hidden_states.len() / self.hidden_size; // Gate projection + GELU let gate = self.linear( hidden_states, &self.gate_proj, self.hidden_size, self.intermediate_size, ); let gate_activated = self.gelu(&gate); // Up projection let up = self.linear( hidden_states, &self.up_proj, self.hidden_size, self.intermediate_size, ); // Element-wise multiply (gating) let hidden: Vec = gate_activated .iter() .zip(up.iter()) .map(|(g, u)| g * u) .collect(); // Down projection let output = self.linear( &hidden, &self.down_proj, self.intermediate_size, self.hidden_size, ); Ok(output) } /// Linear transformation fn linear(&self, input: &[f32], weights: &[f32], in_dim: usize, out_dim: usize) -> Vec { let batch_size = input.len() / in_dim; let mut output = vec![0.0; batch_size * out_dim]; #[cfg(target_arch = "aarch64")] unsafe { self.linear_neon(input, weights, &mut output, batch_size, in_dim, out_dim); } #[cfg(not(target_arch = "aarch64"))] { for b in 0..batch_size { for o in 0..out_dim { let mut sum = 0.0; for i in 0..in_dim { sum += input[b * in_dim + i] * weights[o * in_dim + i]; } output[b * out_dim + o] = sum; } } } output } /// NEON-optimized linear transformation #[cfg(target_arch = "aarch64")] unsafe fn linear_neon( &self, input: &[f32], weights: &[f32], output: &mut [f32], batch_size: usize, in_dim: usize, out_dim: usize, ) { let in_ptr: *const f32 = input.as_ptr(); let w_ptr: *const f32 = weights.as_ptr(); let out_ptr: *mut f32 = output.as_mut_ptr(); for b in 0..batch_size { for o in 0..out_dim { let mut acc = vdupq_n_f32(0.0); let mut i = 0; while i + 4 <= in_dim { let x = vld1q_f32(in_ptr.add(b * in_dim + i)); let w = vld1q_f32(w_ptr.add(o * in_dim + i)); acc = vfmaq_f32(acc, x, w); i += 4; } let mut sum = vaddvq_f32(acc); while i < in_dim { sum += *in_ptr.add(b * in_dim + i) * *w_ptr.add(o * in_dim + i); i += 1; } *out_ptr.add(b * out_dim + o) = sum; } } } /// GELU activation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) /// /// Uses the vectorized NEON implementation from the activations module /// for ~3.2x speedup over the previous scalar-in-vector approach. fn gelu(&self, x: &[f32]) -> Vec { crate::kernels::gelu_vec(x) } } /// Gemma-2 Decoder Layer #[derive(Debug)] pub struct Gemma2DecoderLayer { /// Self attention pub self_attn: Gemma2Attention, /// MLP pub mlp: Gemma2MLP, /// Input layer norm weights pub input_layernorm: Vec, /// Post-attention layer norm weights pub post_attention_layernorm: Vec, /// Pre-feedforward layer norm pub pre_feedforward_layernorm: Vec, /// Post-feedforward layer norm pub post_feedforward_layernorm: Vec, /// RMS norm epsilon pub rms_norm_eps: f32, /// Hidden size pub hidden_size: usize, } impl Gemma2DecoderLayer { /// Create a new decoder layer pub fn new(config: &Gemma2Config, layer_idx: usize) -> Self { Self { self_attn: Gemma2Attention::new(config, layer_idx), mlp: Gemma2MLP::new(config), input_layernorm: vec![1.0; config.hidden_size], post_attention_layernorm: vec![1.0; config.hidden_size], pre_feedforward_layernorm: vec![1.0; config.hidden_size], post_feedforward_layernorm: vec![1.0; config.hidden_size], rms_norm_eps: config.rms_norm_eps, hidden_size: config.hidden_size, } } /// Forward pass pub fn forward( &self, hidden_states: &[f32], positions: &[usize], kv_cache: Option<(&mut Vec, &mut Vec)>, ) -> Result> { let seq_len = positions.len(); // Pre-norm for attention let mut normed = hidden_states.to_vec(); for t in 0..seq_len { let offset = t * self.hidden_size; let slice = &mut normed[offset..offset + self.hidden_size]; rms_norm_neon(slice, &self.input_layernorm, self.rms_norm_eps); } // Self attention let attn_output = self.self_attn.forward(&normed, positions, kv_cache)?; // Post-attention norm let mut attn_normed = attn_output.clone(); for t in 0..seq_len { let offset = t * self.hidden_size; let slice = &mut attn_normed[offset..offset + self.hidden_size]; rms_norm_neon(slice, &self.post_attention_layernorm, self.rms_norm_eps); } // Residual connection let mut hidden: Vec = hidden_states .iter() .zip(attn_normed.iter()) .map(|(h, a)| h + a) .collect(); // Pre-feedforward norm let mut ff_normed = hidden.clone(); for t in 0..seq_len { let offset = t * self.hidden_size; let slice = &mut ff_normed[offset..offset + self.hidden_size]; rms_norm_neon(slice, &self.pre_feedforward_layernorm, self.rms_norm_eps); } // MLP let mlp_output = self.mlp.forward(&ff_normed)?; // Post-feedforward norm let mut mlp_normed = mlp_output.clone(); for t in 0..seq_len { let offset = t * self.hidden_size; let slice = &mut mlp_normed[offset..offset + self.hidden_size]; rms_norm_neon(slice, &self.post_feedforward_layernorm, self.rms_norm_eps); } // Residual connection for (h, m) in hidden.iter_mut().zip(mlp_normed.iter()) { *h += m; } Ok(hidden) } } /// Complete Gemma-2 Model #[derive(Debug)] pub struct Gemma2Model { /// Model configuration pub config: Gemma2Config, /// Token embeddings pub embed_tokens: Vec, /// Decoder layers pub layers: Vec, /// Final layer norm pub norm: Vec, /// LM head weights pub lm_head: Option>, /// Whether lm_head is tied to embeddings pub tie_word_embeddings: bool, } impl Gemma2Model { /// Create a new Gemma-2 model pub fn new(config: &Gemma2Config) -> Result { let mut layers = Vec::with_capacity(config.num_hidden_layers); for i in 0..config.num_hidden_layers { layers.push(Gemma2DecoderLayer::new(config, i)); } Ok(Self { config: config.clone(), embed_tokens: vec![0.0; config.vocab_size * config.hidden_size], layers, norm: vec![1.0; config.hidden_size], lm_head: None, tie_word_embeddings: true, }) } /// Forward pass pub fn forward( &self, input_ids: &[u32], positions: &[usize], mut kv_caches: Option<&mut Vec<(Vec, Vec)>>, ) -> Result> { let seq_len = positions.len(); if input_ids.len() != seq_len { return Err(RuvLLMError::InvalidOperation(format!( "input_ids length {} != positions length {}", input_ids.len(), seq_len ))); } // Token embeddings (normalized by sqrt(hidden_size) for Gemma) let embed_scale = (self.config.hidden_size as f32).sqrt(); let mut hidden_states = Vec::with_capacity(seq_len * self.config.hidden_size); for &token_id in input_ids { let offset = (token_id as usize) * self.config.hidden_size; if offset + self.config.hidden_size > self.embed_tokens.len() { return Err(RuvLLMError::InvalidOperation(format!( "Token ID {} out of vocabulary bounds", token_id ))); } for i in 0..self.config.hidden_size { hidden_states.push(self.embed_tokens[offset + i] * embed_scale); } } // Process through decoder layers for (layer_idx, layer) in self.layers.iter().enumerate() { let kv_cache = kv_caches.as_mut().map(|caches| { while caches.len() <= layer_idx { caches.push((Vec::new(), Vec::new())); } let (k, v) = &mut caches[layer_idx]; (k, v) }); hidden_states = layer.forward(&hidden_states, positions, kv_cache)?; } // Final norm for t in 0..seq_len { let offset = t * self.config.hidden_size; let slice = &mut hidden_states[offset..offset + self.config.hidden_size]; rms_norm_neon(slice, &self.norm, self.config.rms_norm_eps); } // LM head let lm_weights = if self.tie_word_embeddings { &self.embed_tokens } else { self.lm_head .as_ref() .ok_or_else(|| RuvLLMError::InvalidOperation("No LM head weights".to_string()))? }; // Compute logits with soft-capping let mut logits = vec![0.0; seq_len * self.config.vocab_size]; for t in 0..seq_len { for v in 0..self.config.vocab_size { let mut sum = 0.0; for h in 0..self.config.hidden_size { sum += hidden_states[t * self.config.hidden_size + h] * lm_weights[v * self.config.hidden_size + h]; } logits[t * self.config.vocab_size + v] = sum; } // Apply final logit soft-capping let logit_slice = &mut logits[t * self.config.vocab_size..(t + 1) * self.config.vocab_size]; logit_soft_cap(logit_slice, self.config.final_logit_softcapping); } Ok(logits) } /// Generate Gemma-2 chat template format /// /// Gemma-2 uses: `user\n{content}\nmodel` pub fn apply_chat_template(messages: &[(String, String)]) -> String { let mut result = String::new(); for (role, content) in messages { result.push_str(&format!( "{}\n{}\n", role, content )); } result.push_str("model\n"); result } /// Load model weights from GGUF format #[cfg(feature = "candle")] pub fn from_gguf(_path: &std::path::Path) -> Result { Err(RuvLLMError::NotFound( "GGUF loading not yet implemented for Gemma-2".to_string(), )) } /// Load model weights from safetensors format #[cfg(feature = "candle")] pub fn from_safetensors(_path: &std::path::Path) -> Result { Err(RuvLLMError::NotFound( "Safetensors loading not yet implemented for Gemma-2".to_string(), )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_gemma2_config() { let config = Gemma2Config::gemma2_9b(); assert_eq!(config.hidden_size, 3584); assert_eq!(config.num_hidden_layers, 42); assert_eq!(config.head_dim, 256); assert_eq!(config.gqa_ratio(), 2); } #[test] fn test_gemma2_config_2b() { let config = Gemma2Config::gemma2_2b(); assert_eq!(config.hidden_size, 2304); assert_eq!(config.num_hidden_layers, 26); assert_eq!(config.gqa_ratio(), 2); } #[test] fn test_local_attention_alternation() { let config = Gemma2Config::gemma2_9b(); assert!(!config.is_local_attention_layer(0)); // Global assert!(config.is_local_attention_layer(1)); // Local assert!(!config.is_local_attention_layer(2)); // Global assert!(config.is_local_attention_layer(3)); // Local } #[test] fn test_logit_soft_cap() { let mut x = vec![0.0, 10.0, -10.0, 100.0, -100.0]; logit_soft_cap(&mut x, 50.0); // cap * tanh(x / cap) // tanh(0) = 0 assert!((x[0]).abs() < 1e-5); // tanh(10/50) ~ 0.197, so 50 * 0.197 ~ 9.85 assert!((x[1] - 9.866).abs() < 0.1); // tanh(-10/50) ~ -0.197 assert!((x[2] - (-9.866)).abs() < 0.1); // tanh(100/50) = tanh(2) ~ 0.964, so 50 * 0.964 ~ 48.2 assert!((x[3] - 48.2).abs() < 0.5); // Should be bounded by cap assert!(x[3].abs() < 50.0); assert!(x[4].abs() < 50.0); } #[test] fn test_gemma2_mlp_gelu() { let config = Gemma2Config::gemma2_2b(); let mlp = Gemma2MLP::new(&config); // Test GELU activation let input = vec![0.0, 1.0, -1.0, 2.0]; let output = mlp.gelu(&input); // GELU(0) = 0 assert!((output[0]).abs() < 1e-5); // GELU(1) ~ 0.841 assert!((output[1] - 0.841).abs() < 0.01); // GELU(-1) ~ -0.159 assert!((output[2] - (-0.159)).abs() < 0.01); } #[test] fn test_gemma2_model_creation() { let config = Gemma2Config::gemma2_2b(); let model = Gemma2Model::new(&config).unwrap(); assert_eq!(model.layers.len(), 26); assert_eq!( model.embed_tokens.len(), config.vocab_size * config.hidden_size ); } #[test] fn test_chat_template() { let messages = vec![ ("user".to_string(), "Hello!".to_string()), ("model".to_string(), "Hi there!".to_string()), ("user".to_string(), "How are you?".to_string()), ]; let template = Gemma2Model::apply_chat_template(&messages); assert!(template.contains("user")); assert!(template.contains("model")); assert!(template.contains("")); assert!(template.ends_with("model\n")); } #[test] fn test_attention_config() { let config = Gemma2Config::gemma2_9b(); let attn_config = config.attention_config(); assert_eq!(attn_config.num_heads, 16); assert_eq!(attn_config.num_kv_heads, 8); assert_eq!(attn_config.head_dim, 256); assert!(attn_config.causal); } }