Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,344 @@
//! CLI command implementations
use crate::cli::{
export_csv, export_json, format_error, format_search_results, format_stats, format_success,
ProgressTracker,
};
use crate::config::Config;
use anyhow::{Context, Result};
use colored::*;
use ruvector_core::{
types::{DbOptions, SearchQuery, VectorEntry},
VectorDB,
};
use std::path::{Path, PathBuf};
use std::time::Instant;
/// Create a new database
pub fn create_database(path: &str, dimensions: usize, config: &Config) -> Result<()> {
let mut db_options = config.to_db_options();
db_options.storage_path = path.to_string();
db_options.dimensions = dimensions;
println!(
"{}",
format_success(&format!("Creating database at: {}", path))
);
println!(" Dimensions: {}", dimensions.to_string().cyan());
println!(" Distance metric: {:?}", db_options.distance_metric);
let _db = VectorDB::new(db_options).context("Failed to create database")?;
println!("{}", format_success("Database created successfully!"));
Ok(())
}
/// Insert vectors from a file
pub fn insert_vectors(
db_path: &str,
input_file: &str,
format: &str,
config: &Config,
show_progress: bool,
) -> Result<()> {
// Load database
let mut db_options = config.to_db_options();
db_options.storage_path = db_path.to_string();
let db = VectorDB::new(db_options).context("Failed to open database")?;
// Parse input file
let entries = match format {
"json" => parse_json_file(input_file)?,
"csv" => parse_csv_file(input_file)?,
"npy" => parse_npy_file(input_file)?,
_ => return Err(anyhow::anyhow!("Unsupported format: {}", format)),
};
let total = entries.len();
println!(
"{}",
format_success(&format!("Loaded {} vectors from {}", total, input_file))
);
// Insert with progress
let start = Instant::now();
let tracker = ProgressTracker::new();
let pb = if show_progress {
Some(tracker.create_bar(total as u64, "Inserting vectors..."))
} else {
None
};
let batch_size = config.cli.batch_size;
let mut inserted = 0;
for chunk in entries.chunks(batch_size) {
db.insert_batch(chunk.to_vec())
.context("Failed to insert batch")?;
inserted += chunk.len();
if let Some(ref pb) = pb {
pb.set_position(inserted as u64);
}
}
if let Some(pb) = pb {
pb.finish_with_message("Insertion complete!");
}
let elapsed = start.elapsed();
println!(
"{}",
format_success(&format!(
"Inserted {} vectors in {:.2}s ({:.0} vectors/sec)",
total,
elapsed.as_secs_f64(),
total as f64 / elapsed.as_secs_f64()
))
);
Ok(())
}
/// Search for similar vectors
pub fn search_vectors(
db_path: &str,
query_vector: Vec<f32>,
k: usize,
config: &Config,
show_vectors: bool,
) -> Result<()> {
let mut db_options = config.to_db_options();
db_options.storage_path = db_path.to_string();
let db = VectorDB::new(db_options).context("Failed to open database")?;
let start = Instant::now();
let results = db
.search(SearchQuery {
vector: query_vector,
k,
filter: None,
ef_search: None,
})
.context("Failed to search")?;
let elapsed = start.elapsed();
println!("{}", format_search_results(&results, show_vectors));
println!(
"\n{}",
format!(
"Search completed in {:.2}ms",
elapsed.as_secs_f64() * 1000.0
)
.dimmed()
);
Ok(())
}
/// Show database information
pub fn show_info(db_path: &str, config: &Config) -> Result<()> {
let mut db_options = config.to_db_options();
db_options.storage_path = db_path.to_string();
let db = VectorDB::new(db_options).context("Failed to open database")?;
let count = db.len().context("Failed to get count")?;
let dimensions = db.options().dimensions;
let metric = format!("{:?}", db.options().distance_metric);
println!("{}", format_stats(count, dimensions, &metric));
if let Some(hnsw_config) = &db.options().hnsw_config {
println!("{}", "HNSW Configuration:".bold().green());
println!(" M: {}", hnsw_config.m.to_string().cyan());
println!(
" ef_construction: {}",
hnsw_config.ef_construction.to_string().cyan()
);
println!(" ef_search: {}", hnsw_config.ef_search.to_string().cyan());
}
Ok(())
}
/// Run a quick benchmark
pub fn run_benchmark(db_path: &str, config: &Config, num_queries: usize) -> Result<()> {
let mut db_options = config.to_db_options();
db_options.storage_path = db_path.to_string();
let db = VectorDB::new(db_options).context("Failed to open database")?;
let dimensions = db.options().dimensions;
println!("{}", "Running benchmark...".bold().green());
println!(" Queries: {}", num_queries.to_string().cyan());
println!(" Dimensions: {}", dimensions.to_string().cyan());
// Generate random query vectors
use rand::Rng;
let mut rng = rand::thread_rng();
let queries: Vec<Vec<f32>> = (0..num_queries)
.map(|_| (0..dimensions).map(|_| rng.gen()).collect())
.collect();
// Warm-up
for query in queries.iter().take(10) {
let _ = db.search(SearchQuery {
vector: query.clone(),
k: 10,
filter: None,
ef_search: None,
});
}
// Benchmark
let start = Instant::now();
for query in &queries {
db.search(SearchQuery {
vector: query.clone(),
k: 10,
filter: None,
ef_search: None,
})
.context("Search failed")?;
}
let elapsed = start.elapsed();
let qps = num_queries as f64 / elapsed.as_secs_f64();
let avg_latency = elapsed.as_secs_f64() * 1000.0 / num_queries as f64;
println!("\n{}", "Benchmark Results:".bold().green());
println!(" Total time: {:.2}s", elapsed.as_secs_f64());
println!(" Queries per second: {:.0}", qps.to_string().cyan());
println!(" Average latency: {:.2}ms", avg_latency.to_string().cyan());
Ok(())
}
/// Export database to file
pub fn export_database(
db_path: &str,
output_file: &str,
format: &str,
config: &Config,
) -> Result<()> {
let mut db_options = config.to_db_options();
db_options.storage_path = db_path.to_string();
let db = VectorDB::new(db_options).context("Failed to open database")?;
println!(
"{}",
format_success(&format!("Exporting database to: {}", output_file))
);
// Export is currently limited - would need to add all_ids() method to VectorDB
// For now, return an error with a helpful message
return Err(anyhow::anyhow!(
"Export functionality requires VectorDB::all_ids() method. This will be implemented in a future update."
));
// TODO: Implement when VectorDB exposes all_ids()
// let ids = db.all_ids()?;
// let tracker = ProgressTracker::new();
// let pb = tracker.create_bar(ids.len() as u64, "Exporting vectors...");
// ...
}
/// Import from other vector databases
pub fn import_from_external(
db_path: &str,
source: &str,
source_path: &str,
config: &Config,
) -> Result<()> {
println!(
"{}",
format_success(&format!("Importing from {} database", source))
);
match source {
"faiss" => {
// TODO: Implement FAISS import
return Err(anyhow::anyhow!("FAISS import not yet implemented"));
}
"pinecone" => {
// TODO: Implement Pinecone import
return Err(anyhow::anyhow!("Pinecone import not yet implemented"));
}
"weaviate" => {
// TODO: Implement Weaviate import
return Err(anyhow::anyhow!("Weaviate import not yet implemented"));
}
_ => return Err(anyhow::anyhow!("Unsupported source: {}", source)),
}
}
// Helper functions
fn parse_json_file(path: &str) -> Result<Vec<VectorEntry>> {
let content = std::fs::read_to_string(path).context("Failed to read JSON file")?;
serde_json::from_str(&content).context("Failed to parse JSON")
}
fn parse_csv_file(path: &str) -> Result<Vec<VectorEntry>> {
let mut reader = csv::Reader::from_path(path).context("Failed to open CSV file")?;
let mut entries = Vec::new();
for result in reader.records() {
let record = result.context("Failed to read CSV record")?;
let id = if record.get(0).map(|s| s.is_empty()).unwrap_or(true) {
None
} else {
Some(record.get(0).unwrap().to_string())
};
let vector: Vec<f32> =
serde_json::from_str(record.get(1).context("Missing vector column")?)
.context("Failed to parse vector")?;
let metadata = if let Some(meta_str) = record.get(2) {
if !meta_str.is_empty() {
Some(serde_json::from_str(meta_str).context("Failed to parse metadata")?)
} else {
None
}
} else {
None
};
entries.push(VectorEntry {
id,
vector,
metadata,
});
}
Ok(entries)
}
fn parse_npy_file(path: &str) -> Result<Vec<VectorEntry>> {
use ndarray::Array2;
use ndarray_npy::ReadNpyExt;
let file = std::fs::File::open(path).context("Failed to open NPY file")?;
let array: Array2<f32> = Array2::read_npy(file).context("Failed to read NPY file")?;
let entries: Vec<VectorEntry> = array
.outer_iter()
.enumerate()
.map(|(i, row)| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: row.to_vec(),
metadata: None,
})
.collect();
Ok(entries)
}

View File

@@ -0,0 +1,179 @@
//! Output formatting utilities
use colored::*;
use ruvector_core::types::{SearchResult, VectorEntry};
use serde_json;
/// Format search results for display
pub fn format_search_results(results: &[SearchResult], show_vectors: bool) -> String {
let mut output = String::new();
for (i, result) in results.iter().enumerate() {
output.push_str(&format!("\n{}. {}\n", i + 1, result.id.bold()));
output.push_str(&format!(" Score: {:.4}\n", result.score));
if let Some(metadata) = &result.metadata {
if !metadata.is_empty() {
output.push_str(&format!(
" Metadata: {}\n",
serde_json::to_string_pretty(metadata).unwrap_or_else(|_| "{}".to_string())
));
}
}
if show_vectors {
if let Some(vector) = &result.vector {
let preview: Vec<f32> = vector.iter().take(5).copied().collect();
output.push_str(&format!(" Vector (first 5): {:?}...\n", preview));
}
}
}
output
}
/// Format database statistics
pub fn format_stats(count: usize, dimensions: usize, metric: &str) -> String {
format!(
"\n{}\n Vectors: {}\n Dimensions: {}\n Distance Metric: {}\n",
"Database Statistics".bold().green(),
count.to_string().cyan(),
dimensions.to_string().cyan(),
metric.cyan()
)
}
/// Format error message
pub fn format_error(msg: &str) -> String {
format!("{} {}", "Error:".red().bold(), msg)
}
/// Format success message
pub fn format_success(msg: &str) -> String {
format!("{} {}", "".green().bold(), msg)
}
/// Format warning message
pub fn format_warning(msg: &str) -> String {
format!("{} {}", "Warning:".yellow().bold(), msg)
}
/// Format info message
pub fn format_info(msg: &str) -> String {
format!("{} {}", "".blue().bold(), msg)
}
/// Export vector entries to JSON
pub fn export_json(entries: &[VectorEntry]) -> anyhow::Result<String> {
serde_json::to_string_pretty(entries)
.map_err(|e| anyhow::anyhow!("Failed to serialize to JSON: {}", e))
}
/// Export vector entries to CSV
pub fn export_csv(entries: &[VectorEntry]) -> anyhow::Result<String> {
let mut wtr = csv::Writer::from_writer(vec![]);
// Write header
wtr.write_record(&["id", "vector", "metadata"])?;
// Write entries
for entry in entries {
wtr.write_record(&[
entry.id.as_ref().map(|s| s.as_str()).unwrap_or(""),
&serde_json::to_string(&entry.vector)?,
&serde_json::to_string(&entry.metadata)?,
])?;
}
wtr.flush()?;
String::from_utf8(wtr.into_inner()?)
.map_err(|e| anyhow::anyhow!("Failed to convert CSV to string: {}", e))
}
// Graph-specific formatting functions
/// Format graph node for display
pub fn format_graph_node(
id: &str,
labels: &[String],
properties: &serde_json::Map<String, serde_json::Value>,
) -> String {
let mut output = String::new();
output.push_str(&format!("{} ({})\n", id.bold(), labels.join(":").cyan()));
if !properties.is_empty() {
output.push_str(" Properties:\n");
for (key, value) in properties {
output.push_str(&format!(" {}: {}\n", key.yellow(), value));
}
}
output
}
/// Format graph relationship for display
pub fn format_graph_relationship(
id: &str,
rel_type: &str,
start_node: &str,
end_node: &str,
properties: &serde_json::Map<String, serde_json::Value>,
) -> String {
let mut output = String::new();
output.push_str(&format!(
"{} -[{}]-> {}\n",
start_node.cyan(),
rel_type.yellow(),
end_node.cyan()
));
if !properties.is_empty() {
output.push_str(" Properties:\n");
for (key, value) in properties {
output.push_str(&format!(" {}: {}\n", key.yellow(), value));
}
}
output
}
/// Format graph query results as table
pub fn format_graph_table(headers: &[String], rows: &[Vec<String>]) -> String {
use prettytable::{Cell, Row, Table};
let mut table = Table::new();
// Add headers
let header_cells: Vec<Cell> = headers
.iter()
.map(|h| Cell::new(h).style_spec("Fyb"))
.collect();
table.add_row(Row::new(header_cells));
// Add rows
for row in rows {
let cells: Vec<Cell> = row.iter().map(|v| Cell::new(v)).collect();
table.add_row(Row::new(cells));
}
table.to_string()
}
/// Format graph statistics
pub fn format_graph_stats(
node_count: usize,
rel_count: usize,
label_count: usize,
rel_type_count: usize,
) -> String {
format!(
"\n{}\n Nodes: {}\n Relationships: {}\n Labels: {}\n Relationship Types: {}\n",
"Graph Statistics".bold().green(),
node_count.to_string().cyan(),
rel_count.to_string().cyan(),
label_count.to_string().cyan(),
rel_type_count.to_string().cyan()
)
}

