Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,212 @@
use clap::Args;
use crate::{config::Config, output::{print_benchmark_results, BenchmarkRow}};
use ruvector_attention::{
attention::{ScaledDotProductAttention, MultiHeadAttention},
hyperbolic::HyperbolicAttention,
sparse::{FlashAttention, LinearAttention},
moe::MoEAttention,
};
use indicatif::{ProgressBar, ProgressStyle};
use std::time::Instant;
#[derive(Args)]
pub struct BenchmarkArgs {
/// Attention types to benchmark (comma-separated)
#[arg(short, long, value_delimiter = ',')]
attention_types: Option<Vec<String>>,
/// Dimensions to test (comma-separated)
#[arg(short, long, value_delimiter = ',')]
dimensions: Option<Vec<usize>>,
/// Number of iterations per test
#[arg(short, long)]
iterations: Option<usize>,
/// Number of warmup iterations
#[arg(short, long)]
warmup: Option<usize>,
/// Sequence length
#[arg(short, long, default_value = "128")]
seq_length: usize,
/// Output results to file
#[arg(short, long)]
output: Option<std::path::PathBuf>,
/// Output format (json, csv, table)
#[arg(short, long, default_value = "table")]
format: String,
}
pub async fn run(args: BenchmarkArgs, config: &Config) -> anyhow::Result<()> {
let attention_types = args.attention_types.unwrap_or_else(|| {
vec![
"scaled_dot".to_string(),
"multi_head".to_string(),
"hyperbolic".to_string(),
"flash".to_string(),
"linear".to_string(),
"moe".to_string(),
]
});
let dimensions = args.dimensions.unwrap_or_else(|| config.benchmark.dimensions.clone());
let iterations = args.iterations.unwrap_or(config.benchmark.iterations);
let warmup = args.warmup.unwrap_or(config.benchmark.warmup);
println!("Running benchmarks...");
println!("Attention types: {:?}", attention_types);
println!("Dimensions: {:?}", dimensions);
println!("Iterations: {}, Warmup: {}", iterations, warmup);
println!();
let total_tests = attention_types.len() * dimensions.len();
let pb = ProgressBar::new(total_tests as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")?
.progress_chars("##-")
);
let mut results = Vec::new();
for attention_type in &attention_types {
for &dim in &dimensions {
pb.set_message(format!("Testing {} (dim={})", attention_type, dim));
let timings = benchmark_attention(
attention_type,
dim,
args.seq_length,
iterations,
warmup
)?;
let mean = timings.iter().sum::<f64>() / timings.len() as f64;
let variance = timings.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>() / timings.len() as f64;
let std_dev = variance.sqrt();
let throughput = 1000.0 / mean; // operations per second
results.push(BenchmarkRow {
attention_type: attention_type.clone(),
dimension: dim,
mean_time_ms: mean,
std_dev_ms: std_dev,
throughput,
});
pb.inc(1);
}
}
pb.finish_with_message("Benchmarks complete!");
println!();
match args.format.as_str() {
"json" => {
let json = serde_json::to_string_pretty(&results)?;
if let Some(path) = args.output {
std::fs::write(path, json)?;
} else {
println!("{}", json);
}
}
"csv" => {
let mut csv = String::from("attention_type,dimension,mean_time_ms,std_dev_ms,throughput\n");
for row in &results {
csv.push_str(&format!(
"{},{},{},{},{}\n",
row.attention_type, row.dimension, row.mean_time_ms, row.std_dev_ms, row.throughput
));
}
if let Some(path) = args.output {
std::fs::write(path, csv)?;
} else {
println!("{}", csv);
}
}
_ => {
print_benchmark_results(results);
}
}
Ok(())
}
fn benchmark_attention(
attention_type: &str,
dim: usize,
seq_length: usize,
iterations: usize,
warmup: usize,
) -> anyhow::Result<Vec<f64>> {
// Generate random test data
let query: Vec<Vec<f32>> = (0..seq_length)
.map(|_| (0..dim).map(|_| rand::random::<f32>()).collect())
.collect();
let keys: Vec<Vec<f32>> = (0..seq_length)
.map(|_| (0..dim).map(|_| rand::random::<f32>()).collect())
.collect();
let values: Vec<Vec<f32>> = (0..seq_length)
.map(|_| (0..dim).map(|_| rand::random::<f32>()).collect())
.collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
// Warmup
for _ in 0..warmup {
run_attention(attention_type, dim, &query, &keys_refs, &values_refs)?;
}
// Actual benchmark
let mut timings = Vec::new();
for _ in 0..iterations {
let start = Instant::now();
run_attention(attention_type, dim, &query, &keys_refs, &values_refs)?;
let elapsed = start.elapsed();
timings.push(elapsed.as_secs_f64() * 1000.0);
}
Ok(timings)
}
fn run_attention(
attention_type: &str,
dim: usize,
query: &[Vec<f32>],
keys: &[&[f32]],
values: &[&[f32]],
) -> anyhow::Result<Vec<Vec<f32>>> {
match attention_type {
"scaled_dot" => {
let attention = ScaledDotProductAttention::new(dim, None);
attention.compute(query, keys, values)
}
"multi_head" => {
let attention = MultiHeadAttention::new(dim, 8)?;
attention.compute(query, keys, values)
}
"hyperbolic" => {
let attention = HyperbolicAttention::new(dim, 1.0)?;
attention.compute(query, keys, values)
}
"flash" => {
let attention = FlashAttention::new(dim, 64)?;
attention.compute(query, keys, values)
}
"linear" => {
let attention = LinearAttention::new(dim)?;
attention.compute(query, keys, values)
}
"moe" => {
let attention = MoEAttention::new(dim, 4, 2)?;
attention.compute(query, keys, values)
}
_ => Err(anyhow::anyhow!("Unknown attention type: {}", attention_type)),
}
}

