Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
212
vendor/ruvector/crates/ruvector-attention-cli/src/commands/benchmark.rs
vendored
Normal file
212
vendor/ruvector/crates/ruvector-attention-cli/src/commands/benchmark.rs
vendored
Normal file
@@ -0,0 +1,212 @@
|
||||
use clap::Args;
|
||||
use crate::{config::Config, output::{print_benchmark_results, BenchmarkRow}};
|
||||
use ruvector_attention::{
|
||||
attention::{ScaledDotProductAttention, MultiHeadAttention},
|
||||
hyperbolic::HyperbolicAttention,
|
||||
sparse::{FlashAttention, LinearAttention},
|
||||
moe::MoEAttention,
|
||||
};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct BenchmarkArgs {
|
||||
/// Attention types to benchmark (comma-separated)
|
||||
#[arg(short, long, value_delimiter = ',')]
|
||||
attention_types: Option<Vec<String>>,
|
||||
|
||||
/// Dimensions to test (comma-separated)
|
||||
#[arg(short, long, value_delimiter = ',')]
|
||||
dimensions: Option<Vec<usize>>,
|
||||
|
||||
/// Number of iterations per test
|
||||
#[arg(short, long)]
|
||||
iterations: Option<usize>,
|
||||
|
||||
/// Number of warmup iterations
|
||||
#[arg(short, long)]
|
||||
warmup: Option<usize>,
|
||||
|
||||
/// Sequence length
|
||||
#[arg(short, long, default_value = "128")]
|
||||
seq_length: usize,
|
||||
|
||||
/// Output results to file
|
||||
#[arg(short, long)]
|
||||
output: Option<std::path::PathBuf>,
|
||||
|
||||
/// Output format (json, csv, table)
|
||||
#[arg(short, long, default_value = "table")]
|
||||
format: String,
|
||||
}
|
||||
|
||||
pub async fn run(args: BenchmarkArgs, config: &Config) -> anyhow::Result<()> {
|
||||
let attention_types = args.attention_types.unwrap_or_else(|| {
|
||||
vec![
|
||||
"scaled_dot".to_string(),
|
||||
"multi_head".to_string(),
|
||||
"hyperbolic".to_string(),
|
||||
"flash".to_string(),
|
||||
"linear".to_string(),
|
||||
"moe".to_string(),
|
||||
]
|
||||
});
|
||||
|
||||
let dimensions = args.dimensions.unwrap_or_else(|| config.benchmark.dimensions.clone());
|
||||
let iterations = args.iterations.unwrap_or(config.benchmark.iterations);
|
||||
let warmup = args.warmup.unwrap_or(config.benchmark.warmup);
|
||||
|
||||
println!("Running benchmarks...");
|
||||
println!("Attention types: {:?}", attention_types);
|
||||
println!("Dimensions: {:?}", dimensions);
|
||||
println!("Iterations: {}, Warmup: {}", iterations, warmup);
|
||||
println!();
|
||||
|
||||
let total_tests = attention_types.len() * dimensions.len();
|
||||
let pb = ProgressBar::new(total_tests as u64);
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")?
|
||||
.progress_chars("##-")
|
||||
);
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for attention_type in &attention_types {
|
||||
for &dim in &dimensions {
|
||||
pb.set_message(format!("Testing {} (dim={})", attention_type, dim));
|
||||
|
||||
let timings = benchmark_attention(
|
||||
attention_type,
|
||||
dim,
|
||||
args.seq_length,
|
||||
iterations,
|
||||
warmup
|
||||
)?;
|
||||
|
||||
let mean = timings.iter().sum::<f64>() / timings.len() as f64;
|
||||
let variance = timings.iter()
|
||||
.map(|&x| (x - mean).powi(2))
|
||||
.sum::<f64>() / timings.len() as f64;
|
||||
let std_dev = variance.sqrt();
|
||||
let throughput = 1000.0 / mean; // operations per second
|
||||
|
||||
results.push(BenchmarkRow {
|
||||
attention_type: attention_type.clone(),
|
||||
dimension: dim,
|
||||
mean_time_ms: mean,
|
||||
std_dev_ms: std_dev,
|
||||
throughput,
|
||||
});
|
||||
|
||||
pb.inc(1);
|
||||
}
|
||||
}
|
||||
|
||||
pb.finish_with_message("Benchmarks complete!");
|
||||
println!();
|
||||
|
||||
match args.format.as_str() {
|
||||
"json" => {
|
||||
let json = serde_json::to_string_pretty(&results)?;
|
||||
if let Some(path) = args.output {
|
||||
std::fs::write(path, json)?;
|
||||
} else {
|
||||
println!("{}", json);
|
||||
}
|
||||
}
|
||||
"csv" => {
|
||||
let mut csv = String::from("attention_type,dimension,mean_time_ms,std_dev_ms,throughput\n");
|
||||
for row in &results {
|
||||
csv.push_str(&format!(
|
||||
"{},{},{},{},{}\n",
|
||||
row.attention_type, row.dimension, row.mean_time_ms, row.std_dev_ms, row.throughput
|
||||
));
|
||||
}
|
||||
if let Some(path) = args.output {
|
||||
std::fs::write(path, csv)?;
|
||||
} else {
|
||||
println!("{}", csv);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
print_benchmark_results(results);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn benchmark_attention(
|
||||
attention_type: &str,
|
||||
dim: usize,
|
||||
seq_length: usize,
|
||||
iterations: usize,
|
||||
warmup: usize,
|
||||
) -> anyhow::Result<Vec<f64>> {
|
||||
// Generate random test data
|
||||
let query: Vec<Vec<f32>> = (0..seq_length)
|
||||
.map(|_| (0..dim).map(|_| rand::random::<f32>()).collect())
|
||||
.collect();
|
||||
let keys: Vec<Vec<f32>> = (0..seq_length)
|
||||
.map(|_| (0..dim).map(|_| rand::random::<f32>()).collect())
|
||||
.collect();
|
||||
let values: Vec<Vec<f32>> = (0..seq_length)
|
||||
.map(|_| (0..dim).map(|_| rand::random::<f32>()).collect())
|
||||
.collect();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
// Warmup
|
||||
for _ in 0..warmup {
|
||||
run_attention(attention_type, dim, &query, &keys_refs, &values_refs)?;
|
||||
}
|
||||
|
||||
// Actual benchmark
|
||||
let mut timings = Vec::new();
|
||||
for _ in 0..iterations {
|
||||
let start = Instant::now();
|
||||
run_attention(attention_type, dim, &query, &keys_refs, &values_refs)?;
|
||||
let elapsed = start.elapsed();
|
||||
timings.push(elapsed.as_secs_f64() * 1000.0);
|
||||
}
|
||||
|
||||
Ok(timings)
|
||||
}
|
||||
|
||||
fn run_attention(
|
||||
attention_type: &str,
|
||||
dim: usize,
|
||||
query: &[Vec<f32>],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> anyhow::Result<Vec<Vec<f32>>> {
|
||||
match attention_type {
|
||||
"scaled_dot" => {
|
||||
let attention = ScaledDotProductAttention::new(dim, None);
|
||||
attention.compute(query, keys, values)
|
||||
}
|
||||
"multi_head" => {
|
||||
let attention = MultiHeadAttention::new(dim, 8)?;
|
||||
attention.compute(query, keys, values)
|
||||
}
|
||||
"hyperbolic" => {
|
||||
let attention = HyperbolicAttention::new(dim, 1.0)?;
|
||||
attention.compute(query, keys, values)
|
||||
}
|
||||
"flash" => {
|
||||
let attention = FlashAttention::new(dim, 64)?;
|
||||
attention.compute(query, keys, values)
|
||||
}
|
||||
"linear" => {
|
||||
let attention = LinearAttention::new(dim)?;
|
||||
attention.compute(query, keys, values)
|
||||
}
|
||||
"moe" => {
|
||||
let attention = MoEAttention::new(dim, 4, 2)?;
|
||||
attention.compute(query, keys, values)
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("Unknown attention type: {}", attention_type)),
|
||||
}
|
||||
}
|
||||
183
vendor/ruvector/crates/ruvector-attention-cli/src/commands/compute.rs
vendored
Normal file
183
vendor/ruvector/crates/ruvector-attention-cli/src/commands/compute.rs
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
use clap::Args;
|
||||
use crate::{config::Config, output::{Output, OutputFormat, OutputDimensions, OutputMetadata}};
|
||||
use ruvector_attention::{
|
||||
attention::{ScaledDotProductAttention, MultiHeadAttention},
|
||||
hyperbolic::HyperbolicAttention,
|
||||
sparse::{FlashAttention, LinearAttention},
|
||||
moe::MoEAttention,
|
||||
};
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ComputeArgs {
|
||||
/// Input file (JSON/binary/msgpack)
|
||||
#[arg(short, long)]
|
||||
input: std::path::PathBuf,
|
||||
|
||||
/// Output file (optional, prints to stdout if not specified)
|
||||
#[arg(short, long)]
|
||||
output: Option<std::path::PathBuf>,
|
||||
|
||||
/// Attention type
|
||||
#[arg(short, long, default_value = "scaled_dot")]
|
||||
attention_type: AttentionType,
|
||||
|
||||
/// Number of attention heads (for multi-head attention)
|
||||
#[arg(long, default_value = "8")]
|
||||
num_heads: usize,
|
||||
|
||||
/// Number of experts (for MoE attention)
|
||||
#[arg(long, default_value = "4")]
|
||||
num_experts: usize,
|
||||
|
||||
/// Top-k experts (for MoE attention)
|
||||
#[arg(long, default_value = "2")]
|
||||
top_k: usize,
|
||||
|
||||
/// Curvature (for hyperbolic attention)
|
||||
#[arg(long, default_value = "1.0")]
|
||||
curvature: f32,
|
||||
|
||||
/// Output format
|
||||
#[arg(short, long, default_value = "pretty")]
|
||||
format: OutputFormat,
|
||||
|
||||
/// Show detailed metrics
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, clap::ValueEnum)]
|
||||
pub enum AttentionType {
|
||||
ScaledDot,
|
||||
MultiHead,
|
||||
Hyperbolic,
|
||||
Flash,
|
||||
Linear,
|
||||
MoE,
|
||||
}
|
||||
|
||||
pub async fn run(args: ComputeArgs, config: &Config) -> anyhow::Result<()> {
|
||||
tracing::info!("Loading input from {:?}", args.input);
|
||||
let input_data = super::load_input(&args.input)?;
|
||||
|
||||
tracing::info!(
|
||||
"Input dimensions: query={:?}, keys={}, values={}",
|
||||
input_data.query.len(),
|
||||
input_data.keys.len(),
|
||||
input_data.values.len()
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let (result, attention_type_str) = match args.attention_type {
|
||||
AttentionType::ScaledDot => {
|
||||
tracing::info!("Computing scaled dot-product attention");
|
||||
let attention = ScaledDotProductAttention::new(input_data.dim, None);
|
||||
let result = attention.compute(
|
||||
&input_data.query,
|
||||
&input_data.keys_refs(),
|
||||
&input_data.values_refs()
|
||||
)?;
|
||||
(result, "ScaledDotProduct")
|
||||
}
|
||||
AttentionType::MultiHead => {
|
||||
tracing::info!("Computing multi-head attention with {} heads", args.num_heads);
|
||||
let attention = MultiHeadAttention::new(input_data.dim, args.num_heads)?;
|
||||
let result = attention.compute(
|
||||
&input_data.query,
|
||||
&input_data.keys_refs(),
|
||||
&input_data.values_refs()
|
||||
)?;
|
||||
(result, "MultiHead")
|
||||
}
|
||||
AttentionType::Hyperbolic => {
|
||||
tracing::info!("Computing hyperbolic attention with curvature={}", args.curvature);
|
||||
let attention = HyperbolicAttention::new(input_data.dim, args.curvature)?;
|
||||
let result = attention.compute(
|
||||
&input_data.query,
|
||||
&input_data.keys_refs(),
|
||||
&input_data.values_refs()
|
||||
)?;
|
||||
(result, "Hyperbolic")
|
||||
}
|
||||
AttentionType::Flash => {
|
||||
tracing::info!("Computing flash attention");
|
||||
let attention = FlashAttention::new(input_data.dim, 64)?;
|
||||
let result = attention.compute(
|
||||
&input_data.query,
|
||||
&input_data.keys_refs(),
|
||||
&input_data.values_refs()
|
||||
)?;
|
||||
(result, "Flash")
|
||||
}
|
||||
AttentionType::Linear => {
|
||||
tracing::info!("Computing linear attention");
|
||||
let attention = LinearAttention::new(input_data.dim)?;
|
||||
let result = attention.compute(
|
||||
&input_data.query,
|
||||
&input_data.keys_refs(),
|
||||
&input_data.values_refs()
|
||||
)?;
|
||||
(result, "Linear")
|
||||
}
|
||||
AttentionType::MoE => {
|
||||
tracing::info!(
|
||||
"Computing MoE attention with {} experts, top-{}",
|
||||
args.num_experts,
|
||||
args.top_k
|
||||
);
|
||||
let attention = MoEAttention::new(
|
||||
input_data.dim,
|
||||
args.num_experts,
|
||||
args.top_k
|
||||
)?;
|
||||
let result = attention.compute(
|
||||
&input_data.query,
|
||||
&input_data.keys_refs(),
|
||||
&input_data.values_refs()
|
||||
)?;
|
||||
(result, "MixtureOfExperts")
|
||||
}
|
||||
};
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
if args.verbose {
|
||||
tracing::info!("Computation completed in {:.2}ms", elapsed.as_secs_f64() * 1000.0);
|
||||
}
|
||||
|
||||
let dimensions = OutputDimensions {
|
||||
batch_size: 1,
|
||||
num_heads: args.num_heads,
|
||||
seq_length: input_data.keys.len(),
|
||||
embedding_dim: input_data.dim,
|
||||
};
|
||||
|
||||
let metadata = OutputMetadata {
|
||||
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
|
||||
memory_bytes: estimate_memory_usage(&result),
|
||||
num_parameters: calculate_parameters(&args, input_data.dim),
|
||||
};
|
||||
|
||||
let output = Output::new(attention_type_str, dimensions, result, metadata);
|
||||
output.write(args.output.as_deref(), args.format)?;
|
||||
|
||||
tracing::info!("Output written successfully");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn estimate_memory_usage(result: &[Vec<f32>]) -> usize {
|
||||
result.iter().map(|row| row.len() * std::mem::size_of::<f32>()).sum()
|
||||
}
|
||||
|
||||
fn calculate_parameters(args: &ComputeArgs, dim: usize) -> usize {
|
||||
match args.attention_type {
|
||||
AttentionType::ScaledDot => dim * dim * 3, // Q, K, V projections
|
||||
AttentionType::MultiHead => dim * dim * 3 * args.num_heads + dim * dim, // + output projection
|
||||
AttentionType::Hyperbolic => dim * dim * 3 + dim, // + curvature params
|
||||
AttentionType::Flash => dim * dim * 3,
|
||||
AttentionType::Linear => dim * dim * 2, // Linear projections
|
||||
AttentionType::MoE => dim * dim * 3 * args.num_experts + dim * args.num_experts, // + router
|
||||
}
|
||||
}
|
||||
166
vendor/ruvector/crates/ruvector-attention-cli/src/commands/convert.rs
vendored
Normal file
166
vendor/ruvector/crates/ruvector-attention-cli/src/commands/convert.rs
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
use clap::Args;
|
||||
use crate::config::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ConvertArgs {
|
||||
/// Input file
|
||||
#[arg(short, long)]
|
||||
input: std::path::PathBuf,
|
||||
|
||||
/// Output file
|
||||
#[arg(short, long)]
|
||||
output: std::path::PathBuf,
|
||||
|
||||
/// Input format (auto-detect if not specified)
|
||||
#[arg(long)]
|
||||
from: Option<DataFormat>,
|
||||
|
||||
/// Output format
|
||||
#[arg(long)]
|
||||
to: DataFormat,
|
||||
|
||||
/// Pretty print output (for text formats)
|
||||
#[arg(short, long)]
|
||||
pretty: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, clap::ValueEnum)]
|
||||
pub enum DataFormat {
|
||||
Json,
|
||||
Binary,
|
||||
MsgPack,
|
||||
Csv,
|
||||
Npy,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Data {
|
||||
values: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
pub fn run(args: ConvertArgs, _config: &Config) -> anyhow::Result<()> {
|
||||
tracing::info!("Converting from {:?} to {:?}", args.input, args.output);
|
||||
|
||||
// Read input
|
||||
let content = std::fs::read(&args.input)?;
|
||||
let data = parse_data(&content, args.from.as_ref())?;
|
||||
|
||||
// Write output
|
||||
write_data(&args.output, &data, &args.to, args.pretty)?;
|
||||
|
||||
tracing::info!("Conversion complete");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_data(content: &[u8], format: Option<&DataFormat>) -> anyhow::Result<Data> {
|
||||
if let Some(fmt) = format {
|
||||
match fmt {
|
||||
DataFormat::Json => Ok(serde_json::from_slice(content)?),
|
||||
DataFormat::Binary => Ok(bincode::deserialize(content)?),
|
||||
DataFormat::MsgPack => Ok(rmp_serde::from_slice(content)?),
|
||||
DataFormat::Csv => parse_csv(content),
|
||||
DataFormat::Npy => parse_npy(content),
|
||||
}
|
||||
} else {
|
||||
// Auto-detect
|
||||
if let Ok(data) = serde_json::from_slice::<Data>(content) {
|
||||
return Ok(data);
|
||||
}
|
||||
if let Ok(data) = rmp_serde::from_slice::<Data>(content) {
|
||||
return Ok(data);
|
||||
}
|
||||
if let Ok(data) = bincode::deserialize::<Data>(content) {
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("Failed to auto-detect format"))
|
||||
}
|
||||
}
|
||||
|
||||
fn write_data(
|
||||
path: &std::path::Path,
|
||||
data: &Data,
|
||||
format: &DataFormat,
|
||||
pretty: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
match format {
|
||||
DataFormat::Json => {
|
||||
let content = if pretty {
|
||||
serde_json::to_string_pretty(data)?
|
||||
} else {
|
||||
serde_json::to_string(data)?
|
||||
};
|
||||
std::fs::write(path, content)?;
|
||||
}
|
||||
DataFormat::Binary => {
|
||||
let bytes = bincode::serialize(data)?;
|
||||
std::fs::write(path, bytes)?;
|
||||
}
|
||||
DataFormat::MsgPack => {
|
||||
let bytes = rmp_serde::to_vec(data)?;
|
||||
std::fs::write(path, bytes)?;
|
||||
}
|
||||
DataFormat::Csv => {
|
||||
let csv = to_csv(data)?;
|
||||
std::fs::write(path, csv)?;
|
||||
}
|
||||
DataFormat::Npy => {
|
||||
let npy = to_npy(data)?;
|
||||
std::fs::write(path, npy)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_csv(content: &[u8]) -> anyhow::Result<Data> {
|
||||
let text = std::str::from_utf8(content)?;
|
||||
let mut values = Vec::new();
|
||||
|
||||
for line in text.lines().skip(1) { // Skip header
|
||||
let row: Vec<f32> = line.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
if !row.is_empty() {
|
||||
values.push(row);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Data { values })
|
||||
}
|
||||
|
||||
fn to_csv(data: &Data) -> anyhow::Result<String> {
|
||||
let mut csv = String::new();
|
||||
|
||||
// Write header
|
||||
if let Some(first_row) = data.values.first() {
|
||||
csv.push_str("row");
|
||||
for i in 0..first_row.len() {
|
||||
csv.push_str(&format!(",col_{}", i));
|
||||
}
|
||||
csv.push('\n');
|
||||
}
|
||||
|
||||
// Write data
|
||||
for (i, row) in data.values.iter().enumerate() {
|
||||
csv.push_str(&i.to_string());
|
||||
for val in row {
|
||||
csv.push_str(&format!(",{}", val));
|
||||
}
|
||||
csv.push('\n');
|
||||
}
|
||||
|
||||
Ok(csv)
|
||||
}
|
||||
|
||||
fn parse_npy(_content: &[u8]) -> anyhow::Result<Data> {
|
||||
// Simplified NPY parsing (real implementation would use a proper NPY library)
|
||||
Err(anyhow::anyhow!("NPY parsing not yet implemented"))
|
||||
}
|
||||
|
||||
fn to_npy(_data: &Data) -> anyhow::Result<Vec<u8>> {
|
||||
// Simplified NPY writing (real implementation would use a proper NPY library)
|
||||
Err(anyhow::anyhow!("NPY writing not yet implemented"))
|
||||
}
|
||||
66
vendor/ruvector/crates/ruvector-attention-cli/src/commands/mod.rs
vendored
Normal file
66
vendor/ruvector/crates/ruvector-attention-cli/src/commands/mod.rs
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
pub mod compute;
|
||||
pub mod benchmark;
|
||||
pub mod convert;
|
||||
pub mod serve;
|
||||
pub mod repl;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InputData {
|
||||
pub query: Vec<Vec<f32>>,
|
||||
pub keys: Vec<Vec<f32>>,
|
||||
pub values: Vec<Vec<f32>>,
|
||||
pub dim: usize,
|
||||
}
|
||||
|
||||
impl InputData {
|
||||
pub fn keys_refs(&self) -> Vec<&[f32]> {
|
||||
self.keys.iter().map(|k| k.as_slice()).collect()
|
||||
}
|
||||
|
||||
pub fn values_refs(&self) -> Vec<&[f32]> {
|
||||
self.values.iter().map(|v| v.as_slice()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_input(path: &std::path::Path) -> anyhow::Result<InputData> {
|
||||
let content = std::fs::read(path)?;
|
||||
|
||||
// Try to parse as JSON first
|
||||
if let Ok(data) = serde_json::from_slice::<InputData>(&content) {
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
// Try MessagePack
|
||||
if let Ok(data) = rmp_serde::from_slice::<InputData>(&content) {
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
// Try bincode
|
||||
if let Ok(data) = bincode::deserialize::<InputData>(&content) {
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("Failed to parse input file"))
|
||||
}
|
||||
|
||||
pub fn save_output(path: &std::path::Path, data: &[Vec<f32>], format: &str) -> anyhow::Result<()> {
|
||||
match format {
|
||||
"json" => {
|
||||
let json = serde_json::to_string_pretty(data)?;
|
||||
std::fs::write(path, json)?;
|
||||
}
|
||||
"msgpack" => {
|
||||
let bytes = rmp_serde::to_vec(data)?;
|
||||
std::fs::write(path, bytes)?;
|
||||
}
|
||||
"binary" => {
|
||||
let bytes = bincode::serialize(data)?;
|
||||
std::fs::write(path, bytes)?;
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unsupported format: {}", format)),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
220
vendor/ruvector/crates/ruvector-attention-cli/src/commands/repl.rs
vendored
Normal file
220
vendor/ruvector/crates/ruvector-attention-cli/src/commands/repl.rs
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
use clap::Args;
|
||||
use crate::config::Config;
|
||||
use rustyline::{Editor, error::ReadlineError, history::DefaultHistory};
|
||||
use ruvector_attention::{
|
||||
attention::{ScaledDotProductAttention, MultiHeadAttention},
|
||||
hyperbolic::HyperbolicAttention,
|
||||
sparse::{FlashAttention, LinearAttention},
|
||||
moe::MoEAttention,
|
||||
};
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ReplArgs {
|
||||
/// Initial dimension
|
||||
#[arg(short, long, default_value = "512")]
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
enum Command {
|
||||
Help,
|
||||
Load(String),
|
||||
Compute(ComputeArgs),
|
||||
SetType(String),
|
||||
SetDim(usize),
|
||||
Config,
|
||||
Quit,
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
struct ComputeArgs {
|
||||
query: Vec<Vec<f32>>,
|
||||
keys: Vec<Vec<f32>>,
|
||||
values: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
struct ReplState {
|
||||
config: Config,
|
||||
dim: usize,
|
||||
attention_type: String,
|
||||
last_query: Option<Vec<Vec<f32>>>,
|
||||
last_keys: Option<Vec<Vec<f32>>>,
|
||||
last_values: Option<Vec<Vec<f32>>>,
|
||||
}
|
||||
|
||||
impl ReplState {
|
||||
fn new(config: &Config, dim: usize) -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
dim,
|
||||
attention_type: config.attention.default_type.clone(),
|
||||
last_query: None,
|
||||
last_keys: None,
|
||||
last_values: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn load(&mut self, path: &str) -> anyhow::Result<()> {
|
||||
let data = super::load_input(&std::path::Path::new(path))?;
|
||||
self.last_query = Some(data.query);
|
||||
self.last_keys = Some(data.keys);
|
||||
self.last_values = Some(data.values);
|
||||
self.dim = data.dim;
|
||||
println!("Loaded data from {}", path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute(&self, args: &ComputeArgs) -> anyhow::Result<Vec<Vec<f32>>> {
|
||||
let keys_refs: Vec<&[f32]> = args.keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = args.values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
match self.attention_type.as_str() {
|
||||
"scaled_dot" => {
|
||||
let attention = ScaledDotProductAttention::new(self.dim, None);
|
||||
attention.compute(&args.query, &keys_refs, &values_refs)
|
||||
}
|
||||
"multi_head" => {
|
||||
let attention = MultiHeadAttention::new(self.dim, self.config.attention.default_heads)?;
|
||||
attention.compute(&args.query, &keys_refs, &values_refs)
|
||||
}
|
||||
"hyperbolic" => {
|
||||
let attention = HyperbolicAttention::new(self.dim, 1.0)?;
|
||||
attention.compute(&args.query, &keys_refs, &values_refs)
|
||||
}
|
||||
"flash" => {
|
||||
let attention = FlashAttention::new(self.dim, 64)?;
|
||||
attention.compute(&args.query, &keys_refs, &values_refs)
|
||||
}
|
||||
"linear" => {
|
||||
let attention = LinearAttention::new(self.dim)?;
|
||||
attention.compute(&args.query, &keys_refs, &values_refs)
|
||||
}
|
||||
"moe" => {
|
||||
let attention = MoEAttention::new(self.dim, 4, 2)?;
|
||||
attention.compute(&args.query, &keys_refs, &values_refs)
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("Unknown attention type: {}", self.attention_type)),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_attention_type(&mut self, attention_type: String) {
|
||||
self.attention_type = attention_type;
|
||||
println!("Attention type set to: {}", self.attention_type);
|
||||
}
|
||||
|
||||
fn set_dim(&mut self, dim: usize) {
|
||||
self.dim = dim;
|
||||
println!("Dimension set to: {}", self.dim);
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(args: ReplArgs, config: &Config) -> anyhow::Result<()> {
|
||||
let mut rl = Editor::<(), DefaultHistory>::new()?;
|
||||
|
||||
println!("RuVector Attention REPL v{}", env!("CARGO_PKG_VERSION"));
|
||||
println!("Type 'help' for commands, 'quit' to exit\n");
|
||||
|
||||
let mut state = ReplState::new(config, args.dim)?;
|
||||
|
||||
loop {
|
||||
match rl.readline("attention> ") {
|
||||
Ok(line) => {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
rl.add_history_entry(&line)?;
|
||||
|
||||
match parse_command(&line) {
|
||||
Command::Help => print_help(),
|
||||
Command::Load(path) => {
|
||||
if let Err(e) = state.load(&path) {
|
||||
eprintln!("Error loading file: {}", e);
|
||||
}
|
||||
}
|
||||
Command::Compute(args) => {
|
||||
match state.compute(&args) {
|
||||
Ok(result) => {
|
||||
println!("Result shape: {}x{}", result.len(), result.first().map(|r| r.len()).unwrap_or(0));
|
||||
println!("First row (first 5 values): {:?}",
|
||||
result.first().map(|r| &r[..5.min(r.len())]));
|
||||
}
|
||||
Err(e) => eprintln!("Error computing attention: {}", e),
|
||||
}
|
||||
}
|
||||
Command::SetType(t) => state.set_attention_type(t),
|
||||
Command::SetDim(d) => state.set_dim(d),
|
||||
Command::Config => println!("{:#?}", state.config()),
|
||||
Command::Quit => break,
|
||||
Command::Unknown(cmd) => println!("Unknown command: '{}'. Type 'help' for available commands.", cmd),
|
||||
}
|
||||
}
|
||||
Err(ReadlineError::Interrupted) | Err(ReadlineError::Eof) => break,
|
||||
Err(err) => return Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
println!("Goodbye!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_command(line: &str) -> Command {
|
||||
let parts: Vec<&str> = line.trim().split_whitespace().collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
return Command::Unknown(String::new());
|
||||
}
|
||||
|
||||
match parts[0] {
|
||||
"help" | "h" | "?" => Command::Help,
|
||||
"load" => {
|
||||
if parts.len() > 1 {
|
||||
Command::Load(parts[1].to_string())
|
||||
} else {
|
||||
Command::Unknown("load requires a file path".to_string())
|
||||
}
|
||||
}
|
||||
"compute" => {
|
||||
// For simplicity, use random data
|
||||
let query = vec![vec![0.1, 0.2, 0.3]];
|
||||
let keys = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
|
||||
let values = vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]];
|
||||
Command::Compute(ComputeArgs { query, keys, values })
|
||||
}
|
||||
"type" => {
|
||||
if parts.len() > 1 {
|
||||
Command::SetType(parts[1].to_string())
|
||||
} else {
|
||||
Command::Unknown("type requires an attention type".to_string())
|
||||
}
|
||||
}
|
||||
"dim" => {
|
||||
if parts.len() > 1 {
|
||||
if let Ok(d) = parts[1].parse() {
|
||||
Command::SetDim(d)
|
||||
} else {
|
||||
Command::Unknown("dim requires a number".to_string())
|
||||
}
|
||||
} else {
|
||||
Command::Unknown("dim requires a dimension".to_string())
|
||||
}
|
||||
}
|
||||
"config" => Command::Config,
|
||||
"quit" | "exit" | "q" => Command::Quit,
|
||||
_ => Command::Unknown(parts[0].to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Available commands:");
|
||||
println!(" help - Show this help message");
|
||||
println!(" load <file> - Load input data from file");
|
||||
println!(" compute - Compute attention with loaded data");
|
||||
println!(" type <type> - Set attention type (scaled_dot, multi_head, hyperbolic, flash, linear, moe)");
|
||||
println!(" dim <size> - Set dimension size");
|
||||
println!(" config - Show current configuration");
|
||||
println!(" quit - Exit REPL");
|
||||
}
|
||||
290
vendor/ruvector/crates/ruvector-attention-cli/src/commands/serve.rs
vendored
Normal file
290
vendor/ruvector/crates/ruvector-attention-cli/src/commands/serve.rs
vendored
Normal file
@@ -0,0 +1,290 @@
|
||||
use clap::Args;
|
||||
use crate::config::Config;
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Router, Json, extract::State,
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use ruvector_attention::{
|
||||
attention::{ScaledDotProductAttention, MultiHeadAttention},
|
||||
hyperbolic::HyperbolicAttention,
|
||||
sparse::{FlashAttention, LinearAttention},
|
||||
moe::MoEAttention,
|
||||
};
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ServeArgs {
|
||||
/// Host address
|
||||
#[arg(short = 'H', long, default_value = "0.0.0.0")]
|
||||
host: String,
|
||||
|
||||
/// Port number
|
||||
#[arg(short, long, default_value = "8080")]
|
||||
port: u16,
|
||||
|
||||
/// Enable CORS
|
||||
#[arg(long)]
|
||||
cors: bool,
|
||||
}
|
||||
|
||||
struct ServerState {
|
||||
config: Config,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AttentionRequest {
|
||||
query: Vec<Vec<f32>>,
|
||||
keys: Vec<Vec<f32>>,
|
||||
values: Vec<Vec<f32>>,
|
||||
#[serde(default)]
|
||||
num_heads: Option<usize>,
|
||||
#[serde(default)]
|
||||
num_experts: Option<usize>,
|
||||
#[serde(default)]
|
||||
top_k: Option<usize>,
|
||||
#[serde(default)]
|
||||
curvature: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AttentionResponse {
|
||||
result: Vec<Vec<f32>>,
|
||||
compute_time_ms: f64,
|
||||
metadata: ResponseMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResponseMetadata {
|
||||
attention_type: String,
|
||||
dimensions: (usize, usize),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
pub async fn run(args: ServeArgs, config: &Config) -> anyhow::Result<()> {
|
||||
let state = Arc::new(ServerState {
|
||||
config: config.clone(),
|
||||
});
|
||||
|
||||
let mut app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/attention/scaled_dot", post(scaled_dot_attention))
|
||||
.route("/attention/multi_head", post(multi_head_attention))
|
||||
.route("/attention/hyperbolic", post(hyperbolic_attention))
|
||||
.route("/attention/flash", post(flash_attention))
|
||||
.route("/attention/linear", post(linear_attention))
|
||||
.route("/attention/moe", post(moe_attention))
|
||||
.route("/batch", post(batch_compute))
|
||||
.with_state(state);
|
||||
|
||||
if args.cors {
|
||||
app = app.layer(CorsLayer::permissive());
|
||||
}
|
||||
|
||||
let addr = format!("{}:{}", args.host, args.port);
|
||||
tracing::info!("Starting server at http://{}", addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health() -> impl IntoResponse {
|
||||
Json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}))
|
||||
}
|
||||
|
||||
async fn scaled_dot_attention(
|
||||
State(_state): State<Arc<ServerState>>,
|
||||
Json(req): Json<AttentionRequest>,
|
||||
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
|
||||
let attention = ScaledDotProductAttention::new(dim, None);
|
||||
|
||||
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attention.compute(&req.query, &keys_refs, &values_refs)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
Ok(Json(AttentionResponse {
|
||||
result,
|
||||
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
|
||||
metadata: ResponseMetadata {
|
||||
attention_type: "ScaledDotProduct".to_string(),
|
||||
dimensions: (req.query.len(), dim),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
async fn multi_head_attention(
|
||||
State(_state): State<Arc<ServerState>>,
|
||||
Json(req): Json<AttentionRequest>,
|
||||
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
|
||||
let num_heads = req.num_heads.unwrap_or(8);
|
||||
let attention = MultiHeadAttention::new(dim, num_heads)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attention.compute(&req.query, &keys_refs, &values_refs)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
Ok(Json(AttentionResponse {
|
||||
result,
|
||||
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
|
||||
metadata: ResponseMetadata {
|
||||
attention_type: format!("MultiHead({})", num_heads),
|
||||
dimensions: (req.query.len(), dim),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
async fn hyperbolic_attention(
|
||||
State(_state): State<Arc<ServerState>>,
|
||||
Json(req): Json<AttentionRequest>,
|
||||
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
|
||||
let curvature = req.curvature.unwrap_or(1.0);
|
||||
let attention = HyperbolicAttention::new(dim, curvature)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attention.compute(&req.query, &keys_refs, &values_refs)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
Ok(Json(AttentionResponse {
|
||||
result,
|
||||
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
|
||||
metadata: ResponseMetadata {
|
||||
attention_type: "Hyperbolic".to_string(),
|
||||
dimensions: (req.query.len(), dim),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
async fn flash_attention(
|
||||
State(_state): State<Arc<ServerState>>,
|
||||
Json(req): Json<AttentionRequest>,
|
||||
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
|
||||
let attention = FlashAttention::new(dim, 64)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attention.compute(&req.query, &keys_refs, &values_refs)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
Ok(Json(AttentionResponse {
|
||||
result,
|
||||
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
|
||||
metadata: ResponseMetadata {
|
||||
attention_type: "Flash".to_string(),
|
||||
dimensions: (req.query.len(), dim),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
async fn linear_attention(
|
||||
State(_state): State<Arc<ServerState>>,
|
||||
Json(req): Json<AttentionRequest>,
|
||||
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
|
||||
let attention = LinearAttention::new(dim)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attention.compute(&req.query, &keys_refs, &values_refs)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
Ok(Json(AttentionResponse {
|
||||
result,
|
||||
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
|
||||
metadata: ResponseMetadata {
|
||||
attention_type: "Linear".to_string(),
|
||||
dimensions: (req.query.len(), dim),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
async fn moe_attention(
|
||||
State(_state): State<Arc<ServerState>>,
|
||||
Json(req): Json<AttentionRequest>,
|
||||
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
|
||||
let num_experts = req.num_experts.unwrap_or(4);
|
||||
let top_k = req.top_k.unwrap_or(2);
|
||||
let attention = MoEAttention::new(dim, num_experts, top_k)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = attention.compute(&req.query, &keys_refs, &values_refs)
|
||||
.map_err(|e| error_response(e.to_string()))?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
Ok(Json(AttentionResponse {
|
||||
result,
|
||||
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
|
||||
metadata: ResponseMetadata {
|
||||
attention_type: format!("MoE({}/{})", top_k, num_experts),
|
||||
dimensions: (req.query.len(), dim),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
async fn batch_compute(
|
||||
State(_state): State<Arc<ServerState>>,
|
||||
Json(_req): Json<serde_json::Value>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
|
||||
Err(error_response("Batch compute not yet implemented".to_string()))
|
||||
}
|
||||
|
||||
fn error_response(message: String) -> (StatusCode, Json<ErrorResponse>) {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse { error: message }),
|
||||
)
|
||||
}
|
||||
109
vendor/ruvector/crates/ruvector-attention-cli/src/config.rs
vendored
Normal file
109
vendor/ruvector/crates/ruvector-attention-cli/src/config.rs
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub attention: AttentionSettings,
|
||||
pub server: ServerSettings,
|
||||
pub output: OutputSettings,
|
||||
pub benchmark: BenchmarkSettings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AttentionSettings {
|
||||
pub default_dim: usize,
|
||||
pub default_heads: usize,
|
||||
pub default_type: String,
|
||||
pub dropout: f32,
|
||||
pub max_sequence_length: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerSettings {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub max_batch_size: usize,
|
||||
pub timeout_ms: u64,
|
||||
pub enable_cors: bool,
|
||||
pub enable_metrics: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OutputSettings {
|
||||
pub format: String,
|
||||
pub pretty: bool,
|
||||
pub precision: usize,
|
||||
pub color: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchmarkSettings {
|
||||
pub iterations: usize,
|
||||
pub warmup: usize,
|
||||
pub sample_size: usize,
|
||||
pub dimensions: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
attention: AttentionSettings {
|
||||
default_dim: 512,
|
||||
default_heads: 8,
|
||||
default_type: "scaled_dot".to_string(),
|
||||
dropout: 0.1,
|
||||
max_sequence_length: 4096,
|
||||
},
|
||||
server: ServerSettings {
|
||||
host: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
max_batch_size: 32,
|
||||
timeout_ms: 30000,
|
||||
enable_cors: true,
|
||||
enable_metrics: true,
|
||||
},
|
||||
output: OutputSettings {
|
||||
format: "json".to_string(),
|
||||
pretty: true,
|
||||
precision: 4,
|
||||
color: true,
|
||||
},
|
||||
benchmark: BenchmarkSettings {
|
||||
iterations: 100,
|
||||
warmup: 10,
|
||||
sample_size: 10,
|
||||
dimensions: vec![128, 256, 512, 1024],
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_config(path: Option<&Path>) -> anyhow::Result<Config> {
|
||||
if let Some(p) = path {
|
||||
let content = std::fs::read_to_string(p)?;
|
||||
Ok(toml::from_str(&content)?)
|
||||
} else {
|
||||
// Try default locations
|
||||
let default_paths = [
|
||||
"ruvector-attention.toml",
|
||||
"config/ruvector-attention.toml",
|
||||
"~/.config/ruvector-attention.toml",
|
||||
];
|
||||
|
||||
for path in &default_paths {
|
||||
if let Ok(content) = std::fs::read_to_string(path) {
|
||||
return Ok(toml::from_str(&content)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Config::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||
let content = toml::to_string_pretty(self)?;
|
||||
std::fs::write(path, content)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
72
vendor/ruvector/crates/ruvector-attention-cli/src/main.rs
vendored
Normal file
72
vendor/ruvector/crates/ruvector-attention-cli/src/main.rs
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
mod commands;
|
||||
mod config;
|
||||
mod output;
|
||||
mod server;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "ruvector-attention")]
|
||||
#[command(author = "rUv <ruv@ruv.io>")]
|
||||
#[command(version)]
|
||||
#[command(about = "High-performance attention mechanisms CLI", long_about = None)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
|
||||
/// Path to configuration file
|
||||
#[arg(short, long, global = true)]
|
||||
config: Option<std::path::PathBuf>,
|
||||
|
||||
/// Log level (trace, debug, info, warn, error)
|
||||
#[arg(short, long, global = true, default_value = "info")]
|
||||
log_level: String,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Compute attention over input data
|
||||
Compute(commands::compute::ComputeArgs),
|
||||
|
||||
/// Run performance benchmarks
|
||||
Benchmark(commands::benchmark::BenchmarkArgs),
|
||||
|
||||
/// Convert between data formats
|
||||
Convert(commands::convert::ConvertArgs),
|
||||
|
||||
/// Start HTTP server
|
||||
Serve(commands::serve::ServeArgs),
|
||||
|
||||
/// Interactive REPL
|
||||
Repl(commands::repl::ReplArgs),
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_target(false)
|
||||
.with_thread_ids(false)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
)
|
||||
.with(tracing_subscriber::EnvFilter::new(&cli.log_level))
|
||||
.init();
|
||||
|
||||
// Load configuration
|
||||
let config = config::load_config(cli.config.as_deref())?;
|
||||
|
||||
// Execute command
|
||||
match cli.command {
|
||||
Commands::Compute(args) => commands::compute::run(args, &config).await,
|
||||
Commands::Benchmark(args) => commands::benchmark::run(args, &config).await,
|
||||
Commands::Convert(args) => commands::convert::run(args, &config),
|
||||
Commands::Serve(args) => commands::serve::run(args, &config).await,
|
||||
Commands::Repl(args) => commands::repl::run(args, &config).await,
|
||||
}
|
||||
}
|
||||
170
vendor/ruvector/crates/ruvector-attention-cli/src/output.rs
vendored
Normal file
170
vendor/ruvector/crates/ruvector-attention-cli/src/output.rs
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use tabled::{settings::Style, Table, Tabled};
|
||||
|
||||
#[derive(Clone, clap::ValueEnum)]
|
||||
pub enum OutputFormat {
|
||||
Pretty,
|
||||
Json,
|
||||
Binary,
|
||||
Csv,
|
||||
MsgPack,
|
||||
Table,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AttentionOutput {
|
||||
pub attention_type: String,
|
||||
pub dimensions: OutputDimensions,
|
||||
pub scores: Vec<Vec<f32>>,
|
||||
pub metadata: OutputMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OutputDimensions {
|
||||
pub batch_size: usize,
|
||||
pub num_heads: usize,
|
||||
pub seq_length: usize,
|
||||
pub embedding_dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OutputMetadata {
|
||||
pub compute_time_ms: f64,
|
||||
pub memory_bytes: usize,
|
||||
pub num_parameters: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Tabled)]
|
||||
pub struct BenchmarkRow {
|
||||
pub attention_type: String,
|
||||
pub dimension: usize,
|
||||
pub mean_time_ms: f64,
|
||||
pub std_dev_ms: f64,
|
||||
pub throughput: f64,
|
||||
}
|
||||
|
||||
pub struct Output {
|
||||
data: AttentionOutput,
|
||||
}
|
||||
|
||||
impl Output {
|
||||
pub fn new(
|
||||
attention_type: impl Into<String>,
|
||||
dimensions: OutputDimensions,
|
||||
scores: Vec<Vec<f32>>,
|
||||
metadata: OutputMetadata,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: AttentionOutput {
|
||||
attention_type: attention_type.into(),
|
||||
dimensions,
|
||||
scores,
|
||||
metadata,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&self, path: Option<&Path>, format: OutputFormat) -> Result<()> {
|
||||
let content = match format {
|
||||
OutputFormat::Pretty => self.to_pretty()?,
|
||||
OutputFormat::Json => serde_json::to_string_pretty(&self.data)?,
|
||||
OutputFormat::Binary => {
|
||||
if let Some(p) = path {
|
||||
std::fs::write(p, bincode::serialize(&self.data)?)?;
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Binary format requires output path"));
|
||||
}
|
||||
}
|
||||
OutputFormat::Csv => self.to_csv()?,
|
||||
OutputFormat::MsgPack => {
|
||||
if let Some(p) = path {
|
||||
let data = rmp_serde::to_vec(&self.data)?;
|
||||
std::fs::write(p, data)?;
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("MessagePack format requires output path"));
|
||||
}
|
||||
}
|
||||
OutputFormat::Table => self.to_table()?,
|
||||
};
|
||||
|
||||
if let Some(p) = path {
|
||||
std::fs::write(p, content)?;
|
||||
} else {
|
||||
println!("{}", content);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_pretty(&self) -> Result<String> {
|
||||
let mut output = String::new();
|
||||
output.push_str(&format!("Attention Type: {}\n", self.data.attention_type));
|
||||
output.push_str(&format!("Dimensions:\n"));
|
||||
output.push_str(&format!(" Batch Size: {}\n", self.data.dimensions.batch_size));
|
||||
output.push_str(&format!(" Num Heads: {}\n", self.data.dimensions.num_heads));
|
||||
output.push_str(&format!(" Sequence Length: {}\n", self.data.dimensions.seq_length));
|
||||
output.push_str(&format!(" Embedding Dim: {}\n", self.data.dimensions.embedding_dim));
|
||||
output.push_str(&format!("\nMetadata:\n"));
|
||||
output.push_str(&format!(" Compute Time: {:.2}ms\n", self.data.metadata.compute_time_ms));
|
||||
output.push_str(&format!(" Memory Usage: {} bytes\n", self.data.metadata.memory_bytes));
|
||||
output.push_str(&format!(" Parameters: {}\n", self.data.metadata.num_parameters));
|
||||
|
||||
if !self.data.scores.is_empty() {
|
||||
output.push_str(&format!("\nAttention Scores (first 5x5):\n"));
|
||||
for (i, row) in self.data.scores.iter().take(5).enumerate() {
|
||||
output.push_str(&format!(" Row {}: ", i));
|
||||
for val in row.iter().take(5) {
|
||||
output.push_str(&format!("{:.4} ", val));
|
||||
}
|
||||
output.push_str("\n");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn to_csv(&self) -> Result<String> {
|
||||
let mut csv = String::new();
|
||||
csv.push_str("row,col,value\n");
|
||||
|
||||
for (i, row) in self.data.scores.iter().enumerate() {
|
||||
for (j, val) in row.iter().enumerate() {
|
||||
csv.push_str(&format!("{},{},{}\n", i, j, val));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(csv)
|
||||
}
|
||||
|
||||
fn to_table(&self) -> Result<String> {
|
||||
let rows: Vec<Vec<String>> = self.data.scores.iter()
|
||||
.take(10)
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.take(10)
|
||||
.map(|v| format!("{:.4}", v))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut table_str = String::from("Attention Scores:\n");
|
||||
for row in rows {
|
||||
table_str.push_str(&row.join(" | "));
|
||||
table_str.push('\n');
|
||||
}
|
||||
|
||||
Ok(table_str)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_benchmark_results(results: Vec<BenchmarkRow>) {
|
||||
let table = Table::new(results)
|
||||
.with(Style::modern())
|
||||
.to_string();
|
||||
|
||||
println!("{}", table);
|
||||
}
|
||||
7
vendor/ruvector/crates/ruvector-attention-cli/src/server/handlers.rs
vendored
Normal file
7
vendor/ruvector/crates/ruvector-attention-cli/src/server/handlers.rs
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
// Additional server handlers can be added here
|
||||
// This module can be extended with:
|
||||
// - WebSocket support for streaming attention computation
|
||||
// - Server-sent events for progress updates
|
||||
// - Additional API endpoints for batch processing
|
||||
// - Authentication and rate limiting
|
||||
// - Metrics and monitoring endpoints
|
||||
2
vendor/ruvector/crates/ruvector-attention-cli/src/server/mod.rs
vendored
Normal file
2
vendor/ruvector/crates/ruvector-attention-cli/src/server/mod.rs
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// Server module for additional HTTP server functionality
|
||||
pub mod handlers;
|
||||
Reference in New Issue
Block a user