View File

@@ -0,0 +1,552 @@
//! Graph database command implementations
use crate::cli::{format_error, format_info, format_success, ProgressTracker};
use crate::config::Config;
use anyhow::{Context, Result};
use colored::*;
use std::io::{self, BufRead, Write};
use std::path::Path;
use std::time::Instant;
/// Graph database subcommands
#[derive(clap::Subcommand, Debug)]
pub enum GraphCommands {
/// Create a new graph database
Create {
/// Database file path
#[arg(short, long, default_value = "./ruvector-graph.db")]
path: String,
/// Graph name
#[arg(short, long, default_value = "default")]
name: String,
/// Enable property indexing
#[arg(long)]
indexed: bool,
},
/// Execute a Cypher query
Query {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
db: String,
/// Cypher query to execute
#[arg(short = 'q', long)]
cypher: String,
/// Output format (table, json, csv)
#[arg(long, default_value = "table")]
format: String,
/// Show execution plan
#[arg(long)]
explain: bool,
},
/// Interactive Cypher shell (REPL)
Shell {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
db: String,
/// Enable multiline mode
#[arg(long)]
multiline: bool,
},
/// Import data from file
Import {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
db: String,
/// Input file path
#[arg(short = 'i', long)]
input: String,
/// Input format (csv, json, cypher)
#[arg(long, default_value = "json")]
format: String,
/// Graph name
#[arg(short = 'g', long, default_value = "default")]
graph: String,
/// Skip errors and continue
#[arg(long)]
skip_errors: bool,
},
/// Export graph data to file
Export {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
db: String,
/// Output file path
#[arg(short = 'o', long)]
output: String,
/// Output format (json, csv, cypher, graphml)
#[arg(long, default_value = "json")]
format: String,
/// Graph name
#[arg(short = 'g', long, default_value = "default")]
graph: String,
},
/// Show graph database information
Info {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
db: String,
/// Show detailed statistics
#[arg(long)]
detailed: bool,
},
/// Run graph benchmarks
Benchmark {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
db: String,
/// Number of queries to run
#[arg(short = 'n', long, default_value = "1000")]
queries: usize,
/// Benchmark type (traverse, pattern, aggregate)
#[arg(short = 't', long, default_value = "traverse")]
bench_type: String,
},
/// Start HTTP/gRPC server
Serve {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
db: String,
/// Server host
#[arg(long, default_value = "127.0.0.1")]
host: String,
/// HTTP port
#[arg(long, default_value = "8080")]
http_port: u16,
/// gRPC port
#[arg(long, default_value = "50051")]
grpc_port: u16,
/// Enable GraphQL endpoint
#[arg(long)]
graphql: bool,
},
}
/// Create a new graph database
pub fn create_graph(path: &str, name: &str, indexed: bool, config: &Config) -> Result<()> {
println!(
"{}",
format_success(&format!("Creating graph database at: {}", path))
);
println!(" Graph name: {}", name.cyan());
println!(
" Property indexing: {}",
if indexed {
"enabled".green()
} else {
"disabled".dimmed()
}
);
// TODO: Integrate with ruvector-neo4j when available
// For now, create a placeholder implementation
std::fs::create_dir_all(Path::new(path).parent().unwrap_or(Path::new(".")))?;
println!("{}", format_success("Graph database created successfully!"));
println!(
"{}",
format_info("Use 'ruvector graph shell' to start interactive mode")
);
Ok(())
}
/// Execute a Cypher query
pub fn execute_query(
db_path: &str,
cypher: &str,
format: &str,
explain: bool,
config: &Config,
) -> Result<()> {
if explain {
println!("{}", "Query Execution Plan:".bold().cyan());
println!("{}", format_info("EXPLAIN mode - showing query plan"));
}
let start = Instant::now();
// TODO: Integrate with ruvector-neo4j Neo4jGraph implementation
// Placeholder for actual query execution
println!("{}", format_success("Executing Cypher query..."));
println!(" Query: {}", cypher.dimmed());
let elapsed = start.elapsed();
match format {
"table" => {
println!("\n{}", format_graph_results_table(&[], cypher));
}
"json" => {
println!("{}", format_graph_results_json(&[])?);
}
"csv" => {
println!("{}", format_graph_results_csv(&[])?);
}
_ => return Err(anyhow::anyhow!("Unsupported output format: {}", format)),
}
println!(
"\n{}",
format!("Query completed in {:.2}ms", elapsed.as_secs_f64() * 1000.0).dimmed()
);
Ok(())
}
/// Interactive Cypher shell (REPL)
pub fn run_shell(db_path: &str, multiline: bool, config: &Config) -> Result<()> {
println!("{}", "RuVector Graph Shell".bold().green());
println!("Database: {}", db_path.cyan());
println!(
"Type {} to exit, {} for help\n",
":exit".yellow(),
":help".yellow()
);
let stdin = io::stdin();
let mut stdout = io::stdout();
let mut query_buffer = String::new();
loop {
// Print prompt
if multiline && !query_buffer.is_empty() {
print!("{}", " ... ".dimmed());
} else {
print!("{}", "cypher> ".green().bold());
}
stdout.flush()?;
// Read line
let mut line = String::new();
stdin.lock().read_line(&mut line)?;
let line = line.trim();
// Handle special commands
match line {
":exit" | ":quit" | ":q" => {
println!("{}", format_success("Goodbye!"));
break;
}
":help" | ":h" => {
print_shell_help();
continue;
}
":clear" => {
query_buffer.clear();
println!("{}", format_info("Query buffer cleared"));
continue;
}
"" => {
if !multiline || query_buffer.is_empty() {
continue;
}
// In multiline mode, empty line executes query
}
_ => {
query_buffer.push_str(line);
query_buffer.push(' ');
if multiline && !line.ends_with(';') {
continue; // Continue reading in multiline mode
}
}
}
// Execute query
let query = query_buffer.trim().trim_end_matches(';');
if !query.is_empty() {
match execute_query(db_path, query, "table", false, config) {
Ok(_) => {}
Err(e) => println!("{}", format_error(&e.to_string())),
}
}
query_buffer.clear();
}
Ok(())
}
/// Import graph data from file
pub fn import_graph(
db_path: &str,
input_file: &str,
format: &str,
graph_name: &str,
skip_errors: bool,
config: &Config,
) -> Result<()> {
println!(
"{}",
format_success(&format!("Importing graph data from: {}", input_file))
);
println!(" Format: {}", format.cyan());
println!(" Graph: {}", graph_name.cyan());
println!(
" Skip errors: {}",
if skip_errors {
"yes".yellow()
} else {
"no".dimmed()
}
);
let start = Instant::now();
// TODO: Implement actual import logic with ruvector-neo4j
match format {
"csv" => {
println!("{}", format_info("Parsing CSV file..."));
// Parse CSV and create nodes/relationships
}
"json" => {
println!("{}", format_info("Parsing JSON file..."));
// Parse JSON and create graph structure
}
"cypher" => {
println!("{}", format_info("Executing Cypher statements..."));
// Execute Cypher commands from file
}
_ => return Err(anyhow::anyhow!("Unsupported import format: {}", format)),
}
let elapsed = start.elapsed();
println!(
"{}",
format_success(&format!(
"Import completed in {:.2}s",
elapsed.as_secs_f64()
))
);
Ok(())
}
/// Export graph data to file
pub fn export_graph(
db_path: &str,
output_file: &str,
format: &str,
graph_name: &str,
config: &Config,
) -> Result<()> {
println!(
"{}",
format_success(&format!("Exporting graph to: {}", output_file))
);
println!(" Format: {}", format.cyan());
println!(" Graph: {}", graph_name.cyan());
let start = Instant::now();
// TODO: Implement actual export logic with ruvector-neo4j
match format {
"json" => {
println!("{}", format_info("Generating JSON export..."));
// Export as JSON graph format
}
"csv" => {
println!("{}", format_info("Generating CSV export..."));
// Export nodes and edges as CSV files
}
"cypher" => {
println!("{}", format_info("Generating Cypher statements..."));
// Export as Cypher CREATE statements
}
"graphml" => {
println!("{}", format_info("Generating GraphML export..."));
// Export as GraphML XML format
}
_ => return Err(anyhow::anyhow!("Unsupported export format: {}", format)),
}
let elapsed = start.elapsed();
println!(
"{}",
format_success(&format!(
"Export completed in {:.2}s",
elapsed.as_secs_f64()
))
);
Ok(())
}
/// Show graph database information
pub fn show_graph_info(db_path: &str, detailed: bool, config: &Config) -> Result<()> {
println!("\n{}", "Graph Database Statistics".bold().green());
// TODO: Integrate with ruvector-neo4j to get actual statistics
println!(" Database: {}", db_path.cyan());
println!(" Graphs: {}", "1".cyan());
println!(" Total nodes: {}", "0".cyan());
println!(" Total relationships: {}", "0".cyan());
println!(" Node labels: {}", "0".cyan());
println!(" Relationship types: {}", "0".cyan());
if detailed {
println!("\n{}", "Storage Information:".bold().cyan());
println!(" Store size: {}", "0 bytes".cyan());
println!(" Index size: {}", "0 bytes".cyan());
println!("\n{}", "Configuration:".bold().cyan());
println!(" Cache size: {}", "N/A".cyan());
println!(" Page size: {}", "N/A".cyan());
}
Ok(())
}
/// Run graph benchmarks
pub fn run_graph_benchmark(
db_path: &str,
num_queries: usize,
bench_type: &str,
config: &Config,
) -> Result<()> {
println!("{}", "Running graph benchmark...".bold().green());
println!(" Benchmark type: {}", bench_type.cyan());
println!(" Queries: {}", num_queries.to_string().cyan());
let start = Instant::now();
// TODO: Implement actual benchmarks with ruvector-neo4j
match bench_type {
"traverse" => {
println!("{}", format_info("Benchmarking graph traversal..."));
// Run traversal queries
}
"pattern" => {
println!("{}", format_info("Benchmarking pattern matching..."));
// Run pattern matching queries
}
"aggregate" => {
println!("{}", format_info("Benchmarking aggregations..."));
// Run aggregation queries
}
_ => return Err(anyhow::anyhow!("Unknown benchmark type: {}", bench_type)),
}
let elapsed = start.elapsed();
let qps = num_queries as f64 / elapsed.as_secs_f64();
let avg_latency = elapsed.as_secs_f64() * 1000.0 / num_queries as f64;
println!("\n{}", "Benchmark Results:".bold().green());
println!(" Total time: {:.2}s", elapsed.as_secs_f64());
println!(" Queries per second: {:.0}", qps.to_string().cyan());
println!(" Average latency: {:.2}ms", avg_latency.to_string().cyan());
Ok(())
}
/// Start HTTP/gRPC server
pub fn serve_graph(
db_path: &str,
host: &str,
http_port: u16,
grpc_port: u16,
enable_graphql: bool,
config: &Config,
) -> Result<()> {
println!("{}", "Starting RuVector Graph Server...".bold().green());
println!(" Database: {}", db_path.cyan());
println!(
" HTTP endpoint: {}:{}",
host.cyan(),
http_port.to_string().cyan()
);
println!(
" gRPC endpoint: {}:{}",
host.cyan(),
grpc_port.to_string().cyan()
);
if enable_graphql {
println!(
" GraphQL endpoint: {}:{}/graphql",
host.cyan(),
http_port.to_string().cyan()
);
}
println!("\n{}", format_info("Server configuration loaded"));
// TODO: Implement actual server with ruvector-neo4j
println!("{}", format_success("Server ready! Press Ctrl+C to stop."));
// Placeholder - would run actual server here
println!(
"\n{}",
format_info("Server implementation pending - integrate with ruvector-neo4j")
);
Ok(())
}
// Helper functions for formatting graph results
fn format_graph_results_table(results: &[serde_json::Value], query: &str) -> String {
let mut output = String::new();
if results.is_empty() {
output.push_str(&format!("{}\n", "No results found".dimmed()));
output.push_str(&format!("Query: {}\n", query.dimmed()));
} else {
output.push_str(&format!("{} results\n", results.len().to_string().cyan()));
// TODO: Format results as table
}
output
}
fn format_graph_results_json(results: &[serde_json::Value]) -> Result<String> {
serde_json::to_string_pretty(&results)
.map_err(|e| anyhow::anyhow!("Failed to serialize results: {}", e))
}
fn format_graph_results_csv(results: &[serde_json::Value]) -> Result<String> {
// TODO: Implement CSV formatting
Ok(String::new())
}
fn print_shell_help() {
println!("\n{}", "RuVector Graph Shell Commands".bold().cyan());
println!(" {} - Exit the shell", ":exit, :quit, :q".yellow());
println!(
" {} - Show this help message",
":help, :h".yellow()
);
println!(" {} - Clear query buffer", ":clear".yellow());
println!("\n{}", "Cypher Examples:".bold().cyan());
println!(" {}", "CREATE (n:Person {{name: 'Alice'}})".dimmed());
println!(" {}", "MATCH (n:Person) RETURN n".dimmed());
println!(" {}", "MATCH (a)-[r:KNOWS]->(b) RETURN a, r, b".dimmed());
println!();
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,415 @@
//! PostgreSQL storage backend for hooks intelligence data
//!
//! This module provides PostgreSQL-based storage for the hooks system,
//! using the ruvector extension for vector operations.
//!
//! Enable with the `postgres` feature flag.
#[cfg(feature = "postgres")]
use deadpool_postgres::{Config, Pool, Runtime};
#[cfg(feature = "postgres")]
use tokio_postgres::NoTls;
use std::env;
/// PostgreSQL storage configuration
#[derive(Debug, Clone)]
pub struct PostgresConfig {
pub host: String,
pub port: u16,
pub user: String,
pub password: Option<String>,
pub dbname: String,
}
impl PostgresConfig {
/// Create config from environment variables
pub fn from_env() -> Option<Self> {
// Try RUVECTOR_POSTGRES_URL first, then DATABASE_URL
if let Ok(url) = env::var("RUVECTOR_POSTGRES_URL").or_else(|_| env::var("DATABASE_URL")) {
return Self::from_url(&url);
}
// Try individual environment variables
let host = env::var("RUVECTOR_PG_HOST").unwrap_or_else(|_| "localhost".to_string());
let port = env::var("RUVECTOR_PG_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(5432);
let user = env::var("RUVECTOR_PG_USER").ok()?;
let password = env::var("RUVECTOR_PG_PASSWORD").ok();
let dbname = env::var("RUVECTOR_PG_DATABASE").unwrap_or_else(|_| "ruvector".to_string());
Some(Self {
host,
port,
user,
password,
dbname,
})
}
/// Parse PostgreSQL connection URL
pub fn from_url(url: &str) -> Option<Self> {
// Parse postgres://user:password@host:port/dbname
let url = url
.strip_prefix("postgres://")
.or_else(|| url.strip_prefix("postgresql://"))?;
let (auth, rest) = url.split_once('@')?;
let (user, password) = if auth.contains(':') {
let (u, p) = auth.split_once(':')?;
(u.to_string(), Some(p.to_string()))
} else {
(auth.to_string(), None)
};
let (host_port, dbname) = rest.split_once('/')?;
let dbname = dbname.split('?').next()?.to_string();
let (host, port) = if host_port.contains(':') {
let (h, p) = host_port.split_once(':')?;
(h.to_string(), p.parse().ok()?)
} else {
(host_port.to_string(), 5432)
};
Some(Self {
host,
port,
user,
password,
dbname,
})
}
}
/// PostgreSQL storage backend for hooks
#[cfg(feature = "postgres")]
pub struct PostgresStorage {
pool: Pool,
}
#[cfg(feature = "postgres")]
impl PostgresStorage {
/// Create a new PostgreSQL storage backend
pub async fn new(config: PostgresConfig) -> Result<Self, Box<dyn std::error::Error>> {
let mut cfg = Config::new();
cfg.host = Some(config.host);
cfg.port = Some(config.port);
cfg.user = Some(config.user);
cfg.password = config.password;
cfg.dbname = Some(config.dbname);
let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls)?;
Ok(Self { pool })
}
/// Update Q-value for state-action pair
pub async fn update_q(
&self,
state: &str,
action: &str,
reward: f32,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
client
.execute(
"SELECT ruvector_hooks_update_q($1, $2, $3)",
&[&state, &action, &reward],
)
.await?;
Ok(())
}
/// Get best action for state
pub async fn best_action(
&self,
state: &str,
actions: &[String],
) -> Result<Option<(String, f32, f32)>, Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
let row = client
.query_opt(
"SELECT action, q_value, confidence FROM ruvector_hooks_best_action($1, $2)",
&[&state, &actions],
)
.await?;
Ok(row.map(|r| (r.get(0), r.get(1), r.get(2))))
}
/// Store content in semantic memory
pub async fn remember(
&self,
memory_type: &str,
content: &str,
embedding: Option<&[f32]>,
metadata: &serde_json::Value,
) -> Result<i32, Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
let metadata_str = serde_json::to_string(metadata)?;
let row = client
.query_one(
"SELECT ruvector_hooks_remember($1, $2, $3, $4::jsonb)",
&[&memory_type, &content, &embedding, &metadata_str],
)
.await?;
Ok(row.get(0))
}
/// Search memory semantically
pub async fn recall(
&self,
query_embedding: &[f32],
limit: i32,
) -> Result<Vec<MemoryResult>, Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
let rows = client
.query(
"SELECT id, memory_type, content, metadata::text, similarity
FROM ruvector_hooks_recall($1, $2)",
&[&query_embedding, &limit],
)
.await?;
Ok(rows
.iter()
.map(|r| {
let metadata_str: String = r.get(3);
MemoryResult {
id: r.get(0),
memory_type: r.get(1),
content: r.get(2),
metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
similarity: r.get(4),
}
})
.collect())
}
/// Record file sequence
pub async fn record_sequence(
&self,
from_file: &str,
to_file: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
client
.execute(
"SELECT ruvector_hooks_record_sequence($1, $2)",
&[&from_file, &to_file],
)
.await?;
Ok(())
}
/// Get suggested next files
pub async fn suggest_next(
&self,
file: &str,
limit: i32,
) -> Result<Vec<(String, i32)>, Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
let rows = client
.query(
"SELECT to_file, count FROM ruvector_hooks_suggest_next($1, $2)",
&[&file, &limit],
)
.await?;
Ok(rows.iter().map(|r| (r.get(0), r.get(1))).collect())
}
/// Record error pattern
pub async fn record_error(
&self,
code: &str,
error_type: &str,
message: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
client
.execute(
"SELECT ruvector_hooks_record_error($1, $2, $3)",
&[&code, &error_type, &message],
)
.await?;
Ok(())
}
/// Register agent in swarm
pub async fn swarm_register(
&self,
agent_id: &str,
agent_type: &str,
capabilities: &[String],
) -> Result<(), Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
client
.execute(
"SELECT ruvector_hooks_swarm_register($1, $2, $3)",
&[&agent_id, &agent_type, &capabilities],
)
.await?;
Ok(())
}
/// Record coordination between agents
pub async fn swarm_coordinate(
&self,
source: &str,
target: &str,
weight: f32,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
client
.execute(
"SELECT ruvector_hooks_swarm_coordinate($1, $2, $3)",
&[&source, &target, &weight],
)
.await?;
Ok(())
}
/// Get swarm statistics
pub async fn swarm_stats(&self) -> Result<SwarmStats, Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
let row = client
.query_one("SELECT * FROM ruvector_hooks_swarm_stats()", &[])
.await?;
Ok(SwarmStats {
total_agents: row.get(0),
active_agents: row.get(1),
total_edges: row.get(2),
avg_success_rate: row.get(3),
})
}
/// Get overall statistics
pub async fn get_stats(&self) -> Result<IntelligenceStats, Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
let row = client
.query_one("SELECT * FROM ruvector_hooks_get_stats()", &[])
.await?;
Ok(IntelligenceStats {
total_patterns: row.get(0),
total_memories: row.get(1),
total_trajectories: row.get(2),
total_errors: row.get(3),
session_count: row.get(4),
})
}
/// Start a new session
pub async fn session_start(&self) -> Result<(), Box<dyn std::error::Error>> {
let client = self.pool.get().await?;
client
.execute("SELECT ruvector_hooks_session_start()", &[])
.await?;
Ok(())
}
}
/// Memory search result
#[derive(Debug)]
pub struct MemoryResult {
pub id: i32,
pub memory_type: String,
pub content: String,
pub metadata: serde_json::Value,
pub similarity: f32,
}
/// Swarm statistics
#[derive(Debug)]
pub struct SwarmStats {
pub total_agents: i64,
pub active_agents: i64,
pub total_edges: i64,
pub avg_success_rate: f32,
}
/// Intelligence statistics
#[derive(Debug)]
pub struct IntelligenceStats {
pub total_patterns: i64,
pub total_memories: i64,
pub total_trajectories: i64,
pub total_errors: i64,
pub session_count: i64,
}
/// Check if PostgreSQL is available
pub fn is_postgres_available() -> bool {
PostgresConfig::from_env().is_some()
}
/// Storage backend selector
pub enum StorageBackend {
#[cfg(feature = "postgres")]
Postgres(PostgresStorage),
Json(super::Intelligence),
}
impl StorageBackend {
/// Create storage backend from environment
#[cfg(feature = "postgres")]
pub async fn from_env() -> Result<Self, Box<dyn std::error::Error>> {
if let Some(config) = PostgresConfig::from_env() {
match PostgresStorage::new(config).await {
Ok(pg) => return Ok(Self::Postgres(pg)),
Err(e) => {
eprintln!(
"Warning: PostgreSQL unavailable ({}), using JSON fallback",
e
);
}
}
}
Ok(Self::Json(super::Intelligence::new(
super::get_intelligence_path(),
)))
}
#[cfg(not(feature = "postgres"))]
pub fn from_env() -> Self {
Self::Json(super::Intelligence::new(super::get_intelligence_path()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_from_url() {
let config =
PostgresConfig::from_url("postgres://user:pass@localhost:5432/ruvector").unwrap();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 5432);
assert_eq!(config.user, "user");
assert_eq!(config.password, Some("pass".to_string()));
assert_eq!(config.dbname, "ruvector");
}
#[test]
fn test_config_from_url_no_password() {
let config = PostgresConfig::from_url("postgres://user@localhost/ruvector").unwrap();
assert_eq!(config.user, "user");
assert_eq!(config.password, None);
}
#[test]
fn test_config_from_url_with_query() {
let config = PostgresConfig::from_url(
"postgres://user:pass@localhost:5432/ruvector?sslmode=require",
)
.unwrap();
assert_eq!(config.dbname, "ruvector");
}
}

View File

@@ -0,0 +1,15 @@
//! CLI module for Ruvector
pub mod commands;
pub mod format;
pub mod graph;
pub mod hooks;
#[cfg(feature = "postgres")]
pub mod hooks_postgres;
pub mod progress;
pub use commands::*;
pub use format::*;
pub use graph::*;
pub use hooks::*;
pub use progress::ProgressTracker;

View File

@@ -0,0 +1,56 @@
// ! Progress tracking for CLI operations
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::time::Duration;
/// Progress tracker for long-running operations
pub struct ProgressTracker {
multi: MultiProgress,
}
impl ProgressTracker {
/// Create a new progress tracker
pub fn new() -> Self {
Self {
multi: MultiProgress::new(),
}
}
/// Create a progress bar for an operation
pub fn create_bar(&self, total: u64, message: &str) -> ProgressBar {
let pb = self.multi.add(ProgressBar::new(total));
pb.set_style(
ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta})")
.unwrap()
.progress_chars("#>-")
);
pb.set_message(message.to_string());
pb.enable_steady_tick(Duration::from_millis(100));
pb
}
/// Create a spinner for indeterminate operations
pub fn create_spinner(&self, message: &str) -> ProgressBar {
let pb = self.multi.add(ProgressBar::new_spinner());
pb.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} {msg}")
.unwrap(),
);
pb.set_message(message.to_string());
pb.enable_steady_tick(Duration::from_millis(100));
pb
}
/// Finish all progress bars
pub fn finish_all(&self) {
// Progress bars auto-finish when dropped
}
}
impl Default for ProgressTracker {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,280 @@
//! Configuration management for Ruvector CLI
use anyhow::{Context, Result};
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, QuantizationConfig};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
/// Ruvector CLI configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
/// Database options
#[serde(default)]
pub database: DatabaseConfig,
/// CLI options
#[serde(default)]
pub cli: CliConfig,
/// MCP server options
#[serde(default)]
pub mcp: McpConfig,
}
/// Database configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
/// Default storage path
#[serde(default = "default_storage_path")]
pub storage_path: String,
/// Default dimensions
#[serde(default = "default_dimensions")]
pub dimensions: usize,
/// Distance metric
#[serde(default = "default_distance_metric")]
pub distance_metric: DistanceMetric,
/// HNSW configuration
#[serde(default)]
pub hnsw: Option<HnswConfig>,
/// Quantization configuration
#[serde(default)]
pub quantization: Option<QuantizationConfig>,
}
/// CLI configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CliConfig {
/// Show progress bars
#[serde(default = "default_true")]
pub progress: bool,
/// Use colors in output
#[serde(default = "default_true")]
pub colors: bool,
/// Default batch size for operations
#[serde(default = "default_batch_size")]
pub batch_size: usize,
}
/// MCP server configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpConfig {
/// Server host for SSE transport
#[serde(default = "default_host")]
pub host: String,
/// Server port for SSE transport
#[serde(default = "default_port")]
pub port: u16,
/// Enable CORS
#[serde(default = "default_true")]
pub cors: bool,
/// Allowed data directory for MCP file operations (path confinement)
/// All db_path and backup_path values must resolve within this directory.
/// Defaults to the current working directory.
#[serde(default = "default_data_dir")]
pub data_dir: String,
}
// Default value functions
fn default_storage_path() -> String {
"./ruvector.db".to_string()
}
fn default_dimensions() -> usize {
384
}
fn default_distance_metric() -> DistanceMetric {
DistanceMetric::Cosine
}
fn default_true() -> bool {
true
}
fn default_batch_size() -> usize {
1000
}
fn default_data_dir() -> String {
std::env::current_dir()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|_| ".".to_string())
}
fn default_host() -> String {
"127.0.0.1".to_string()
}
fn default_port() -> u16 {
3000
}
impl Default for Config {
fn default() -> Self {
Self {
database: DatabaseConfig::default(),
cli: CliConfig::default(),
mcp: McpConfig::default(),
}
}
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
storage_path: default_storage_path(),
dimensions: default_dimensions(),
distance_metric: DistanceMetric::Cosine,
hnsw: Some(HnswConfig::default()),
quantization: Some(QuantizationConfig::Scalar),
}
}
}
impl Default for CliConfig {
fn default() -> Self {
Self {
progress: true,
colors: true,
batch_size: default_batch_size(),
}
}
}
impl Default for McpConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
cors: true,
data_dir: default_data_dir(),
}
}
}
impl Config {
/// Load configuration from file
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content =
std::fs::read_to_string(path.as_ref()).context("Failed to read config file")?;
let config: Config = toml::from_str(&content).context("Failed to parse config file")?;
Ok(config)
}
/// Load configuration with precedence: CLI args > env vars > config file > defaults
pub fn load(config_path: Option<PathBuf>) -> Result<Self> {
let mut config = if let Some(path) = config_path {
Self::from_file(&path).unwrap_or_default()
} else {
// Try default locations
Self::try_default_locations().unwrap_or_default()
};
// Override with environment variables
config.apply_env_vars()?;
Ok(config)
}
/// Try loading from default locations
fn try_default_locations() -> Option<Self> {
let paths = vec![
"ruvector.toml",
".ruvector.toml",
"~/.config/ruvector/config.toml",
"/etc/ruvector/config.toml",
];
for path in paths {
let expanded = shellexpand::tilde(path).to_string();
if let Ok(config) = Self::from_file(&expanded) {
return Some(config);
}
}
None
}
/// Apply environment variable overrides
fn apply_env_vars(&mut self) -> Result<()> {
if let Ok(path) = std::env::var("RUVECTOR_STORAGE_PATH") {
self.database.storage_path = path;
}
if let Ok(dims) = std::env::var("RUVECTOR_DIMENSIONS") {
self.database.dimensions = dims.parse().context("Invalid RUVECTOR_DIMENSIONS")?;
}
if let Ok(metric) = std::env::var("RUVECTOR_DISTANCE_METRIC") {
self.database.distance_metric = match metric.to_lowercase().as_str() {
"euclidean" => DistanceMetric::Euclidean,
"cosine" => DistanceMetric::Cosine,
"dotproduct" => DistanceMetric::DotProduct,
"manhattan" => DistanceMetric::Manhattan,
_ => return Err(anyhow::anyhow!("Invalid distance metric: {}", metric)),
};
}
if let Ok(host) = std::env::var("RUVECTOR_MCP_HOST") {
self.mcp.host = host;
}
if let Ok(port) = std::env::var("RUVECTOR_MCP_PORT") {
self.mcp.port = port.parse().context("Invalid RUVECTOR_MCP_PORT")?;
}
if let Ok(data_dir) = std::env::var("RUVECTOR_MCP_DATA_DIR") {
self.mcp.data_dir = data_dir;
}
Ok(())
}
/// Convert to DbOptions
pub fn to_db_options(&self) -> DbOptions {
DbOptions {
dimensions: self.database.dimensions,
distance_metric: self.database.distance_metric,
storage_path: self.database.storage_path.clone(),
hnsw_config: self.database.hnsw.clone(),
quantization: self.database.quantization.clone(),
}
}
/// Save configuration to file
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
std::fs::write(path, content).context("Failed to write config file")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.database.dimensions, 384);
assert_eq!(config.cli.batch_size, 1000);
assert_eq!(config.mcp.port, 3000);
}
#[test]
fn test_config_serialization() {
let config = Config::default();
let toml_str = toml::to_string(&config).unwrap();
let parsed: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(config.database.dimensions, parsed.database.dimensions);
}
}