View File

@@ -0,0 +1,183 @@
use clap::Args;
use crate::{config::Config, output::{Output, OutputFormat, OutputDimensions, OutputMetadata}};
use ruvector_attention::{
attention::{ScaledDotProductAttention, MultiHeadAttention},
hyperbolic::HyperbolicAttention,
sparse::{FlashAttention, LinearAttention},
moe::MoEAttention,
};
use std::time::Instant;
#[derive(Args)]
pub struct ComputeArgs {
/// Input file (JSON/binary/msgpack)
#[arg(short, long)]
input: std::path::PathBuf,
/// Output file (optional, prints to stdout if not specified)
#[arg(short, long)]
output: Option<std::path::PathBuf>,
/// Attention type
#[arg(short, long, default_value = "scaled_dot")]
attention_type: AttentionType,
/// Number of attention heads (for multi-head attention)
#[arg(long, default_value = "8")]
num_heads: usize,
/// Number of experts (for MoE attention)
#[arg(long, default_value = "4")]
num_experts: usize,
/// Top-k experts (for MoE attention)
#[arg(long, default_value = "2")]
top_k: usize,
/// Curvature (for hyperbolic attention)
#[arg(long, default_value = "1.0")]
curvature: f32,
/// Output format
#[arg(short, long, default_value = "pretty")]
format: OutputFormat,
/// Show detailed metrics
#[arg(short, long)]
verbose: bool,
}
#[derive(Clone, clap::ValueEnum)]
pub enum AttentionType {
ScaledDot,
MultiHead,
Hyperbolic,
Flash,
Linear,
MoE,
}
pub async fn run(args: ComputeArgs, config: &Config) -> anyhow::Result<()> {
tracing::info!("Loading input from {:?}", args.input);
let input_data = super::load_input(&args.input)?;
tracing::info!(
"Input dimensions: query={:?}, keys={}, values={}",
input_data.query.len(),
input_data.keys.len(),
input_data.values.len()
);
let start = Instant::now();
let (result, attention_type_str) = match args.attention_type {
AttentionType::ScaledDot => {
tracing::info!("Computing scaled dot-product attention");
let attention = ScaledDotProductAttention::new(input_data.dim, None);
let result = attention.compute(
&input_data.query,
&input_data.keys_refs(),
&input_data.values_refs()
)?;
(result, "ScaledDotProduct")
}
AttentionType::MultiHead => {
tracing::info!("Computing multi-head attention with {} heads", args.num_heads);
let attention = MultiHeadAttention::new(input_data.dim, args.num_heads)?;
let result = attention.compute(
&input_data.query,
&input_data.keys_refs(),
&input_data.values_refs()
)?;
(result, "MultiHead")
}
AttentionType::Hyperbolic => {
tracing::info!("Computing hyperbolic attention with curvature={}", args.curvature);
let attention = HyperbolicAttention::new(input_data.dim, args.curvature)?;
let result = attention.compute(
&input_data.query,
&input_data.keys_refs(),
&input_data.values_refs()
)?;
(result, "Hyperbolic")
}
AttentionType::Flash => {
tracing::info!("Computing flash attention");
let attention = FlashAttention::new(input_data.dim, 64)?;
let result = attention.compute(
&input_data.query,
&input_data.keys_refs(),
&input_data.values_refs()
)?;
(result, "Flash")
}
AttentionType::Linear => {
tracing::info!("Computing linear attention");
let attention = LinearAttention::new(input_data.dim)?;
let result = attention.compute(
&input_data.query,
&input_data.keys_refs(),
&input_data.values_refs()
)?;
(result, "Linear")
}
AttentionType::MoE => {
tracing::info!(
"Computing MoE attention with {} experts, top-{}",
args.num_experts,
args.top_k
);
let attention = MoEAttention::new(
input_data.dim,
args.num_experts,
args.top_k
)?;
let result = attention.compute(
&input_data.query,
&input_data.keys_refs(),
&input_data.values_refs()
)?;
(result, "MixtureOfExperts")
}
};
let elapsed = start.elapsed();
if args.verbose {
tracing::info!("Computation completed in {:.2}ms", elapsed.as_secs_f64() * 1000.0);
}
let dimensions = OutputDimensions {
batch_size: 1,
num_heads: args.num_heads,
seq_length: input_data.keys.len(),
embedding_dim: input_data.dim,
};
let metadata = OutputMetadata {
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
memory_bytes: estimate_memory_usage(&result),
num_parameters: calculate_parameters(&args, input_data.dim),
};
let output = Output::new(attention_type_str, dimensions, result, metadata);
output.write(args.output.as_deref(), args.format)?;
tracing::info!("Output written successfully");
Ok(())
}
fn estimate_memory_usage(result: &[Vec<f32>]) -> usize {
result.iter().map(|row| row.len() * std::mem::size_of::<f32>()).sum()
}
fn calculate_parameters(args: &ComputeArgs, dim: usize) -> usize {
match args.attention_type {
AttentionType::ScaledDot => dim * dim * 3, // Q, K, V projections
AttentionType::MultiHead => dim * dim * 3 * args.num_heads + dim * dim, // + output projection
AttentionType::Hyperbolic => dim * dim * 3 + dim, // + curvature params
AttentionType::Flash => dim * dim * 3,
AttentionType::Linear => dim * dim * 2, // Linear projections
AttentionType::MoE => dim * dim * 3 * args.num_experts + dim * args.num_experts, // + router
}
}

