Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
42
crates/ruvector-attention/src/sdk/presets.rs
Normal file
42
crates/ruvector-attention/src/sdk/presets.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
//! Pre-configured attention presets for common use cases.
|
||||
|
||||
use crate::sdk::builder::AttentionBuilder;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum AttentionPreset {
|
||||
Bert,
|
||||
Gpt,
|
||||
Longformer,
|
||||
Performer,
|
||||
FlashOptimized,
|
||||
SwitchTransformer,
|
||||
HyperbolicTree,
|
||||
T5,
|
||||
ViT,
|
||||
SparseTransformer,
|
||||
}
|
||||
|
||||
impl AttentionPreset {
|
||||
pub fn builder(self, dim: usize) -> AttentionBuilder {
|
||||
match self {
|
||||
AttentionPreset::Bert => AttentionBuilder::new(dim).multi_head(12).dropout(0.1),
|
||||
AttentionPreset::Gpt => AttentionBuilder::new(dim)
|
||||
.multi_head(12)
|
||||
.causal(true)
|
||||
.dropout(0.1),
|
||||
_ => AttentionBuilder::new(dim),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn for_sequences(dim: usize, _max_len: usize) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim)
|
||||
}
|
||||
|
||||
pub fn for_graphs(dim: usize, _hierarchical: bool) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim)
|
||||
}
|
||||
|
||||
pub fn for_large_scale(dim: usize) -> AttentionBuilder {
|
||||
AttentionBuilder::new(dim).flash(128)
|
||||
}
|
||||
Reference in New Issue
Block a user