View File

@@ -0,0 +1,416 @@
//! Ruvector CLI - High-performance vector database command-line interface
use anyhow::Result;
use clap::{Parser, Subcommand};
use colored::*;
use std::path::PathBuf;
mod cli;
mod config;
use crate::cli::commands::*;
use crate::config::Config;
#[derive(Parser)]
#[command(name = "ruvector")]
#[command(about = "High-performance Rust vector database CLI", long_about = None)]
#[command(version)]
struct Cli {
/// Configuration file path
#[arg(short, long, global = true)]
config: Option<PathBuf>,
/// Enable debug mode
#[arg(short, long, global = true)]
debug: bool,
/// Disable colors
#[arg(long, global = true)]
no_color: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Create a new vector database
Create {
/// Database file path
#[arg(short, long, default_value = "./ruvector.db")]
path: String,
/// Vector dimensions
#[arg(short = 'D', long)]
dimensions: usize,
},
/// Insert vectors from a file
Insert {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector.db")]
db: String,
/// Input file path
#[arg(short, long)]
input: String,
/// Input format (json, csv, npy)
#[arg(short, long, default_value = "json")]
format: String,
/// Hide progress bar
#[arg(long)]
no_progress: bool,
},
/// Search for similar vectors
Search {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector.db")]
db: String,
/// Query vector (comma-separated floats or JSON array)
#[arg(short, long)]
query: String,
/// Number of results
#[arg(short = 'k', long, default_value = "10")]
top_k: usize,
/// Show full vectors in results
#[arg(long)]
show_vectors: bool,
},
/// Show database information
Info {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector.db")]
db: String,
},
/// Run a quick performance benchmark
Benchmark {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector.db")]
db: String,
/// Number of queries to run
#[arg(short = 'n', long, default_value = "1000")]
queries: usize,
},
/// Export database to file
Export {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector.db")]
db: String,
/// Output file path
#[arg(short, long)]
output: String,
/// Output format (json, csv)
#[arg(short, long, default_value = "json")]
format: String,
},
/// Import from other vector databases
Import {
/// Database file path
#[arg(short = 'b', long, default_value = "./ruvector.db")]
db: String,
/// Source database type (faiss, pinecone, weaviate)
#[arg(short, long)]
source: String,
/// Source file or connection path
#[arg(short = 'p', long)]
source_path: String,
},
/// Graph database operations (Neo4j-compatible)
Graph {
#[command(subcommand)]
action: cli::graph::GraphCommands,
},
/// Self-learning intelligence hooks for Claude Code
Hooks {
#[command(subcommand)]
action: cli::hooks::HooksCommands,
},
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize logging
if cli.debug {
tracing_subscriber::fmt()
.with_env_filter("ruvector=debug")
.init();
}
// Disable colors if requested
if cli.no_color {
colored::control::set_override(false);
}
// Load configuration
let config = Config::load(cli.config)?;
// Execute command
let result = match cli.command {
Commands::Create { path, dimensions } => create_database(&path, dimensions, &config),
Commands::Insert {
db,
input,
format,
no_progress,
} => insert_vectors(&db, &input, &format, &config, !no_progress),
Commands::Search {
db,
query,
top_k,
show_vectors,
} => {
let query_vec = parse_query_vector(&query)?;
search_vectors(&db, query_vec, top_k, &config, show_vectors)
}
Commands::Info { db } => show_info(&db, &config),
Commands::Benchmark { db, queries } => run_benchmark(&db, &config, queries),
Commands::Export { db, output, format } => export_database(&db, &output, &format, &config),
Commands::Import {
db,
source,
source_path,
} => import_from_external(&db, &source, &source_path, &config),
Commands::Graph { action } => {
use cli::graph::GraphCommands;
match action {
GraphCommands::Create {
path,
name,
indexed,
} => cli::graph::create_graph(&path, &name, indexed, &config),
GraphCommands::Query {
db,
cypher,
format,
explain,
} => cli::graph::execute_query(&db, &cypher, &format, explain, &config),
GraphCommands::Shell { db, multiline } => {
cli::graph::run_shell(&db, multiline, &config)
}
GraphCommands::Import {
db,
input,
format,
graph,
skip_errors,
} => cli::graph::import_graph(&db, &input, &format, &graph, skip_errors, &config),
GraphCommands::Export {
db,
output,
format,
graph,
} => cli::graph::export_graph(&db, &output, &format, &graph, &config),
GraphCommands::Info { db, detailed } => {
cli::graph::show_graph_info(&db, detailed, &config)
}
GraphCommands::Benchmark {
db,
queries,
bench_type,
} => cli::graph::run_graph_benchmark(&db, queries, &bench_type, &config),
GraphCommands::Serve {
db,
host,
http_port,
grpc_port,
graphql,
} => cli::graph::serve_graph(&db, &host, http_port, grpc_port, graphql, &config),
}
}
Commands::Hooks { action } => {
use cli::hooks::HooksCommands;
match action {
HooksCommands::Init { force, postgres } => {
cli::hooks::init_hooks(force, postgres, &config)
}
HooksCommands::Install { settings_dir } => {
cli::hooks::install_hooks(&settings_dir, &config)
}
HooksCommands::Stats => cli::hooks::show_stats(&config),
HooksCommands::Remember {
memory_type,
content,
} => cli::hooks::remember_content(&memory_type, &content.join(" "), &config),
HooksCommands::Recall { query, top_k } => {
cli::hooks::recall_content(&query.join(" "), top_k, &config)
}
HooksCommands::Learn {
state,
action,
reward,
} => cli::hooks::learn_trajectory(&state, &action, reward, &config),
HooksCommands::Suggest { state, actions } => {
cli::hooks::suggest_action(&state, &actions, &config)
}
HooksCommands::Route {
task,
file,
crate_name,
operation,
} => cli::hooks::route_task(
&task.join(" "),
file.as_deref(),
crate_name.as_deref(),
&operation,
&config,
),
HooksCommands::PreEdit { file } => cli::hooks::pre_edit_hook(&file, &config),
HooksCommands::PostEdit { file, success } => {
cli::hooks::post_edit_hook(&file, success, &config)
}
HooksCommands::PreCommand { command } => {
cli::hooks::pre_command_hook(&command.join(" "), &config)
}
HooksCommands::PostCommand {
command,
success,
stderr,
} => cli::hooks::post_command_hook(
&command.join(" "),
success,
stderr.as_deref(),
&config,
),
HooksCommands::SessionStart { session_id, resume } => {
cli::hooks::session_start_hook(session_id.as_deref(), resume, &config)
}
HooksCommands::SessionEnd { export_metrics } => {
cli::hooks::session_end_hook(export_metrics, &config)
}
HooksCommands::PreCompact { length, auto } => {
cli::hooks::pre_compact_hook(length, auto, &config)
}
HooksCommands::SuggestContext => cli::hooks::suggest_context_cmd(&config),
HooksCommands::TrackNotification { notification_type } => {
cli::hooks::track_notification_cmd(notification_type.as_deref(), &config)
}
// Claude Code v2.0.55+ features
HooksCommands::LspDiagnostic {
file,
severity,
message,
} => cli::hooks::lsp_diagnostic_cmd(
file.as_deref(),
severity.as_deref(),
message.as_deref(),
&config,
),
HooksCommands::SuggestUltrathink { task, file } => {
cli::hooks::suggest_ultrathink_cmd(&task.join(" "), file.as_deref(), &config)
}
HooksCommands::AsyncAgent {
action,
agent_id,
task,
} => cli::hooks::async_agent_cmd(
&action,
agent_id.as_deref(),
task.as_deref(),
&config,
),
HooksCommands::RecordError { command, stderr } => {
cli::hooks::record_error_cmd(&command, &stderr, &config)
}
HooksCommands::SuggestFix { error_code } => {
cli::hooks::suggest_fix_cmd(&error_code, &config)
}
HooksCommands::SuggestNext { file, count } => {
cli::hooks::suggest_next_cmd(&file, count, &config)
}
HooksCommands::ShouldTest { file } => cli::hooks::should_test_cmd(&file, &config),
HooksCommands::SwarmRegister {
agent_id,
agent_type,
capabilities,
} => cli::hooks::swarm_register_cmd(
&agent_id,
&agent_type,
capabilities.as_deref(),
&config,
),
HooksCommands::SwarmCoordinate {
source,
target,
weight,
} => cli::hooks::swarm_coordinate_cmd(&source, &target, weight, &config),
HooksCommands::SwarmOptimize { tasks } => {
cli::hooks::swarm_optimize_cmd(&tasks, &config)
}
HooksCommands::SwarmRecommend { task_type } => {
cli::hooks::swarm_recommend_cmd(&task_type, &config)
}
HooksCommands::SwarmHeal { agent_id } => {
cli::hooks::swarm_heal_cmd(&agent_id, &config)
}
HooksCommands::SwarmStats => cli::hooks::swarm_stats_cmd(&config),
HooksCommands::Completions { shell } => cli::hooks::generate_completions(shell),
HooksCommands::Compress => cli::hooks::compress_storage(&config),
HooksCommands::CacheStats => cli::hooks::cache_stats(&config),
}
}
};
// Handle errors
if let Err(e) = result {
eprintln!("{}", cli::format::format_error(&e.to_string()));
if cli.debug {
eprintln!("\n{:#?}", e);
} else {
eprintln!("\n{}", "Run with --debug for more details".dimmed());
}
std::process::exit(1);
}
Ok(())
}
/// Parse query vector from string
fn parse_query_vector(s: &str) -> Result<Vec<f32>> {
// Try JSON first
if s.trim().starts_with('[') {
return serde_json::from_str(s)
.map_err(|e| anyhow::anyhow!("Failed to parse query vector as JSON: {}", e));
}
// Try comma-separated
s.split(',')
.map(|s| s.trim().parse::<f32>())
.collect::<std::result::Result<Vec<f32>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse query vector: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_query_vector_json() {
let vec = parse_query_vector("[1.0, 2.0, 3.0]").unwrap();
assert_eq!(vec, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_parse_query_vector_csv() {
let vec = parse_query_vector("1.0, 2.0, 3.0").unwrap();
assert_eq!(vec, vec![1.0, 2.0, 3.0]);
}
}

View File

@@ -0,0 +1,463 @@
//! GNN Layer Caching for Performance Optimization
//!
//! This module provides persistent caching for GNN layers and query results,
//! eliminating the ~2.5s overhead per operation from process initialization,
//! database loading, and index deserialization.
//!
//! ## Performance Impact
//!
//! | Operation | Before | After | Improvement |
//! |-----------|--------|-------|-------------|
//! | Layer init | ~2.5s | ~5-10ms | 250-500x |
//! | Query | ~2.5s | ~5-10ms | 250-500x |
//! | Batch query | ~2.5s * N | ~5-10ms | Amortized |
use lru::LruCache;
use ruvector_gnn::layer::RuvectorLayer;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
/// Cache entry with metadata for monitoring
#[derive(Debug, Clone)]
pub struct CacheEntry<T> {
pub value: T,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
}
impl<T: Clone> CacheEntry<T> {
pub fn new(value: T) -> Self {
let now = Instant::now();
Self {
value,
created_at: now,
last_accessed: now,
access_count: 1,
}
}
pub fn access(&mut self) -> &T {
self.last_accessed = Instant::now();
self.access_count += 1;
&self.value
}
}
/// Configuration for the GNN cache
#[derive(Debug, Clone)]
pub struct GnnCacheConfig {
/// Maximum number of GNN layers to cache
pub max_layers: usize,
/// Maximum number of query results to cache
pub max_query_results: usize,
/// TTL for cached query results (in seconds)
pub query_result_ttl_secs: u64,
/// Whether to preload common layer configurations
pub preload_common: bool,
}
impl Default for GnnCacheConfig {
fn default() -> Self {
Self {
max_layers: 32,
max_query_results: 1000,
query_result_ttl_secs: 300, // 5 minutes
preload_common: true,
}
}
}
/// Query result cache key
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct QueryCacheKey {
/// Layer configuration hash
pub layer_hash: String,
/// Query vector hash (first 8 floats as u64 bits)
pub query_hash: u64,
/// Number of results requested
pub k: usize,
}
impl QueryCacheKey {
pub fn new(layer_id: &str, query: &[f32], k: usize) -> Self {
// Simple hash of query vector
let query_hash = query
.iter()
.take(8)
.fold(0u64, |acc, &v| acc.wrapping_add(v.to_bits() as u64));
Self {
layer_hash: layer_id.to_string(),
query_hash,
k,
}
}
}
/// Cached query result
#[derive(Debug, Clone)]
pub struct CachedQueryResult {
pub result: Vec<f32>,
pub cached_at: Instant,
}
/// GNN Layer cache with LRU eviction and TTL support
pub struct GnnCache {
/// Cached GNN layers by configuration hash
layers: Arc<RwLock<HashMap<String, CacheEntry<RuvectorLayer>>>>,
/// LRU cache for query results
query_results: Arc<RwLock<LruCache<QueryCacheKey, CachedQueryResult>>>,
/// Configuration
config: GnnCacheConfig,
/// Cache statistics
stats: Arc<RwLock<CacheStats>>,
}
/// Cache statistics for monitoring
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub layer_hits: u64,
pub layer_misses: u64,
pub query_hits: u64,
pub query_misses: u64,
pub evictions: u64,
pub total_queries: u64,
}
impl CacheStats {
pub fn layer_hit_rate(&self) -> f64 {
let total = self.layer_hits + self.layer_misses;
if total == 0 {
0.0
} else {
self.layer_hits as f64 / total as f64
}
}
pub fn query_hit_rate(&self) -> f64 {
let total = self.query_hits + self.query_misses;
if total == 0 {
0.0
} else {
self.query_hits as f64 / total as f64
}
}
}
impl GnnCache {
/// Create a new GNN cache with the given configuration
pub fn new(config: GnnCacheConfig) -> Self {
let query_cache_size =
NonZeroUsize::new(config.max_query_results).unwrap_or(NonZeroUsize::new(1000).unwrap());
Self {
layers: Arc::new(RwLock::new(HashMap::new())),
query_results: Arc::new(RwLock::new(LruCache::new(query_cache_size))),
config,
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
/// Get or create a GNN layer with the specified configuration
pub async fn get_or_create_layer(
&self,
input_dim: usize,
hidden_dim: usize,
heads: usize,
dropout: f32,
) -> RuvectorLayer {
let key = format!(
"{}_{}_{}_{}",
input_dim,
hidden_dim,
heads,
(dropout * 1000.0) as u32
);
// Check cache first
{
let mut layers = self.layers.write().await;
if let Some(entry) = layers.get_mut(&key) {
let mut stats = self.stats.write().await;
stats.layer_hits += 1;
return entry.access().clone();
}
}
// Create new layer
let layer = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout)
.expect("GNN layer cache: invalid layer configuration");
// Cache it
{
let mut layers = self.layers.write().await;
let mut stats = self.stats.write().await;
stats.layer_misses += 1;
// Evict if necessary
if layers.len() >= self.config.max_layers {
// Simple eviction: remove oldest entry
if let Some(oldest_key) = layers
.iter()
.min_by_key(|(_, v)| v.last_accessed)
.map(|(k, _)| k.clone())
{
layers.remove(&oldest_key);
stats.evictions += 1;
}
}
layers.insert(key, CacheEntry::new(layer.clone()));
}
layer
}
/// Get cached query result if available and not expired
pub async fn get_query_result(&self, key: &QueryCacheKey) -> Option<Vec<f32>> {
let mut results = self.query_results.write().await;
if let Some(cached) = results.get(key) {
let ttl = Duration::from_secs(self.config.query_result_ttl_secs);
if cached.cached_at.elapsed() < ttl {
let mut stats = self.stats.write().await;
stats.query_hits += 1;
stats.total_queries += 1;
return Some(cached.result.clone());
}
// Expired, remove it
results.pop(key);
}
let mut stats = self.stats.write().await;
stats.query_misses += 1;
stats.total_queries += 1;
None
}
/// Cache a query result
pub async fn cache_query_result(&self, key: QueryCacheKey, result: Vec<f32>) {
let mut results = self.query_results.write().await;
results.put(
key,
CachedQueryResult {
result,
cached_at: Instant::now(),
},
);
}
/// Get current cache statistics
pub async fn stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
/// Clear all caches
pub async fn clear(&self) {
self.layers.write().await.clear();
self.query_results.write().await.clear();
}
/// Preload common layer configurations for faster first access
pub async fn preload_common_layers(&self) {
// Common configurations used in practice
let common_configs = [
(128, 256, 4, 0.1), // Small model
(256, 512, 8, 0.1), // Medium model
(384, 768, 8, 0.1), // Base model (BERT-like)
(768, 1024, 16, 0.1), // Large model
];
for (input, hidden, heads, dropout) in common_configs {
let _ = self
.get_or_create_layer(input, hidden, heads, dropout)
.await;
}
}
/// Get number of cached layers
pub async fn layer_count(&self) -> usize {
self.layers.read().await.len()
}
/// Get number of cached query results
pub async fn query_result_count(&self) -> usize {
self.query_results.read().await.len()
}
}
/// Batch operation for multiple GNN forward passes
#[derive(Debug, Clone)]
pub struct BatchGnnRequest {
pub layer_config: LayerConfig,
pub operations: Vec<GnnOperation>,
}
#[derive(Debug, Clone)]
pub struct LayerConfig {
pub input_dim: usize,
pub hidden_dim: usize,
pub heads: usize,
pub dropout: f32,
}
#[derive(Debug, Clone)]
pub struct GnnOperation {
pub node_embedding: Vec<f32>,
pub neighbor_embeddings: Vec<Vec<f32>>,
pub edge_weights: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct BatchGnnResult {
pub results: Vec<Vec<f32>>,
pub cached_count: usize,
pub computed_count: usize,
pub total_time_ms: f64,
}
impl GnnCache {
/// Execute batch GNN operations with caching
pub async fn batch_forward(&self, request: BatchGnnRequest) -> BatchGnnResult {
let start = Instant::now();
// Get or create the layer
let layer = self
.get_or_create_layer(
request.layer_config.input_dim,
request.layer_config.hidden_dim,
request.layer_config.heads,
request.layer_config.dropout,
)
.await;
let layer_id = format!(
"{}_{}_{}",
request.layer_config.input_dim,
request.layer_config.hidden_dim,
request.layer_config.heads
);
let mut results = Vec::with_capacity(request.operations.len());
let mut cached_count = 0;
let mut computed_count = 0;
for op in &request.operations {
// Check cache
let cache_key = QueryCacheKey::new(&layer_id, &op.node_embedding, 1);
if let Some(cached) = self.get_query_result(&cache_key).await {
results.push(cached);
cached_count += 1;
} else {
// Compute forward pass
let result = layer.forward(
&op.node_embedding,
&op.neighbor_embeddings,
&op.edge_weights,
);
// Cache the result
self.cache_query_result(cache_key, result.clone()).await;
results.push(result);
computed_count += 1;
}
}
BatchGnnResult {
results,
cached_count,
computed_count,
total_time_ms: start.elapsed().as_secs_f64() * 1000.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_layer_caching() {
let cache = GnnCache::new(GnnCacheConfig::default());
// First access - miss
let layer1 = cache.get_or_create_layer(128, 256, 4, 0.1).await;
let stats = cache.stats().await;
assert_eq!(stats.layer_misses, 1);
assert_eq!(stats.layer_hits, 0);
// Second access - hit
let _layer2 = cache.get_or_create_layer(128, 256, 4, 0.1).await;
let stats = cache.stats().await;
assert_eq!(stats.layer_misses, 1);
assert_eq!(stats.layer_hits, 1);
}
#[tokio::test]
async fn test_query_result_caching() {
let cache = GnnCache::new(GnnCacheConfig::default());
let key = QueryCacheKey::new("test", &[1.0, 2.0, 3.0], 10);
let result = vec![0.1, 0.2, 0.3];
// Cache miss
assert!(cache.get_query_result(&key).await.is_none());
// Cache the result
cache.cache_query_result(key.clone(), result.clone()).await;
// Cache hit
let cached = cache.get_query_result(&key).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap(), result);
}
#[tokio::test]
async fn test_batch_forward() {
let cache = GnnCache::new(GnnCacheConfig::default());
let request = BatchGnnRequest {
layer_config: LayerConfig {
input_dim: 4,
hidden_dim: 8,
heads: 2,
dropout: 0.1,
},
operations: vec![
GnnOperation {
node_embedding: vec![1.0, 2.0, 3.0, 4.0],
neighbor_embeddings: vec![vec![0.5, 1.0, 1.5, 2.0]],
edge_weights: vec![1.0],
},
GnnOperation {
node_embedding: vec![2.0, 3.0, 4.0, 5.0],
neighbor_embeddings: vec![vec![1.0, 1.5, 2.0, 2.5]],
edge_weights: vec![1.0],
},
],
};
let result = cache.batch_forward(request).await;
assert_eq!(result.results.len(), 2);
assert_eq!(result.computed_count, 2);
assert_eq!(result.cached_count, 0);
}
#[tokio::test]
async fn test_preload_common_layers() {
let cache = GnnCache::new(GnnCacheConfig {
preload_common: true,
..Default::default()
});
cache.preload_common_layers().await;
// Should have 4 preloaded layers
assert_eq!(cache.layer_count().await, 4);
}
}

View File

@@ -0,0 +1,927 @@
//! MCP request handlers
use super::gnn_cache::{BatchGnnRequest, GnnCache, GnnCacheConfig, GnnOperation, LayerConfig};
use super::protocol::*;
use crate::config::Config;
use anyhow::{Context, Result};
use ruvector_core::{
types::{DbOptions, DistanceMetric, SearchQuery, VectorEntry},
VectorDB,
};
use ruvector_gnn::{compress::TensorCompress, search::differentiable_search};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
/// MCP handler state with GNN caching for performance optimization
pub struct McpHandler {
config: Config,
databases: Arc<RwLock<HashMap<String, Arc<VectorDB>>>>,
/// GNN layer cache for eliminating ~2.5s initialization overhead
gnn_cache: Arc<GnnCache>,
/// Tensor compressor for GNN operations
tensor_compress: Arc<TensorCompress>,
/// Allowed base directory for all file operations (path confinement)
allowed_data_dir: PathBuf,
}
impl McpHandler {
pub fn new(config: Config) -> Self {
let gnn_cache = Arc::new(GnnCache::new(GnnCacheConfig::default()));
let allowed_data_dir = PathBuf::from(&config.mcp.data_dir);
// Canonicalize at startup so all later comparisons are absolute
let allowed_data_dir = std::fs::canonicalize(&allowed_data_dir)
.unwrap_or_else(|_| std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/")));
Self {
config,
databases: Arc::new(RwLock::new(HashMap::new())),
gnn_cache,
tensor_compress: Arc::new(TensorCompress::new()),
allowed_data_dir,
}
}
/// Initialize with preloaded GNN layers for optimal performance
pub async fn with_preload(config: Config) -> Self {
let handler = Self::new(config);
handler.gnn_cache.preload_common_layers().await;
handler
}
/// Validate that a user-supplied path resolves within the allowed data directory.
///
/// Prevents CWE-22 path traversal by:
/// 1. Resolving the path relative to `allowed_data_dir` (not cwd)
/// 2. Canonicalizing to eliminate `..`, symlinks, and other tricks
/// 3. Checking that the canonical path starts with the allowed directory
fn validate_path(&self, user_path: &str) -> Result<PathBuf> {
// Reject obviously malicious absolute paths outside data dir
let path = Path::new(user_path);
// If relative, resolve against allowed_data_dir
let resolved = if path.is_absolute() {
PathBuf::from(user_path)
} else {
self.allowed_data_dir.join(user_path)
};
// For existing paths, canonicalize resolves symlinks and ..
// For non-existing paths, canonicalize the parent and append the filename
let canonical = if resolved.exists() {
std::fs::canonicalize(&resolved)
.with_context(|| format!("Failed to resolve path: {}", user_path))?
} else {
// Canonicalize the parent directory (must exist), then append filename
let parent = resolved.parent().unwrap_or(Path::new("/"));
let parent_canonical = if parent.exists() {
std::fs::canonicalize(parent).with_context(|| {
format!("Parent directory does not exist: {}", parent.display())
})?
} else {
// Create the parent directory within allowed_data_dir if it doesn't exist
anyhow::bail!(
"Path '{}' references non-existent directory '{}'",
user_path,
parent.display()
);
};
let filename = resolved
.file_name()
.ok_or_else(|| anyhow::anyhow!("Invalid path: no filename in '{}'", user_path))?;
parent_canonical.join(filename)
};
// Security check: canonical path must be inside allowed_data_dir
if !canonical.starts_with(&self.allowed_data_dir) {
anyhow::bail!(
"Access denied: path '{}' resolves to '{}' which is outside the allowed data directory '{}'",
user_path,
canonical.display(),
self.allowed_data_dir.display()
);
}
Ok(canonical)
}
/// Handle MCP request
pub async fn handle_request(&self, request: McpRequest) -> McpResponse {
match request.method.as_str() {
"initialize" => self.handle_initialize(request.id).await,
"tools/list" => self.handle_tools_list(request.id).await,
"tools/call" => self.handle_tools_call(request.id, request.params).await,
"resources/list" => self.handle_resources_list(request.id).await,
"resources/read" => self.handle_resources_read(request.id, request.params).await,
"prompts/list" => self.handle_prompts_list(request.id).await,
"prompts/get" => self.handle_prompts_get(request.id, request.params).await,
_ => McpResponse::error(
request.id,
McpError::new(error_codes::METHOD_NOT_FOUND, "Method not found"),
),
}
}
async fn handle_initialize(&self, id: Option<Value>) -> McpResponse {
McpResponse::success(
id,
json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {},
"resources": {},
"prompts": {}
},
"serverInfo": {
"name": "ruvector-mcp",
"version": env!("CARGO_PKG_VERSION")
}
}),
)
}
async fn handle_tools_list(&self, id: Option<Value>) -> McpResponse {
let tools = vec![
McpTool {
name: "vector_db_create".to_string(),
description: "Create a new vector database".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"path": {"type": "string", "description": "Database file path"},
"dimensions": {"type": "integer", "description": "Vector dimensions"},
"distance_metric": {"type": "string", "enum": ["euclidean", "cosine", "dotproduct", "manhattan"]}
},
"required": ["path", "dimensions"]
}),
},
McpTool {
name: "vector_db_insert".to_string(),
description: "Insert vectors into database".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"db_path": {"type": "string"},
"vectors": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"vector": {"type": "array", "items": {"type": "number"}},
"metadata": {"type": "object"}
}
}
}
},
"required": ["db_path", "vectors"]
}),
},
McpTool {
name: "vector_db_search".to_string(),
description: "Search for similar vectors".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"db_path": {"type": "string"},
"query": {"type": "array", "items": {"type": "number"}},
"k": {"type": "integer", "default": 10},
"filter": {"type": "object"}
},
"required": ["db_path", "query"]
}),
},
McpTool {
name: "vector_db_stats".to_string(),
description: "Get database statistics".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"db_path": {"type": "string"}
},
"required": ["db_path"]
}),
},
McpTool {
name: "vector_db_backup".to_string(),
description: "Backup database to file".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"db_path": {"type": "string"},
"backup_path": {"type": "string"}
},
"required": ["db_path", "backup_path"]
}),
},
// GNN Tools with persistent caching (~250-500x faster)
McpTool {
name: "gnn_layer_create".to_string(),
description: "Create/cache a GNN layer (eliminates ~2.5s init overhead)"
.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"input_dim": {"type": "integer", "description": "Input embedding dimension"},
"hidden_dim": {"type": "integer", "description": "Hidden layer dimension"},
"heads": {"type": "integer", "description": "Number of attention heads"},
"dropout": {"type": "number", "default": 0.1, "description": "Dropout rate"}
},
"required": ["input_dim", "hidden_dim", "heads"]
}),
},
McpTool {
name: "gnn_forward".to_string(),
description: "Forward pass through cached GNN layer (~5-10ms vs ~2.5s)".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"layer_id": {"type": "string", "description": "Layer config: input_hidden_heads"},
"node_embedding": {"type": "array", "items": {"type": "number"}},
"neighbor_embeddings": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}},
"edge_weights": {"type": "array", "items": {"type": "number"}}
},
"required": ["layer_id", "node_embedding", "neighbor_embeddings", "edge_weights"]
}),
},
McpTool {
name: "gnn_batch_forward".to_string(),
description: "Batch GNN forward passes with result caching (amortized cost)"
.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"layer_config": {
"type": "object",
"properties": {
"input_dim": {"type": "integer"},
"hidden_dim": {"type": "integer"},
"heads": {"type": "integer"},
"dropout": {"type": "number", "default": 0.1}
},
"required": ["input_dim", "hidden_dim", "heads"]
},
"operations": {
"type": "array",
"items": {
"type": "object",
"properties": {
"node_embedding": {"type": "array", "items": {"type": "number"}},
"neighbor_embeddings": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}},
"edge_weights": {"type": "array", "items": {"type": "number"}}
}
}
}
},
"required": ["layer_config", "operations"]
}),
},
McpTool {
name: "gnn_cache_stats".to_string(),
description: "Get GNN cache statistics (hit rates, counts)".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"include_details": {"type": "boolean", "default": false}
}
}),
},
McpTool {
name: "gnn_compress".to_string(),
description: "Compress embedding based on access frequency".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"embedding": {"type": "array", "items": {"type": "number"}},
"access_freq": {"type": "number", "description": "Access frequency 0.0-1.0"}
},
"required": ["embedding", "access_freq"]
}),
},
McpTool {
name: "gnn_decompress".to_string(),
description: "Decompress a compressed tensor".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"compressed_json": {"type": "string", "description": "Compressed tensor JSON"}
},
"required": ["compressed_json"]
}),
},
McpTool {
name: "gnn_search".to_string(),
description: "Differentiable search with soft attention".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"query": {"type": "array", "items": {"type": "number"}},
"candidates": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}},
"k": {"type": "integer", "description": "Number of results"},
"temperature": {"type": "number", "default": 1.0}
},
"required": ["query", "candidates", "k"]
}),
},
];
McpResponse::success(id, json!({ "tools": tools }))
}
async fn handle_tools_call(&self, id: Option<Value>, params: Option<Value>) -> McpResponse {
let params = match params {
Some(p) => p,
None => {
return McpResponse::error(
id,
McpError::new(error_codes::INVALID_PARAMS, "Missing params"),
)
}
};
let tool_name = params["name"].as_str().unwrap_or("");
let arguments = &params["arguments"];
let result = match tool_name {
// Vector DB tools
"vector_db_create" => self.tool_create_db(arguments).await,
"vector_db_insert" => self.tool_insert(arguments).await,
"vector_db_search" => self.tool_search(arguments).await,
"vector_db_stats" => self.tool_stats(arguments).await,
"vector_db_backup" => self.tool_backup(arguments).await,
// GNN tools with caching
"gnn_layer_create" => self.tool_gnn_layer_create(arguments).await,
"gnn_forward" => self.tool_gnn_forward(arguments).await,
"gnn_batch_forward" => self.tool_gnn_batch_forward(arguments).await,
"gnn_cache_stats" => self.tool_gnn_cache_stats(arguments).await,
"gnn_compress" => self.tool_gnn_compress(arguments).await,
"gnn_decompress" => self.tool_gnn_decompress(arguments).await,
"gnn_search" => self.tool_gnn_search(arguments).await,
_ => Err(anyhow::anyhow!("Unknown tool: {}", tool_name)),
};
match result {
Ok(value) => {
McpResponse::success(id, json!({ "content": [{"type": "text", "text": value}] }))
}
Err(e) => McpResponse::error(
id,
McpError::new(error_codes::INTERNAL_ERROR, e.to_string()),
),
}
}
async fn handle_resources_list(&self, id: Option<Value>) -> McpResponse {
McpResponse::success(
id,
json!({
"resources": [
{
"uri": "database://local/default",
"name": "Default Database",
"description": "Default vector database",
"mimeType": "application/x-ruvector-db"
}
]
}),
)
}
async fn handle_resources_read(
&self,
id: Option<Value>,
_params: Option<Value>,
) -> McpResponse {
McpResponse::success(
id,
json!({
"contents": [{
"uri": "database://local/default",
"mimeType": "application/json",
"text": "{\"status\": \"available\"}"
}]
}),
)
}
async fn handle_prompts_list(&self, id: Option<Value>) -> McpResponse {
McpResponse::success(
id,
json!({
"prompts": [
{
"name": "semantic-search",
"description": "Generate a semantic search query",
"arguments": [
{
"name": "query",
"description": "Natural language query",
"required": true
}
]
}
]
}),
)
}
async fn handle_prompts_get(&self, id: Option<Value>, _params: Option<Value>) -> McpResponse {
McpResponse::success(
id,
json!({
"description": "Semantic search template",
"messages": [
{
"role": "user",
"content": {
"type": "text",
"text": "Search for vectors related to: {{query}}"
}
}
]
}),
)
}
// Tool implementations
async fn tool_create_db(&self, args: &Value) -> Result<String> {
let params: CreateDbParams =
serde_json::from_value(args.clone()).context("Invalid parameters")?;
// Validate path to prevent directory traversal (CWE-22)
let validated_path = self.validate_path(&params.path)?;
let mut db_options = self.config.to_db_options();
db_options.storage_path = validated_path.to_string_lossy().to_string();
db_options.dimensions = params.dimensions;
if let Some(metric) = params.distance_metric {
db_options.distance_metric = match metric.as_str() {
"euclidean" => DistanceMetric::Euclidean,
"cosine" => DistanceMetric::Cosine,
"dotproduct" => DistanceMetric::DotProduct,
"manhattan" => DistanceMetric::Manhattan,
_ => DistanceMetric::Cosine,
};
}
let db = VectorDB::new(db_options)?;
let path_str = validated_path.to_string_lossy().to_string();
self.databases
.write()
.await
.insert(path_str.clone(), Arc::new(db));
Ok(format!("Database created at: {}", path_str))
}
async fn tool_insert(&self, args: &Value) -> Result<String> {
let params: InsertParams = serde_json::from_value(args.clone())?;
let db = self.get_or_open_db(&params.db_path).await?;
let entries: Vec<VectorEntry> = params
.vectors
.into_iter()
.map(|v| VectorEntry {
id: v.id,
vector: v.vector,
metadata: v.metadata.and_then(|m| serde_json::from_value(m).ok()),
})
.collect();
let ids = db.insert_batch(entries)?;
Ok(format!("Inserted {} vectors", ids.len()))
}
async fn tool_search(&self, args: &Value) -> Result<String> {
let params: SearchParams = serde_json::from_value(args.clone())?;
let db = self.get_or_open_db(&params.db_path).await?;
let results = db.search(SearchQuery {
vector: params.query,
k: params.k,
filter: params.filter.and_then(|f| serde_json::from_value(f).ok()),
ef_search: None,
})?;
serde_json::to_string_pretty(&results).context("Failed to serialize results")
}
async fn tool_stats(&self, args: &Value) -> Result<String> {
let params: StatsParams = serde_json::from_value(args.clone())?;
let db = self.get_or_open_db(&params.db_path).await?;
let count = db.len()?;
let options = db.options();
Ok(json!({
"count": count,
"dimensions": options.dimensions,
"distance_metric": format!("{:?}", options.distance_metric),
"hnsw_enabled": options.hnsw_config.is_some()
})
.to_string())
}
async fn tool_backup(&self, args: &Value) -> Result<String> {
let params: BackupParams = serde_json::from_value(args.clone())?;
// Validate both paths to prevent directory traversal (CWE-22)
let validated_db_path = self.validate_path(&params.db_path)?;
let validated_backup_path = self.validate_path(&params.backup_path)?;
std::fs::copy(&validated_db_path, &validated_backup_path)
.context("Failed to backup database")?;
Ok(format!("Backed up to: {}", validated_backup_path.display()))
}
async fn get_or_open_db(&self, path: &str) -> Result<Arc<VectorDB>> {
// Validate path to prevent directory traversal (CWE-22)
let validated_path = self.validate_path(path)?;
let path_str = validated_path.to_string_lossy().to_string();
let databases = self.databases.read().await;
if let Some(db) = databases.get(&path_str) {
return Ok(db.clone());
}
drop(databases);
// Open new database
let mut db_options = self.config.to_db_options();
db_options.storage_path = path_str.clone();
let db = Arc::new(VectorDB::new(db_options)?);
self.databases.write().await.insert(path_str, db.clone());
Ok(db)
}
// ==================== GNN Tool Implementations ====================
// These tools eliminate ~2.5s overhead per operation via persistent caching
/// Create or retrieve a cached GNN layer
async fn tool_gnn_layer_create(&self, args: &Value) -> Result<String> {
let params: GnnLayerCreateParams =
serde_json::from_value(args.clone()).context("Invalid parameters")?;
let start = Instant::now();
let _layer = self
.gnn_cache
.get_or_create_layer(
params.input_dim,
params.hidden_dim,
params.heads,
params.dropout,
)
.await;
let elapsed = start.elapsed();
let layer_id = format!(
"{}_{}_{}_{}",
params.input_dim,
params.hidden_dim,
params.heads,
(params.dropout * 1000.0) as u32
);
Ok(json!({
"layer_id": layer_id,
"input_dim": params.input_dim,
"hidden_dim": params.hidden_dim,
"heads": params.heads,
"dropout": params.dropout,
"creation_time_ms": elapsed.as_secs_f64() * 1000.0,
"cached": elapsed.as_millis() < 50 // <50ms indicates cache hit
})
.to_string())
}
/// Forward pass through a cached GNN layer
async fn tool_gnn_forward(&self, args: &Value) -> Result<String> {
let params: GnnForwardParams =
serde_json::from_value(args.clone()).context("Invalid parameters")?;
let start = Instant::now();
// Parse layer_id format: "input_hidden_heads_dropout"
let parts: Vec<&str> = params.layer_id.split('_').collect();
if parts.len() < 3 {
return Err(anyhow::anyhow!(
"Invalid layer_id format. Expected: input_hidden_heads[_dropout]"
));
}
let input_dim: usize = parts[0].parse()?;
let hidden_dim: usize = parts[1].parse()?;
let heads: usize = parts[2].parse()?;
let dropout: f32 = parts
.get(3)
.map(|s| s.parse::<u32>().unwrap_or(100) as f32 / 1000.0)
.unwrap_or(0.1);
let layer = self
.gnn_cache
.get_or_create_layer(input_dim, hidden_dim, heads, dropout)
.await;
// Convert f64 to f32
let node_f32: Vec<f32> = params.node_embedding.iter().map(|&x| x as f32).collect();
let neighbors_f32: Vec<Vec<f32>> = params
.neighbor_embeddings
.iter()
.map(|v| v.iter().map(|&x| x as f32).collect())
.collect();
let weights_f32: Vec<f32> = params.edge_weights.iter().map(|&x| x as f32).collect();
let result = layer.forward(&node_f32, &neighbors_f32, &weights_f32);
let elapsed = start.elapsed();
// Convert back to f64 for JSON
let result_f64: Vec<f64> = result.iter().map(|&x| x as f64).collect();
Ok(json!({
"result": result_f64,
"output_dim": result.len(),
"latency_ms": elapsed.as_secs_f64() * 1000.0
})
.to_string())
}
/// Batch forward passes with caching
async fn tool_gnn_batch_forward(&self, args: &Value) -> Result<String> {
let params: GnnBatchForwardParams =
serde_json::from_value(args.clone()).context("Invalid parameters")?;
let request = BatchGnnRequest {
layer_config: LayerConfig {
input_dim: params.layer_config.input_dim,
hidden_dim: params.layer_config.hidden_dim,
heads: params.layer_config.heads,
dropout: params.layer_config.dropout,
},
operations: params
.operations
.into_iter()
.map(|op| GnnOperation {
node_embedding: op.node_embedding.iter().map(|&x| x as f32).collect(),
neighbor_embeddings: op
.neighbor_embeddings
.iter()
.map(|v| v.iter().map(|&x| x as f32).collect())
.collect(),
edge_weights: op.edge_weights.iter().map(|&x| x as f32).collect(),
})
.collect(),
};
let batch_result = self.gnn_cache.batch_forward(request).await;
// Convert results to f64
let results_f64: Vec<Vec<f64>> = batch_result
.results
.iter()
.map(|r| r.iter().map(|&x| x as f64).collect())
.collect();
Ok(json!({
"results": results_f64,
"cached_count": batch_result.cached_count,
"computed_count": batch_result.computed_count,
"total_time_ms": batch_result.total_time_ms,
"avg_time_per_op_ms": batch_result.total_time_ms / (batch_result.cached_count + batch_result.computed_count) as f64
})
.to_string())
}
/// Get GNN cache statistics
async fn tool_gnn_cache_stats(&self, args: &Value) -> Result<String> {
let params: GnnCacheStatsParams =
serde_json::from_value(args.clone()).unwrap_or(GnnCacheStatsParams {
include_details: false,
});
let stats = self.gnn_cache.stats().await;
let layer_count = self.gnn_cache.layer_count().await;
let query_count = self.gnn_cache.query_result_count().await;
let mut result = json!({
"layer_hits": stats.layer_hits,
"layer_misses": stats.layer_misses,
"layer_hit_rate": format!("{:.2}%", stats.layer_hit_rate() * 100.0),
"query_hits": stats.query_hits,
"query_misses": stats.query_misses,
"query_hit_rate": format!("{:.2}%", stats.query_hit_rate() * 100.0),
"total_queries": stats.total_queries,
"evictions": stats.evictions,
"cached_layers": layer_count,
"cached_queries": query_count
});
if params.include_details {
result["estimated_memory_saved_ms"] = json!((stats.layer_hits as f64) * 2500.0);
// ~2.5s per hit
}
Ok(result.to_string())
}
/// Compress embedding based on access frequency
async fn tool_gnn_compress(&self, args: &Value) -> Result<String> {
let params: GnnCompressParams =
serde_json::from_value(args.clone()).context("Invalid parameters")?;
let embedding_f32: Vec<f32> = params.embedding.iter().map(|&x| x as f32).collect();
let compressed = self
.tensor_compress
.compress(&embedding_f32, params.access_freq as f32)
.map_err(|e| anyhow::anyhow!("Compression error: {}", e))?;
let compressed_json = serde_json::to_string(&compressed)?;
Ok(json!({
"compressed_json": compressed_json,
"original_size": params.embedding.len() * 4,
"compressed_size": compressed_json.len(),
"compression_ratio": (params.embedding.len() * 4) as f64 / compressed_json.len() as f64
})
.to_string())
}
/// Decompress a compressed tensor
async fn tool_gnn_decompress(&self, args: &Value) -> Result<String> {
let params: GnnDecompressParams =
serde_json::from_value(args.clone()).context("Invalid parameters")?;
let compressed: ruvector_gnn::compress::CompressedTensor =
serde_json::from_str(&params.compressed_json)
.context("Invalid compressed tensor JSON")?;
let decompressed = self
.tensor_compress
.decompress(&compressed)
.map_err(|e| anyhow::anyhow!("Decompression error: {}", e))?;
let decompressed_f64: Vec<f64> = decompressed.iter().map(|&x| x as f64).collect();
Ok(json!({
"embedding": decompressed_f64,
"dimensions": decompressed.len()
})
.to_string())
}
/// Differentiable search with soft attention
async fn tool_gnn_search(&self, args: &Value) -> Result<String> {
let params: GnnSearchParams =
serde_json::from_value(args.clone()).context("Invalid parameters")?;
let start = Instant::now();
let query_f32: Vec<f32> = params.query.iter().map(|&x| x as f32).collect();
let candidates_f32: Vec<Vec<f32>> = params
.candidates
.iter()
.map(|v| v.iter().map(|&x| x as f32).collect())
.collect();
let (indices, weights) = differentiable_search(
&query_f32,
&candidates_f32,
params.k,
params.temperature as f32,
);
let elapsed = start.elapsed();
Ok(json!({
"indices": indices,
"weights": weights.iter().map(|&w| w as f64).collect::<Vec<f64>>(),
"k": params.k,
"latency_ms": elapsed.as_secs_f64() * 1000.0
})
.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn handler_with_data_dir(data_dir: &Path) -> McpHandler {
let mut config = Config::default();
config.mcp.data_dir = data_dir.to_string_lossy().to_string();
McpHandler::new(config)
}
#[test]
fn test_validate_path_allows_relative_within_data_dir() {
let dir = tempdir().unwrap();
let handler = handler_with_data_dir(dir.path());
// Create a file to validate against
std::fs::write(dir.path().join("test.db"), b"test").unwrap();
let result = handler.validate_path("test.db");
assert!(result.is_ok(), "Should allow relative path within data dir");
assert!(result.unwrap().starts_with(dir.path()));
}
#[test]
fn test_validate_path_blocks_absolute_outside_data_dir() {
let dir = tempdir().unwrap();
let handler = handler_with_data_dir(dir.path());
let result = handler.validate_path("/etc/passwd");
assert!(result.is_err(), "Should block /etc/passwd");
let err = result.unwrap_err().to_string();
assert!(
err.contains("outside the allowed data directory"),
"Error should mention path confinement: {}",
err
);
}
#[test]
fn test_validate_path_blocks_dot_dot_traversal() {
let dir = tempdir().unwrap();
// Create a subdir so ../.. resolves to something real
let subdir = dir.path().join("sub");
std::fs::create_dir_all(&subdir).unwrap();
let handler = handler_with_data_dir(&subdir);
let result = handler.validate_path("../../../etc/passwd");
assert!(result.is_err(), "Should block ../ traversal: {:?}", result);
}
#[test]
fn test_validate_path_blocks_dot_dot_in_middle() {
let dir = tempdir().unwrap();
let handler = handler_with_data_dir(dir.path());
// Create the inner directory
std::fs::create_dir_all(dir.path().join("a")).unwrap();
let result = handler.validate_path("a/../../etc/passwd");
assert!(result.is_err(), "Should block ../ in the middle of path");
}
#[test]
fn test_validate_path_allows_subdirectory_within_data_dir() {
let dir = tempdir().unwrap();
let handler = handler_with_data_dir(dir.path());
// Create subdirectory
std::fs::create_dir_all(dir.path().join("backups")).unwrap();
let result = handler.validate_path("backups/mydb.bak");
assert!(
result.is_ok(),
"Should allow path in subdirectory: {:?}",
result
);
assert!(result.unwrap().starts_with(dir.path()));
}
#[test]
fn test_validate_path_allows_new_file_in_data_dir() {
let dir = tempdir().unwrap();
let handler = handler_with_data_dir(dir.path());
let result = handler.validate_path("new_database.db");
assert!(
result.is_ok(),
"Should allow new file in data dir: {:?}",
result
);
}
#[test]
fn test_validate_path_blocks_absolute_path_to_etc() {
let dir = tempdir().unwrap();
let handler = handler_with_data_dir(dir.path());
// Test all 3 POCs from the issue
for path in &["/etc/passwd", "/etc/shadow", "/etc/hosts"] {
let result = handler.validate_path(path);
assert!(result.is_err(), "Should block {}", path);
}
}
#[test]
fn test_validate_path_blocks_home_ssh_keys() {
let dir = tempdir().unwrap();
let handler = handler_with_data_dir(dir.path());
let result = handler.validate_path("~/.ssh/id_rsa");
// This is a relative path so it won't expand ~, but test the principle
let result2 = handler.validate_path("/root/.ssh/id_rsa");
assert!(result2.is_err(), "Should block /root/.ssh/id_rsa");
}
}

View File

@@ -0,0 +1,11 @@
//! Model Context Protocol (MCP) implementation for Ruvector
pub mod gnn_cache;
pub mod handlers;
pub mod protocol;
pub mod transport;
pub use gnn_cache::*;
pub use handlers::*;
pub use protocol::*;
pub use transport::*;

View File

@@ -0,0 +1,238 @@
//! MCP protocol types and utilities
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// MCP request message
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpRequest {
pub jsonrpc: String,
pub id: Option<Value>,
pub method: String,
pub params: Option<Value>,
}
/// MCP response message
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpResponse {
pub jsonrpc: String,
pub id: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<McpError>,
}
/// MCP error
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
impl McpError {
pub fn new(code: i32, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
data: None,
}
}
pub fn with_data(mut self, data: Value) -> Self {
self.data = Some(data);
self
}
}
/// Standard MCP error codes
pub mod error_codes {
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
}
impl McpResponse {
pub fn success(id: Option<Value>, result: Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
pub fn error(id: Option<Value>, error: McpError) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(error),
}
}
}
/// MCP Tool definition
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
pub description: String,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
/// MCP Resource definition
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpResource {
pub uri: String,
pub name: String,
pub description: String,
#[serde(rename = "mimeType")]
pub mime_type: String,
}
/// MCP Prompt definition
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpPrompt {
pub name: String,
pub description: String,
pub arguments: Option<Vec<PromptArgument>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptArgument {
pub name: String,
pub description: String,
pub required: bool,
}
/// Tool call parameters for vector_db_create
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateDbParams {
pub path: String,
pub dimensions: usize,
#[serde(default)]
pub distance_metric: Option<String>,
}
/// Tool call parameters for vector_db_insert
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InsertParams {
pub db_path: String,
pub vectors: Vec<VectorInsert>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorInsert {
pub id: Option<String>,
pub vector: Vec<f32>,
pub metadata: Option<Value>,
}
/// Tool call parameters for vector_db_search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchParams {
pub db_path: String,
pub query: Vec<f32>,
pub k: usize,
pub filter: Option<Value>,
}
/// Tool call parameters for vector_db_stats
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatsParams {
pub db_path: String,
}
/// Tool call parameters for vector_db_backup
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackupParams {
pub db_path: String,
pub backup_path: String,
}
// ==================== GNN Tool Parameters ====================
/// Tool call parameters for gnn_layer_create
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnLayerCreateParams {
pub input_dim: usize,
pub hidden_dim: usize,
pub heads: usize,
#[serde(default = "default_dropout")]
pub dropout: f32,
}
fn default_dropout() -> f32 {
0.1
}
/// Tool call parameters for gnn_forward
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnForwardParams {
pub layer_id: String,
pub node_embedding: Vec<f64>,
pub neighbor_embeddings: Vec<Vec<f64>>,
pub edge_weights: Vec<f64>,
}
/// Tool call parameters for gnn_batch_forward
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnBatchForwardParams {
pub layer_config: GnnLayerConfigParams,
pub operations: Vec<GnnOperationParams>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnLayerConfigParams {
pub input_dim: usize,
pub hidden_dim: usize,
pub heads: usize,
#[serde(default = "default_dropout")]
pub dropout: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnOperationParams {
pub node_embedding: Vec<f64>,
pub neighbor_embeddings: Vec<Vec<f64>>,
pub edge_weights: Vec<f64>,
}
/// Tool call parameters for gnn_cache_stats
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnCacheStatsParams {
#[serde(default)]
pub include_details: bool,
}
/// Tool call parameters for gnn_compress
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnCompressParams {
pub embedding: Vec<f64>,
pub access_freq: f64,
}
/// Tool call parameters for gnn_decompress
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnDecompressParams {
pub compressed_json: String,
}
/// Tool call parameters for gnn_search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GnnSearchParams {
pub query: Vec<f64>,
pub candidates: Vec<Vec<f64>>,
pub k: usize,
#[serde(default = "default_temperature")]
pub temperature: f64,
}
fn default_temperature() -> f64 {
1.0
}

View File

@@ -0,0 +1,186 @@
//! MCP transport layers (STDIO and SSE)
use super::{handlers::McpHandler, protocol::*};
use anyhow::Result;
use axum::{
extract::State,
http::{header, StatusCode},
response::{sse::Event, IntoResponse, Sse},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde_json;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tower_http::cors::{AllowOrigin, CorsLayer};
/// STDIO transport for local MCP communication
pub struct StdioTransport {
handler: Arc<McpHandler>,
}
impl StdioTransport {
pub fn new(handler: Arc<McpHandler>) -> Self {
Self { handler }
}
/// Run STDIO transport loop
pub async fn run(&self) -> Result<()> {
let stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
let mut reader = BufReader::new(stdin);
let mut line = String::new();
tracing::info!("MCP STDIO transport started");
loop {
line.clear();
let n = reader.read_line(&mut line).await?;
if n == 0 {
// EOF
break;
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
// Parse request
let request: McpRequest = match serde_json::from_str(trimmed) {
Ok(req) => req,
Err(e) => {
let error_response = McpResponse::error(
None,
McpError::new(error_codes::PARSE_ERROR, e.to_string()),
);
let response_json = serde_json::to_string(&error_response)?;
stdout.write_all(response_json.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
continue;
}
};
// Handle request
let response = self.handler.handle_request(request).await;
// Send response
let response_json = serde_json::to_string(&response)?;
stdout.write_all(response_json.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
}
tracing::info!("MCP STDIO transport stopped");
Ok(())
}
}
/// SSE (Server-Sent Events) transport for HTTP streaming
pub struct SseTransport {
handler: Arc<McpHandler>,
host: String,
port: u16,
}
impl SseTransport {
pub fn new(handler: Arc<McpHandler>, host: String, port: u16) -> Self {
Self {
handler,
host,
port,
}
}
/// Run SSE transport server
pub async fn run(&self) -> Result<()> {
// Use restrictive CORS: only allow localhost origins by default
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::predicate(|origin, _| {
if let Ok(origin_str) = origin.to_str() {
origin_str.starts_with("http://127.0.0.1")
|| origin_str.starts_with("http://localhost")
|| origin_str.starts_with("https://127.0.0.1")
|| origin_str.starts_with("https://localhost")
} else {
false
}
}))
.allow_methods([axum::http::Method::GET, axum::http::Method::POST])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]);
let app = Router::new()
.route("/", get(root))
.route("/mcp", post(mcp_handler))
.route("/mcp/sse", get(mcp_sse_handler))
.layer(cors)
.with_state(self.handler.clone());
let addr = format!("{}:{}", self.host, self.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("MCP SSE transport listening on http://{}", addr);
axum::serve(listener, app).await?;
Ok(())
}
}
// HTTP handlers
async fn root() -> &'static str {
"Ruvector MCP Server"
}
async fn mcp_handler(
State(handler): State<Arc<McpHandler>>,
Json(request): Json<McpRequest>,
) -> Json<McpResponse> {
let response = handler.handle_request(request).await;
Json(response)
}
async fn mcp_sse_handler(
State(handler): State<Arc<McpHandler>>,
) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
let stream = async_stream::stream! {
// Send initial connection event
yield Ok(Event::default().data("connected"));
// Keep connection alive with periodic pings
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
loop {
interval.tick().await;
yield Ok(Event::default().event("ping").data("keep-alive"));
}
};
Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(tokio::time::Duration::from_secs(30))
.text("keep-alive"),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
#[tokio::test]
async fn test_stdio_transport_creation() {
let config = Config::default();
let handler = Arc::new(McpHandler::new(config));
let _transport = StdioTransport::new(handler);
}
#[tokio::test]
async fn test_sse_transport_creation() {
let config = Config::default();
let handler = Arc::new(McpHandler::new(config));
let _transport = SseTransport::new(handler, "127.0.0.1".to_string(), 3000);
}
}

View File

@@ -0,0 +1,90 @@
//! MCP Server for Ruvector - Main entry point
use anyhow::Result;
use clap::Parser;
use std::path::PathBuf;
use std::sync::Arc;
use tracing_subscriber;
mod config;
mod mcp;
use config::Config;
use mcp::{
handlers::McpHandler,
transport::{SseTransport, StdioTransport},
};
#[derive(Parser)]
#[command(name = "ruvector-mcp")]
#[command(about = "Ruvector MCP Server", long_about = None)]
#[command(version)]
struct Cli {
/// Configuration file path
#[arg(short, long)]
config: Option<PathBuf>,
/// Transport type (stdio or sse)
#[arg(short, long, default_value = "stdio")]
transport: String,
/// Host for SSE transport
#[arg(long)]
host: Option<String>,
/// Port for SSE transport
#[arg(short, long)]
port: Option<u16>,
/// Enable debug logging
#[arg(short, long)]
debug: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize logging
if cli.debug {
tracing_subscriber::fmt()
.with_env_filter("ruvector=debug")
.init();
} else {
tracing_subscriber::fmt()
.with_env_filter("ruvector=info")
.init();
}
// Load configuration
let config = Config::load(cli.config)?;
// Create MCP handler
let handler = Arc::new(McpHandler::new(config.clone()));
// Run appropriate transport
match cli.transport.as_str() {
"stdio" => {
tracing::info!("Starting MCP server with STDIO transport");
let transport = StdioTransport::new(handler);
transport.run().await?;
}
"sse" => {
let host = cli.host.unwrap_or(config.mcp.host.clone());
let port = cli.port.unwrap_or(config.mcp.port);
tracing::info!(
"Starting MCP server with SSE transport on {}:{}",
host,
port
);
let transport = SseTransport::new(handler, host, port);
transport.run().await?;
}
_ => {
return Err(anyhow::anyhow!("Invalid transport type: {}", cli.transport));
}
}
Ok(())
}