View File

@@ -0,0 +1,166 @@
use clap::Args;
use crate::config::Config;
use serde::{Deserialize, Serialize};
#[derive(Args)]
pub struct ConvertArgs {
/// Input file
#[arg(short, long)]
input: std::path::PathBuf,
/// Output file
#[arg(short, long)]
output: std::path::PathBuf,
/// Input format (auto-detect if not specified)
#[arg(long)]
from: Option<DataFormat>,
/// Output format
#[arg(long)]
to: DataFormat,
/// Pretty print output (for text formats)
#[arg(short, long)]
pretty: bool,
}
#[derive(Clone, clap::ValueEnum)]
pub enum DataFormat {
Json,
Binary,
MsgPack,
Csv,
Npy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Data {
values: Vec<Vec<f32>>,
}
pub fn run(args: ConvertArgs, _config: &Config) -> anyhow::Result<()> {
tracing::info!("Converting from {:?} to {:?}", args.input, args.output);
// Read input
let content = std::fs::read(&args.input)?;
let data = parse_data(&content, args.from.as_ref())?;
// Write output
write_data(&args.output, &data, &args.to, args.pretty)?;
tracing::info!("Conversion complete");
Ok(())
}
fn parse_data(content: &[u8], format: Option<&DataFormat>) -> anyhow::Result<Data> {
if let Some(fmt) = format {
match fmt {
DataFormat::Json => Ok(serde_json::from_slice(content)?),
DataFormat::Binary => Ok(bincode::deserialize(content)?),
DataFormat::MsgPack => Ok(rmp_serde::from_slice(content)?),
DataFormat::Csv => parse_csv(content),
DataFormat::Npy => parse_npy(content),
}
} else {
// Auto-detect
if let Ok(data) = serde_json::from_slice::<Data>(content) {
return Ok(data);
}
if let Ok(data) = rmp_serde::from_slice::<Data>(content) {
return Ok(data);
}
if let Ok(data) = bincode::deserialize::<Data>(content) {
return Ok(data);
}
Err(anyhow::anyhow!("Failed to auto-detect format"))
}
}
fn write_data(
path: &std::path::Path,
data: &Data,
format: &DataFormat,
pretty: bool,
) -> anyhow::Result<()> {
match format {
DataFormat::Json => {
let content = if pretty {
serde_json::to_string_pretty(data)?
} else {
serde_json::to_string(data)?
};
std::fs::write(path, content)?;
}
DataFormat::Binary => {
let bytes = bincode::serialize(data)?;
std::fs::write(path, bytes)?;
}
DataFormat::MsgPack => {
let bytes = rmp_serde::to_vec(data)?;
std::fs::write(path, bytes)?;
}
DataFormat::Csv => {
let csv = to_csv(data)?;
std::fs::write(path, csv)?;
}
DataFormat::Npy => {
let npy = to_npy(data)?;
std::fs::write(path, npy)?;
}
}
Ok(())
}
fn parse_csv(content: &[u8]) -> anyhow::Result<Data> {
let text = std::str::from_utf8(content)?;
let mut values = Vec::new();
for line in text.lines().skip(1) { // Skip header
let row: Vec<f32> = line.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
if !row.is_empty() {
values.push(row);
}
}
Ok(Data { values })
}
fn to_csv(data: &Data) -> anyhow::Result<String> {
let mut csv = String::new();
// Write header
if let Some(first_row) = data.values.first() {
csv.push_str("row");
for i in 0..first_row.len() {
csv.push_str(&format!(",col_{}", i));
}
csv.push('\n');
}
// Write data
for (i, row) in data.values.iter().enumerate() {
csv.push_str(&i.to_string());
for val in row {
csv.push_str(&format!(",{}", val));
}
csv.push('\n');
}
Ok(csv)
}
fn parse_npy(_content: &[u8]) -> anyhow::Result<Data> {
// Simplified NPY parsing (real implementation would use a proper NPY library)
Err(anyhow::anyhow!("NPY parsing not yet implemented"))
}
fn to_npy(_data: &Data) -> anyhow::Result<Vec<u8>> {
// Simplified NPY writing (real implementation would use a proper NPY library)
Err(anyhow::anyhow!("NPY writing not yet implemented"))
}

View File

@@ -0,0 +1,66 @@
pub mod compute;
pub mod benchmark;
pub mod convert;
pub mod serve;
pub mod repl;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputData {
pub query: Vec<Vec<f32>>,
pub keys: Vec<Vec<f32>>,
pub values: Vec<Vec<f32>>,
pub dim: usize,
}
impl InputData {
pub fn keys_refs(&self) -> Vec<&[f32]> {
self.keys.iter().map(|k| k.as_slice()).collect()
}
pub fn values_refs(&self) -> Vec<&[f32]> {
self.values.iter().map(|v| v.as_slice()).collect()
}
}
pub fn load_input(path: &std::path::Path) -> anyhow::Result<InputData> {
let content = std::fs::read(path)?;
// Try to parse as JSON first
if let Ok(data) = serde_json::from_slice::<InputData>(&content) {
return Ok(data);
}
// Try MessagePack
if let Ok(data) = rmp_serde::from_slice::<InputData>(&content) {
return Ok(data);
}
// Try bincode
if let Ok(data) = bincode::deserialize::<InputData>(&content) {
return Ok(data);
}
Err(anyhow::anyhow!("Failed to parse input file"))
}
pub fn save_output(path: &std::path::Path, data: &[Vec<f32>], format: &str) -> anyhow::Result<()> {
match format {
"json" => {
let json = serde_json::to_string_pretty(data)?;
std::fs::write(path, json)?;
}
"msgpack" => {
let bytes = rmp_serde::to_vec(data)?;
std::fs::write(path, bytes)?;
}
"binary" => {
let bytes = bincode::serialize(data)?;
std::fs::write(path, bytes)?;
}
_ => return Err(anyhow::anyhow!("Unsupported format: {}", format)),
}
Ok(())
}

View File

@@ -0,0 +1,220 @@
use clap::Args;
use crate::config::Config;
use rustyline::{Editor, error::ReadlineError, history::DefaultHistory};
use ruvector_attention::{
attention::{ScaledDotProductAttention, MultiHeadAttention},
hyperbolic::HyperbolicAttention,
sparse::{FlashAttention, LinearAttention},
moe::MoEAttention,
};
#[derive(Args)]
pub struct ReplArgs {
/// Initial dimension
#[arg(short, long, default_value = "512")]
dim: usize,
}
enum Command {
Help,
Load(String),
Compute(ComputeArgs),
SetType(String),
SetDim(usize),
Config,
Quit,
Unknown(String),
}
struct ComputeArgs {
query: Vec<Vec<f32>>,
keys: Vec<Vec<f32>>,
values: Vec<Vec<f32>>,
}
struct ReplState {
config: Config,
dim: usize,
attention_type: String,
last_query: Option<Vec<Vec<f32>>>,
last_keys: Option<Vec<Vec<f32>>>,
last_values: Option<Vec<Vec<f32>>>,
}
impl ReplState {
fn new(config: &Config, dim: usize) -> anyhow::Result<Self> {
Ok(Self {
config: config.clone(),
dim,
attention_type: config.attention.default_type.clone(),
last_query: None,
last_keys: None,
last_values: None,
})
}
fn load(&mut self, path: &str) -> anyhow::Result<()> {
let data = super::load_input(&std::path::Path::new(path))?;
self.last_query = Some(data.query);
self.last_keys = Some(data.keys);
self.last_values = Some(data.values);
self.dim = data.dim;
println!("Loaded data from {}", path);
Ok(())
}
fn compute(&self, args: &ComputeArgs) -> anyhow::Result<Vec<Vec<f32>>> {
let keys_refs: Vec<&[f32]> = args.keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = args.values.iter().map(|v| v.as_slice()).collect();
match self.attention_type.as_str() {
"scaled_dot" => {
let attention = ScaledDotProductAttention::new(self.dim, None);
attention.compute(&args.query, &keys_refs, &values_refs)
}
"multi_head" => {
let attention = MultiHeadAttention::new(self.dim, self.config.attention.default_heads)?;
attention.compute(&args.query, &keys_refs, &values_refs)
}
"hyperbolic" => {
let attention = HyperbolicAttention::new(self.dim, 1.0)?;
attention.compute(&args.query, &keys_refs, &values_refs)
}
"flash" => {
let attention = FlashAttention::new(self.dim, 64)?;
attention.compute(&args.query, &keys_refs, &values_refs)
}
"linear" => {
let attention = LinearAttention::new(self.dim)?;
attention.compute(&args.query, &keys_refs, &values_refs)
}
"moe" => {
let attention = MoEAttention::new(self.dim, 4, 2)?;
attention.compute(&args.query, &keys_refs, &values_refs)
}
_ => Err(anyhow::anyhow!("Unknown attention type: {}", self.attention_type)),
}
}
fn set_attention_type(&mut self, attention_type: String) {
self.attention_type = attention_type;
println!("Attention type set to: {}", self.attention_type);
}
fn set_dim(&mut self, dim: usize) {
self.dim = dim;
println!("Dimension set to: {}", self.dim);
}
fn config(&self) -> &Config {
&self.config
}
}
pub async fn run(args: ReplArgs, config: &Config) -> anyhow::Result<()> {
let mut rl = Editor::<(), DefaultHistory>::new()?;
println!("RuVector Attention REPL v{}", env!("CARGO_PKG_VERSION"));
println!("Type 'help' for commands, 'quit' to exit\n");
let mut state = ReplState::new(config, args.dim)?;
loop {
match rl.readline("attention> ") {
Ok(line) => {
if line.trim().is_empty() {
continue;
}
rl.add_history_entry(&line)?;
match parse_command(&line) {
Command::Help => print_help(),
Command::Load(path) => {
if let Err(e) = state.load(&path) {
eprintln!("Error loading file: {}", e);
}
}
Command::Compute(args) => {
match state.compute(&args) {
Ok(result) => {
println!("Result shape: {}x{}", result.len(), result.first().map(|r| r.len()).unwrap_or(0));
println!("First row (first 5 values): {:?}",
result.first().map(|r| &r[..5.min(r.len())]));
}
Err(e) => eprintln!("Error computing attention: {}", e),
}
}
Command::SetType(t) => state.set_attention_type(t),
Command::SetDim(d) => state.set_dim(d),
Command::Config => println!("{:#?}", state.config()),
Command::Quit => break,
Command::Unknown(cmd) => println!("Unknown command: '{}'. Type 'help' for available commands.", cmd),
}
}
Err(ReadlineError::Interrupted) | Err(ReadlineError::Eof) => break,
Err(err) => return Err(err.into()),
}
}
println!("Goodbye!");
Ok(())
}
fn parse_command(line: &str) -> Command {
let parts: Vec<&str> = line.trim().split_whitespace().collect();
if parts.is_empty() {
return Command::Unknown(String::new());
}
match parts[0] {
"help" | "h" | "?" => Command::Help,
"load" => {
if parts.len() > 1 {
Command::Load(parts[1].to_string())
} else {
Command::Unknown("load requires a file path".to_string())
}
}
"compute" => {
// For simplicity, use random data
let query = vec![vec![0.1, 0.2, 0.3]];
let keys = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
let values = vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]];
Command::Compute(ComputeArgs { query, keys, values })
}
"type" => {
if parts.len() > 1 {
Command::SetType(parts[1].to_string())
} else {
Command::Unknown("type requires an attention type".to_string())
}
}
"dim" => {
if parts.len() > 1 {
if let Ok(d) = parts[1].parse() {
Command::SetDim(d)
} else {
Command::Unknown("dim requires a number".to_string())
}
} else {
Command::Unknown("dim requires a dimension".to_string())
}
}
"config" => Command::Config,
"quit" | "exit" | "q" => Command::Quit,
_ => Command::Unknown(parts[0].to_string()),
}
}
fn print_help() {
println!("Available commands:");
println!(" help - Show this help message");
println!(" load <file> - Load input data from file");
println!(" compute - Compute attention with loaded data");
println!(" type <type> - Set attention type (scaled_dot, multi_head, hyperbolic, flash, linear, moe)");
println!(" dim <size> - Set dimension size");
println!(" config - Show current configuration");
println!(" quit - Exit REPL");
}

