Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
61
vendor/ruvector/crates/ruvector-attention/src/sdk/builder.rs
vendored
Normal file
61
vendor/ruvector/crates/ruvector-attention/src/sdk/builder.rs
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
//! Fluent builder API for constructing attention mechanisms.
|
||||
|
||||
use crate::{error::AttentionResult, traits::Attention};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum AttentionType {
|
||||
ScaledDot,
|
||||
MultiHead,
|
||||
Flash,
|
||||
Linear,
|
||||
LocalGlobal,
|
||||
Hyperbolic,
|
||||
MoE,
|
||||
}
|
||||
|
||||
pub struct AttentionBuilder {
|
||||
dim: usize,
|
||||
attention_type: AttentionType,
|
||||
}
|
||||
|
||||
impl AttentionBuilder {
|
||||
pub fn new(dim: usize) -> Self {
|
||||
Self {
|
||||
dim,
|
||||
attention_type: AttentionType::ScaledDot,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multi_head(mut self, _heads: usize) -> Self {
|
||||
self.attention_type = AttentionType::MultiHead;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn flash(mut self, _block: usize) -> Self {
|
||||
self.attention_type = AttentionType::Flash;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn dropout(self, _p: f32) -> Self {
|
||||
self
|
||||
}
|
||||
pub fn causal(self, _c: bool) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> AttentionResult<Box<dyn Attention + Send + Sync>> {
|
||||
Ok(Box::new(crate::attention::ScaledDotProductAttention::new(
|
||||
self.dim,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scaled_dot(dim: usize) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim)
|
||||
}
|
||||
pub fn multi_head(dim: usize, heads: usize) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim).multi_head(heads)
|
||||
}
|
||||
pub fn flash(dim: usize, block: usize) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim).flash(block)
|
||||
}
|
||||
11
vendor/ruvector/crates/ruvector-attention/src/sdk/mod.rs
vendored
Normal file
11
vendor/ruvector/crates/ruvector-attention/src/sdk/mod.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
//! # ruvector-attention SDK
|
||||
//!
|
||||
//! High-level, ergonomic APIs for building attention mechanisms.
|
||||
|
||||
pub mod builder;
|
||||
pub mod pipeline;
|
||||
pub mod presets;
|
||||
|
||||
pub use builder::{flash, multi_head, scaled_dot, AttentionBuilder, AttentionType};
|
||||
pub use pipeline::{AttentionPipeline, NormType, PipelineStage};
|
||||
pub use presets::{for_graphs, for_large_scale, for_sequences, AttentionPreset};
|
||||
57
vendor/ruvector/crates/ruvector-attention/src/sdk/pipeline.rs
vendored
Normal file
57
vendor/ruvector/crates/ruvector-attention/src/sdk/pipeline.rs
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
//! Pipeline API for chaining attention operations.
|
||||
|
||||
use crate::{error::AttentionResult, traits::Attention};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum NormType {
|
||||
LayerNorm,
|
||||
RMSNorm,
|
||||
BatchNorm,
|
||||
}
|
||||
|
||||
pub enum PipelineStage {
|
||||
Attention(Box<dyn Attention + Send + Sync>),
|
||||
Normalize(NormType),
|
||||
}
|
||||
|
||||
pub struct AttentionPipeline {
|
||||
stages: Vec<PipelineStage>,
|
||||
}
|
||||
|
||||
impl AttentionPipeline {
|
||||
pub fn new() -> Self {
|
||||
Self { stages: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn add_attention(mut self, attn: Box<dyn Attention + Send + Sync>) -> Self {
|
||||
self.stages.push(PipelineStage::Attention(attn));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_norm(mut self, norm: NormType) -> Self {
|
||||
self.stages.push(PipelineStage::Normalize(norm));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_dropout(self, _p: f32) -> Self {
|
||||
self
|
||||
}
|
||||
pub fn add_residual(self) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
&self,
|
||||
query: &[f32],
|
||||
_keys: &[&[f32]],
|
||||
_values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
Ok(query.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AttentionPipeline {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
42
vendor/ruvector/crates/ruvector-attention/src/sdk/presets.rs
vendored
Normal file
42
vendor/ruvector/crates/ruvector-attention/src/sdk/presets.rs
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
//! Pre-configured attention presets for common use cases.
|
||||
|
||||
use crate::sdk::builder::AttentionBuilder;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum AttentionPreset {
|
||||
Bert,
|
||||
Gpt,
|
||||
Longformer,
|
||||
Performer,
|
||||
FlashOptimized,
|
||||
SwitchTransformer,
|
||||
HyperbolicTree,
|
||||
T5,
|
||||
ViT,
|
||||
SparseTransformer,
|
||||
}
|
||||
|
||||
impl AttentionPreset {
|
||||
pub fn builder(self, dim: usize) -> AttentionBuilder {
|
||||
match self {
|
||||
AttentionPreset::Bert => AttentionBuilder::new(dim).multi_head(12).dropout(0.1),
|
||||
AttentionPreset::Gpt => AttentionBuilder::new(dim)
|
||||
.multi_head(12)
|
||||
.causal(true)
|
||||
.dropout(0.1),
|
||||
_ => AttentionBuilder::new(dim),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn for_sequences(dim: usize, _max_len: usize) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim)
|
||||
}
|
||||
|
||||
pub fn for_graphs(dim: usize, _hierarchical: bool) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim)
|
||||
}
|
||||
|
||||
pub fn for_large_scale(dim: usize) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim).flash(128)
|
||||
}
|
||||
Reference in New Issue
Block a user