View File

@@ -0,0 +1,290 @@
use clap::Args;
use crate::config::Config;
use axum::{
routing::{get, post},
Router, Json, extract::State,
http::StatusCode,
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tower_http::cors::CorsLayer;
use ruvector_attention::{
attention::{ScaledDotProductAttention, MultiHeadAttention},
hyperbolic::HyperbolicAttention,
sparse::{FlashAttention, LinearAttention},
moe::MoEAttention,
};
#[derive(Args)]
pub struct ServeArgs {
/// Host address
#[arg(short = 'H', long, default_value = "0.0.0.0")]
host: String,
/// Port number
#[arg(short, long, default_value = "8080")]
port: u16,
/// Enable CORS
#[arg(long)]
cors: bool,
}
struct ServerState {
config: Config,
}
#[derive(Debug, Deserialize)]
struct AttentionRequest {
query: Vec<Vec<f32>>,
keys: Vec<Vec<f32>>,
values: Vec<Vec<f32>>,
#[serde(default)]
num_heads: Option<usize>,
#[serde(default)]
num_experts: Option<usize>,
#[serde(default)]
top_k: Option<usize>,
#[serde(default)]
curvature: Option<f32>,
}
#[derive(Debug, Serialize)]
struct AttentionResponse {
result: Vec<Vec<f32>>,
compute_time_ms: f64,
metadata: ResponseMetadata,
}
#[derive(Debug, Serialize)]
struct ResponseMetadata {
attention_type: String,
dimensions: (usize, usize),
}
#[derive(Debug, Serialize)]
struct ErrorResponse {
error: String,
}
pub async fn run(args: ServeArgs, config: &Config) -> anyhow::Result<()> {
let state = Arc::new(ServerState {
config: config.clone(),
});
let mut app = Router::new()
.route("/health", get(health))
.route("/attention/scaled_dot", post(scaled_dot_attention))
.route("/attention/multi_head", post(multi_head_attention))
.route("/attention/hyperbolic", post(hyperbolic_attention))
.route("/attention/flash", post(flash_attention))
.route("/attention/linear", post(linear_attention))
.route("/attention/moe", post(moe_attention))
.route("/batch", post(batch_compute))
.with_state(state);
if args.cors {
app = app.layer(CorsLayer::permissive());
}
let addr = format!("{}:{}", args.host, args.port);
tracing::info!("Starting server at http://{}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn health() -> impl IntoResponse {
Json(serde_json::json!({
"status": "healthy",
"version": env!("CARGO_PKG_VERSION")
}))
}
async fn scaled_dot_attention(
State(_state): State<Arc<ServerState>>,
Json(req): Json<AttentionRequest>,
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
let attention = ScaledDotProductAttention::new(dim, None);
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&req.query, &keys_refs, &values_refs)
.map_err(|e| error_response(e.to_string()))?;
let elapsed = start.elapsed();
Ok(Json(AttentionResponse {
result,
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
metadata: ResponseMetadata {
attention_type: "ScaledDotProduct".to_string(),
dimensions: (req.query.len(), dim),
},
}))
}
async fn multi_head_attention(
State(_state): State<Arc<ServerState>>,
Json(req): Json<AttentionRequest>,
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
let num_heads = req.num_heads.unwrap_or(8);
let attention = MultiHeadAttention::new(dim, num_heads)
.map_err(|e| error_response(e.to_string()))?;
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&req.query, &keys_refs, &values_refs)
.map_err(|e| error_response(e.to_string()))?;
let elapsed = start.elapsed();
Ok(Json(AttentionResponse {
result,
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
metadata: ResponseMetadata {
attention_type: format!("MultiHead({})", num_heads),
dimensions: (req.query.len(), dim),
},
}))
}
async fn hyperbolic_attention(
State(_state): State<Arc<ServerState>>,
Json(req): Json<AttentionRequest>,
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
let curvature = req.curvature.unwrap_or(1.0);
let attention = HyperbolicAttention::new(dim, curvature)
.map_err(|e| error_response(e.to_string()))?;
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&req.query, &keys_refs, &values_refs)
.map_err(|e| error_response(e.to_string()))?;
let elapsed = start.elapsed();
Ok(Json(AttentionResponse {
result,
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
metadata: ResponseMetadata {
attention_type: "Hyperbolic".to_string(),
dimensions: (req.query.len(), dim),
},
}))
}
async fn flash_attention(
State(_state): State<Arc<ServerState>>,
Json(req): Json<AttentionRequest>,
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
let attention = FlashAttention::new(dim, 64)
.map_err(|e| error_response(e.to_string()))?;
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&req.query, &keys_refs, &values_refs)
.map_err(|e| error_response(e.to_string()))?;
let elapsed = start.elapsed();
Ok(Json(AttentionResponse {
result,
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
metadata: ResponseMetadata {
attention_type: "Flash".to_string(),
dimensions: (req.query.len(), dim),
},
}))
}
async fn linear_attention(
State(_state): State<Arc<ServerState>>,
Json(req): Json<AttentionRequest>,
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
let attention = LinearAttention::new(dim)
.map_err(|e| error_response(e.to_string()))?;
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&req.query, &keys_refs, &values_refs)
.map_err(|e| error_response(e.to_string()))?;
let elapsed = start.elapsed();
Ok(Json(AttentionResponse {
result,
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
metadata: ResponseMetadata {
attention_type: "Linear".to_string(),
dimensions: (req.query.len(), dim),
},
}))
}
async fn moe_attention(
State(_state): State<Arc<ServerState>>,
Json(req): Json<AttentionRequest>,
) -> Result<Json<AttentionResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let dim = req.query.first().map(|q| q.len()).unwrap_or(0);
let num_experts = req.num_experts.unwrap_or(4);
let top_k = req.top_k.unwrap_or(2);
let attention = MoEAttention::new(dim, num_experts, top_k)
.map_err(|e| error_response(e.to_string()))?;
let keys_refs: Vec<&[f32]> = req.keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = req.values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&req.query, &keys_refs, &values_refs)
.map_err(|e| error_response(e.to_string()))?;
let elapsed = start.elapsed();
Ok(Json(AttentionResponse {
result,
compute_time_ms: elapsed.as_secs_f64() * 1000.0,
metadata: ResponseMetadata {
attention_type: format!("MoE({}/{})", top_k, num_experts),
dimensions: (req.query.len(), dim),
},
}))
}
async fn batch_compute(
State(_state): State<Arc<ServerState>>,
Json(_req): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
Err(error_response("Batch compute not yet implemented".to_string()))
}
fn error_response(message: String) -> (StatusCode, Json<ErrorResponse>) {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse { error: message }),
)
}