Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
344
vendor/ruvector/crates/ruvector-cli/src/cli/commands.rs
vendored
Normal file
344
vendor/ruvector/crates/ruvector-cli/src/cli/commands.rs
vendored
Normal file
@@ -0,0 +1,344 @@
|
||||
//! CLI command implementations
|
||||
|
||||
use crate::cli::{
|
||||
export_csv, export_json, format_error, format_search_results, format_stats, format_success,
|
||||
ProgressTracker,
|
||||
};
|
||||
use crate::config::Config;
|
||||
use anyhow::{Context, Result};
|
||||
use colored::*;
|
||||
use ruvector_core::{
|
||||
types::{DbOptions, SearchQuery, VectorEntry},
|
||||
VectorDB,
|
||||
};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
/// Create a new database
|
||||
pub fn create_database(path: &str, dimensions: usize, config: &Config) -> Result<()> {
|
||||
let mut db_options = config.to_db_options();
|
||||
db_options.storage_path = path.to_string();
|
||||
db_options.dimensions = dimensions;
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!("Creating database at: {}", path))
|
||||
);
|
||||
println!(" Dimensions: {}", dimensions.to_string().cyan());
|
||||
println!(" Distance metric: {:?}", db_options.distance_metric);
|
||||
|
||||
let _db = VectorDB::new(db_options).context("Failed to create database")?;
|
||||
|
||||
println!("{}", format_success("Database created successfully!"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Insert vectors from a file
|
||||
pub fn insert_vectors(
|
||||
db_path: &str,
|
||||
input_file: &str,
|
||||
format: &str,
|
||||
config: &Config,
|
||||
show_progress: bool,
|
||||
) -> Result<()> {
|
||||
// Load database
|
||||
let mut db_options = config.to_db_options();
|
||||
db_options.storage_path = db_path.to_string();
|
||||
|
||||
let db = VectorDB::new(db_options).context("Failed to open database")?;
|
||||
|
||||
// Parse input file
|
||||
let entries = match format {
|
||||
"json" => parse_json_file(input_file)?,
|
||||
"csv" => parse_csv_file(input_file)?,
|
||||
"npy" => parse_npy_file(input_file)?,
|
||||
_ => return Err(anyhow::anyhow!("Unsupported format: {}", format)),
|
||||
};
|
||||
|
||||
let total = entries.len();
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!("Loaded {} vectors from {}", total, input_file))
|
||||
);
|
||||
|
||||
// Insert with progress
|
||||
let start = Instant::now();
|
||||
let tracker = ProgressTracker::new();
|
||||
let pb = if show_progress {
|
||||
Some(tracker.create_bar(total as u64, "Inserting vectors..."))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let batch_size = config.cli.batch_size;
|
||||
let mut inserted = 0;
|
||||
|
||||
for chunk in entries.chunks(batch_size) {
|
||||
db.insert_batch(chunk.to_vec())
|
||||
.context("Failed to insert batch")?;
|
||||
inserted += chunk.len();
|
||||
|
||||
if let Some(ref pb) = pb {
|
||||
pb.set_position(inserted as u64);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pb) = pb {
|
||||
pb.finish_with_message("Insertion complete!");
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!(
|
||||
"Inserted {} vectors in {:.2}s ({:.0} vectors/sec)",
|
||||
total,
|
||||
elapsed.as_secs_f64(),
|
||||
total as f64 / elapsed.as_secs_f64()
|
||||
))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search for similar vectors
|
||||
pub fn search_vectors(
|
||||
db_path: &str,
|
||||
query_vector: Vec<f32>,
|
||||
k: usize,
|
||||
config: &Config,
|
||||
show_vectors: bool,
|
||||
) -> Result<()> {
|
||||
let mut db_options = config.to_db_options();
|
||||
db_options.storage_path = db_path.to_string();
|
||||
|
||||
let db = VectorDB::new(db_options).context("Failed to open database")?;
|
||||
|
||||
let start = Instant::now();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query_vector,
|
||||
k,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.context("Failed to search")?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
println!("{}", format_search_results(&results, show_vectors));
|
||||
println!(
|
||||
"\n{}",
|
||||
format!(
|
||||
"Search completed in {:.2}ms",
|
||||
elapsed.as_secs_f64() * 1000.0
|
||||
)
|
||||
.dimmed()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Show database information
|
||||
pub fn show_info(db_path: &str, config: &Config) -> Result<()> {
|
||||
let mut db_options = config.to_db_options();
|
||||
db_options.storage_path = db_path.to_string();
|
||||
|
||||
let db = VectorDB::new(db_options).context("Failed to open database")?;
|
||||
|
||||
let count = db.len().context("Failed to get count")?;
|
||||
let dimensions = db.options().dimensions;
|
||||
let metric = format!("{:?}", db.options().distance_metric);
|
||||
|
||||
println!("{}", format_stats(count, dimensions, &metric));
|
||||
|
||||
if let Some(hnsw_config) = &db.options().hnsw_config {
|
||||
println!("{}", "HNSW Configuration:".bold().green());
|
||||
println!(" M: {}", hnsw_config.m.to_string().cyan());
|
||||
println!(
|
||||
" ef_construction: {}",
|
||||
hnsw_config.ef_construction.to_string().cyan()
|
||||
);
|
||||
println!(" ef_search: {}", hnsw_config.ef_search.to_string().cyan());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run a quick benchmark
|
||||
pub fn run_benchmark(db_path: &str, config: &Config, num_queries: usize) -> Result<()> {
|
||||
let mut db_options = config.to_db_options();
|
||||
db_options.storage_path = db_path.to_string();
|
||||
|
||||
let db = VectorDB::new(db_options).context("Failed to open database")?;
|
||||
|
||||
let dimensions = db.options().dimensions;
|
||||
|
||||
println!("{}", "Running benchmark...".bold().green());
|
||||
println!(" Queries: {}", num_queries.to_string().cyan());
|
||||
println!(" Dimensions: {}", dimensions.to_string().cyan());
|
||||
|
||||
// Generate random query vectors
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let queries: Vec<Vec<f32>> = (0..num_queries)
|
||||
.map(|_| (0..dimensions).map(|_| rng.gen()).collect())
|
||||
.collect();
|
||||
|
||||
// Warm-up
|
||||
for query in queries.iter().take(10) {
|
||||
let _ = db.search(SearchQuery {
|
||||
vector: query.clone(),
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
let start = Instant::now();
|
||||
for query in &queries {
|
||||
db.search(SearchQuery {
|
||||
vector: query.clone(),
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.context("Search failed")?;
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
let qps = num_queries as f64 / elapsed.as_secs_f64();
|
||||
let avg_latency = elapsed.as_secs_f64() * 1000.0 / num_queries as f64;
|
||||
|
||||
println!("\n{}", "Benchmark Results:".bold().green());
|
||||
println!(" Total time: {:.2}s", elapsed.as_secs_f64());
|
||||
println!(" Queries per second: {:.0}", qps.to_string().cyan());
|
||||
println!(" Average latency: {:.2}ms", avg_latency.to_string().cyan());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export database to file
|
||||
pub fn export_database(
|
||||
db_path: &str,
|
||||
output_file: &str,
|
||||
format: &str,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let mut db_options = config.to_db_options();
|
||||
db_options.storage_path = db_path.to_string();
|
||||
|
||||
let db = VectorDB::new(db_options).context("Failed to open database")?;
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!("Exporting database to: {}", output_file))
|
||||
);
|
||||
|
||||
// Export is currently limited - would need to add all_ids() method to VectorDB
|
||||
// For now, return an error with a helpful message
|
||||
return Err(anyhow::anyhow!(
|
||||
"Export functionality requires VectorDB::all_ids() method. This will be implemented in a future update."
|
||||
));
|
||||
|
||||
// TODO: Implement when VectorDB exposes all_ids()
|
||||
// let ids = db.all_ids()?;
|
||||
// let tracker = ProgressTracker::new();
|
||||
// let pb = tracker.create_bar(ids.len() as u64, "Exporting vectors...");
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Import from other vector databases
|
||||
pub fn import_from_external(
|
||||
db_path: &str,
|
||||
source: &str,
|
||||
source_path: &str,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!("Importing from {} database", source))
|
||||
);
|
||||
|
||||
match source {
|
||||
"faiss" => {
|
||||
// TODO: Implement FAISS import
|
||||
return Err(anyhow::anyhow!("FAISS import not yet implemented"));
|
||||
}
|
||||
"pinecone" => {
|
||||
// TODO: Implement Pinecone import
|
||||
return Err(anyhow::anyhow!("Pinecone import not yet implemented"));
|
||||
}
|
||||
"weaviate" => {
|
||||
// TODO: Implement Weaviate import
|
||||
return Err(anyhow::anyhow!("Weaviate import not yet implemented"));
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unsupported source: {}", source)),
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn parse_json_file(path: &str) -> Result<Vec<VectorEntry>> {
|
||||
let content = std::fs::read_to_string(path).context("Failed to read JSON file")?;
|
||||
serde_json::from_str(&content).context("Failed to parse JSON")
|
||||
}
|
||||
|
||||
fn parse_csv_file(path: &str) -> Result<Vec<VectorEntry>> {
|
||||
let mut reader = csv::Reader::from_path(path).context("Failed to open CSV file")?;
|
||||
|
||||
let mut entries = Vec::new();
|
||||
|
||||
for result in reader.records() {
|
||||
let record = result.context("Failed to read CSV record")?;
|
||||
|
||||
let id = if record.get(0).map(|s| s.is_empty()).unwrap_or(true) {
|
||||
None
|
||||
} else {
|
||||
Some(record.get(0).unwrap().to_string())
|
||||
};
|
||||
|
||||
let vector: Vec<f32> =
|
||||
serde_json::from_str(record.get(1).context("Missing vector column")?)
|
||||
.context("Failed to parse vector")?;
|
||||
|
||||
let metadata = if let Some(meta_str) = record.get(2) {
|
||||
if !meta_str.is_empty() {
|
||||
Some(serde_json::from_str(meta_str).context("Failed to parse metadata")?)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
entries.push(VectorEntry {
|
||||
id,
|
||||
vector,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
fn parse_npy_file(path: &str) -> Result<Vec<VectorEntry>> {
|
||||
use ndarray::Array2;
|
||||
use ndarray_npy::ReadNpyExt;
|
||||
|
||||
let file = std::fs::File::open(path).context("Failed to open NPY file")?;
|
||||
let array: Array2<f32> = Array2::read_npy(file).context("Failed to read NPY file")?;
|
||||
|
||||
let entries: Vec<VectorEntry> = array
|
||||
.outer_iter()
|
||||
.enumerate()
|
||||
.map(|(i, row)| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: row.to_vec(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
179
vendor/ruvector/crates/ruvector-cli/src/cli/format.rs
vendored
Normal file
179
vendor/ruvector/crates/ruvector-cli/src/cli/format.rs
vendored
Normal file
@@ -0,0 +1,179 @@
|
||||
//! Output formatting utilities
|
||||
|
||||
use colored::*;
|
||||
use ruvector_core::types::{SearchResult, VectorEntry};
|
||||
use serde_json;
|
||||
|
||||
/// Format search results for display
|
||||
pub fn format_search_results(results: &[SearchResult], show_vectors: bool) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
output.push_str(&format!("\n{}. {}\n", i + 1, result.id.bold()));
|
||||
output.push_str(&format!(" Score: {:.4}\n", result.score));
|
||||
|
||||
if let Some(metadata) = &result.metadata {
|
||||
if !metadata.is_empty() {
|
||||
output.push_str(&format!(
|
||||
" Metadata: {}\n",
|
||||
serde_json::to_string_pretty(metadata).unwrap_or_else(|_| "{}".to_string())
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if show_vectors {
|
||||
if let Some(vector) = &result.vector {
|
||||
let preview: Vec<f32> = vector.iter().take(5).copied().collect();
|
||||
output.push_str(&format!(" Vector (first 5): {:?}...\n", preview));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Format database statistics
|
||||
pub fn format_stats(count: usize, dimensions: usize, metric: &str) -> String {
|
||||
format!(
|
||||
"\n{}\n Vectors: {}\n Dimensions: {}\n Distance Metric: {}\n",
|
||||
"Database Statistics".bold().green(),
|
||||
count.to_string().cyan(),
|
||||
dimensions.to_string().cyan(),
|
||||
metric.cyan()
|
||||
)
|
||||
}
|
||||
|
||||
/// Format error message
|
||||
pub fn format_error(msg: &str) -> String {
|
||||
format!("{} {}", "Error:".red().bold(), msg)
|
||||
}
|
||||
|
||||
/// Format success message
|
||||
pub fn format_success(msg: &str) -> String {
|
||||
format!("{} {}", "✓".green().bold(), msg)
|
||||
}
|
||||
|
||||
/// Format warning message
|
||||
pub fn format_warning(msg: &str) -> String {
|
||||
format!("{} {}", "Warning:".yellow().bold(), msg)
|
||||
}
|
||||
|
||||
/// Format info message
|
||||
pub fn format_info(msg: &str) -> String {
|
||||
format!("{} {}", "ℹ".blue().bold(), msg)
|
||||
}
|
||||
|
||||
/// Export vector entries to JSON
|
||||
pub fn export_json(entries: &[VectorEntry]) -> anyhow::Result<String> {
|
||||
serde_json::to_string_pretty(entries)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize to JSON: {}", e))
|
||||
}
|
||||
|
||||
/// Export vector entries to CSV
|
||||
pub fn export_csv(entries: &[VectorEntry]) -> anyhow::Result<String> {
|
||||
let mut wtr = csv::Writer::from_writer(vec![]);
|
||||
|
||||
// Write header
|
||||
wtr.write_record(&["id", "vector", "metadata"])?;
|
||||
|
||||
// Write entries
|
||||
for entry in entries {
|
||||
wtr.write_record(&[
|
||||
entry.id.as_ref().map(|s| s.as_str()).unwrap_or(""),
|
||||
&serde_json::to_string(&entry.vector)?,
|
||||
&serde_json::to_string(&entry.metadata)?,
|
||||
])?;
|
||||
}
|
||||
|
||||
wtr.flush()?;
|
||||
String::from_utf8(wtr.into_inner()?)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to convert CSV to string: {}", e))
|
||||
}
|
||||
|
||||
// Graph-specific formatting functions
|
||||
|
||||
/// Format graph node for display
|
||||
pub fn format_graph_node(
|
||||
id: &str,
|
||||
labels: &[String],
|
||||
properties: &serde_json::Map<String, serde_json::Value>,
|
||||
) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
output.push_str(&format!("{} ({})\n", id.bold(), labels.join(":").cyan()));
|
||||
|
||||
if !properties.is_empty() {
|
||||
output.push_str(" Properties:\n");
|
||||
for (key, value) in properties {
|
||||
output.push_str(&format!(" {}: {}\n", key.yellow(), value));
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Format graph relationship for display
|
||||
pub fn format_graph_relationship(
|
||||
id: &str,
|
||||
rel_type: &str,
|
||||
start_node: &str,
|
||||
end_node: &str,
|
||||
properties: &serde_json::Map<String, serde_json::Value>,
|
||||
) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
output.push_str(&format!(
|
||||
"{} -[{}]-> {}\n",
|
||||
start_node.cyan(),
|
||||
rel_type.yellow(),
|
||||
end_node.cyan()
|
||||
));
|
||||
|
||||
if !properties.is_empty() {
|
||||
output.push_str(" Properties:\n");
|
||||
for (key, value) in properties {
|
||||
output.push_str(&format!(" {}: {}\n", key.yellow(), value));
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Format graph query results as table
|
||||
pub fn format_graph_table(headers: &[String], rows: &[Vec<String>]) -> String {
|
||||
use prettytable::{Cell, Row, Table};
|
||||
|
||||
let mut table = Table::new();
|
||||
|
||||
// Add headers
|
||||
let header_cells: Vec<Cell> = headers
|
||||
.iter()
|
||||
.map(|h| Cell::new(h).style_spec("Fyb"))
|
||||
.collect();
|
||||
table.add_row(Row::new(header_cells));
|
||||
|
||||
// Add rows
|
||||
for row in rows {
|
||||
let cells: Vec<Cell> = row.iter().map(|v| Cell::new(v)).collect();
|
||||
table.add_row(Row::new(cells));
|
||||
}
|
||||
|
||||
table.to_string()
|
||||
}
|
||||
|
||||
/// Format graph statistics
|
||||
pub fn format_graph_stats(
|
||||
node_count: usize,
|
||||
rel_count: usize,
|
||||
label_count: usize,
|
||||
rel_type_count: usize,
|
||||
) -> String {
|
||||
format!(
|
||||
"\n{}\n Nodes: {}\n Relationships: {}\n Labels: {}\n Relationship Types: {}\n",
|
||||
"Graph Statistics".bold().green(),
|
||||
node_count.to_string().cyan(),
|
||||
rel_count.to_string().cyan(),
|
||||
label_count.to_string().cyan(),
|
||||
rel_type_count.to_string().cyan()
|
||||
)
|
||||
}
|
||||
552
vendor/ruvector/crates/ruvector-cli/src/cli/graph.rs
vendored
Normal file
552
vendor/ruvector/crates/ruvector-cli/src/cli/graph.rs
vendored
Normal file
@@ -0,0 +1,552 @@
|
||||
//! Graph database command implementations
|
||||
|
||||
use crate::cli::{format_error, format_info, format_success, ProgressTracker};
|
||||
use crate::config::Config;
|
||||
use anyhow::{Context, Result};
|
||||
use colored::*;
|
||||
use std::io::{self, BufRead, Write};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Graph database subcommands
|
||||
#[derive(clap::Subcommand, Debug)]
|
||||
pub enum GraphCommands {
|
||||
/// Create a new graph database
|
||||
Create {
|
||||
/// Database file path
|
||||
#[arg(short, long, default_value = "./ruvector-graph.db")]
|
||||
path: String,
|
||||
|
||||
/// Graph name
|
||||
#[arg(short, long, default_value = "default")]
|
||||
name: String,
|
||||
|
||||
/// Enable property indexing
|
||||
#[arg(long)]
|
||||
indexed: bool,
|
||||
},
|
||||
|
||||
/// Execute a Cypher query
|
||||
Query {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
|
||||
db: String,
|
||||
|
||||
/// Cypher query to execute
|
||||
#[arg(short = 'q', long)]
|
||||
cypher: String,
|
||||
|
||||
/// Output format (table, json, csv)
|
||||
#[arg(long, default_value = "table")]
|
||||
format: String,
|
||||
|
||||
/// Show execution plan
|
||||
#[arg(long)]
|
||||
explain: bool,
|
||||
},
|
||||
|
||||
/// Interactive Cypher shell (REPL)
|
||||
Shell {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
|
||||
db: String,
|
||||
|
||||
/// Enable multiline mode
|
||||
#[arg(long)]
|
||||
multiline: bool,
|
||||
},
|
||||
|
||||
/// Import data from file
|
||||
Import {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
|
||||
db: String,
|
||||
|
||||
/// Input file path
|
||||
#[arg(short = 'i', long)]
|
||||
input: String,
|
||||
|
||||
/// Input format (csv, json, cypher)
|
||||
#[arg(long, default_value = "json")]
|
||||
format: String,
|
||||
|
||||
/// Graph name
|
||||
#[arg(short = 'g', long, default_value = "default")]
|
||||
graph: String,
|
||||
|
||||
/// Skip errors and continue
|
||||
#[arg(long)]
|
||||
skip_errors: bool,
|
||||
},
|
||||
|
||||
/// Export graph data to file
|
||||
Export {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
|
||||
db: String,
|
||||
|
||||
/// Output file path
|
||||
#[arg(short = 'o', long)]
|
||||
output: String,
|
||||
|
||||
/// Output format (json, csv, cypher, graphml)
|
||||
#[arg(long, default_value = "json")]
|
||||
format: String,
|
||||
|
||||
/// Graph name
|
||||
#[arg(short = 'g', long, default_value = "default")]
|
||||
graph: String,
|
||||
},
|
||||
|
||||
/// Show graph database information
|
||||
Info {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
|
||||
db: String,
|
||||
|
||||
/// Show detailed statistics
|
||||
#[arg(long)]
|
||||
detailed: bool,
|
||||
},
|
||||
|
||||
/// Run graph benchmarks
|
||||
Benchmark {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
|
||||
db: String,
|
||||
|
||||
/// Number of queries to run
|
||||
#[arg(short = 'n', long, default_value = "1000")]
|
||||
queries: usize,
|
||||
|
||||
/// Benchmark type (traverse, pattern, aggregate)
|
||||
#[arg(short = 't', long, default_value = "traverse")]
|
||||
bench_type: String,
|
||||
},
|
||||
|
||||
/// Start HTTP/gRPC server
|
||||
Serve {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector-graph.db")]
|
||||
db: String,
|
||||
|
||||
/// Server host
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
|
||||
/// HTTP port
|
||||
#[arg(long, default_value = "8080")]
|
||||
http_port: u16,
|
||||
|
||||
/// gRPC port
|
||||
#[arg(long, default_value = "50051")]
|
||||
grpc_port: u16,
|
||||
|
||||
/// Enable GraphQL endpoint
|
||||
#[arg(long)]
|
||||
graphql: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// Create a new graph database
|
||||
pub fn create_graph(path: &str, name: &str, indexed: bool, config: &Config) -> Result<()> {
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!("Creating graph database at: {}", path))
|
||||
);
|
||||
println!(" Graph name: {}", name.cyan());
|
||||
println!(
|
||||
" Property indexing: {}",
|
||||
if indexed {
|
||||
"enabled".green()
|
||||
} else {
|
||||
"disabled".dimmed()
|
||||
}
|
||||
);
|
||||
|
||||
// TODO: Integrate with ruvector-neo4j when available
|
||||
// For now, create a placeholder implementation
|
||||
std::fs::create_dir_all(Path::new(path).parent().unwrap_or(Path::new(".")))?;
|
||||
|
||||
println!("{}", format_success("Graph database created successfully!"));
|
||||
println!(
|
||||
"{}",
|
||||
format_info("Use 'ruvector graph shell' to start interactive mode")
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute a Cypher query
|
||||
pub fn execute_query(
|
||||
db_path: &str,
|
||||
cypher: &str,
|
||||
format: &str,
|
||||
explain: bool,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
if explain {
|
||||
println!("{}", "Query Execution Plan:".bold().cyan());
|
||||
println!("{}", format_info("EXPLAIN mode - showing query plan"));
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// TODO: Integrate with ruvector-neo4j Neo4jGraph implementation
|
||||
// Placeholder for actual query execution
|
||||
println!("{}", format_success("Executing Cypher query..."));
|
||||
println!(" Query: {}", cypher.dimmed());
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
match format {
|
||||
"table" => {
|
||||
println!("\n{}", format_graph_results_table(&[], cypher));
|
||||
}
|
||||
"json" => {
|
||||
println!("{}", format_graph_results_json(&[])?);
|
||||
}
|
||||
"csv" => {
|
||||
println!("{}", format_graph_results_csv(&[])?);
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unsupported output format: {}", format)),
|
||||
}
|
||||
|
||||
println!(
|
||||
"\n{}",
|
||||
format!("Query completed in {:.2}ms", elapsed.as_secs_f64() * 1000.0).dimmed()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Interactive Cypher shell (REPL)
|
||||
pub fn run_shell(db_path: &str, multiline: bool, config: &Config) -> Result<()> {
|
||||
println!("{}", "RuVector Graph Shell".bold().green());
|
||||
println!("Database: {}", db_path.cyan());
|
||||
println!(
|
||||
"Type {} to exit, {} for help\n",
|
||||
":exit".yellow(),
|
||||
":help".yellow()
|
||||
);
|
||||
|
||||
let stdin = io::stdin();
|
||||
let mut stdout = io::stdout();
|
||||
let mut query_buffer = String::new();
|
||||
|
||||
loop {
|
||||
// Print prompt
|
||||
if multiline && !query_buffer.is_empty() {
|
||||
print!("{}", " ... ".dimmed());
|
||||
} else {
|
||||
print!("{}", "cypher> ".green().bold());
|
||||
}
|
||||
stdout.flush()?;
|
||||
|
||||
// Read line
|
||||
let mut line = String::new();
|
||||
stdin.lock().read_line(&mut line)?;
|
||||
let line = line.trim();
|
||||
|
||||
// Handle special commands
|
||||
match line {
|
||||
":exit" | ":quit" | ":q" => {
|
||||
println!("{}", format_success("Goodbye!"));
|
||||
break;
|
||||
}
|
||||
":help" | ":h" => {
|
||||
print_shell_help();
|
||||
continue;
|
||||
}
|
||||
":clear" => {
|
||||
query_buffer.clear();
|
||||
println!("{}", format_info("Query buffer cleared"));
|
||||
continue;
|
||||
}
|
||||
"" => {
|
||||
if !multiline || query_buffer.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// In multiline mode, empty line executes query
|
||||
}
|
||||
_ => {
|
||||
query_buffer.push_str(line);
|
||||
query_buffer.push(' ');
|
||||
|
||||
if multiline && !line.ends_with(';') {
|
||||
continue; // Continue reading in multiline mode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
let query = query_buffer.trim().trim_end_matches(';');
|
||||
if !query.is_empty() {
|
||||
match execute_query(db_path, query, "table", false, config) {
|
||||
Ok(_) => {}
|
||||
Err(e) => println!("{}", format_error(&e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
query_buffer.clear();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Import graph data from file
|
||||
pub fn import_graph(
|
||||
db_path: &str,
|
||||
input_file: &str,
|
||||
format: &str,
|
||||
graph_name: &str,
|
||||
skip_errors: bool,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!("Importing graph data from: {}", input_file))
|
||||
);
|
||||
println!(" Format: {}", format.cyan());
|
||||
println!(" Graph: {}", graph_name.cyan());
|
||||
println!(
|
||||
" Skip errors: {}",
|
||||
if skip_errors {
|
||||
"yes".yellow()
|
||||
} else {
|
||||
"no".dimmed()
|
||||
}
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// TODO: Implement actual import logic with ruvector-neo4j
|
||||
match format {
|
||||
"csv" => {
|
||||
println!("{}", format_info("Parsing CSV file..."));
|
||||
// Parse CSV and create nodes/relationships
|
||||
}
|
||||
"json" => {
|
||||
println!("{}", format_info("Parsing JSON file..."));
|
||||
// Parse JSON and create graph structure
|
||||
}
|
||||
"cypher" => {
|
||||
println!("{}", format_info("Executing Cypher statements..."));
|
||||
// Execute Cypher commands from file
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unsupported import format: {}", format)),
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!(
|
||||
"Import completed in {:.2}s",
|
||||
elapsed.as_secs_f64()
|
||||
))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export graph data to file
|
||||
pub fn export_graph(
|
||||
db_path: &str,
|
||||
output_file: &str,
|
||||
format: &str,
|
||||
graph_name: &str,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!("Exporting graph to: {}", output_file))
|
||||
);
|
||||
println!(" Format: {}", format.cyan());
|
||||
println!(" Graph: {}", graph_name.cyan());
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// TODO: Implement actual export logic with ruvector-neo4j
|
||||
match format {
|
||||
"json" => {
|
||||
println!("{}", format_info("Generating JSON export..."));
|
||||
// Export as JSON graph format
|
||||
}
|
||||
"csv" => {
|
||||
println!("{}", format_info("Generating CSV export..."));
|
||||
// Export nodes and edges as CSV files
|
||||
}
|
||||
"cypher" => {
|
||||
println!("{}", format_info("Generating Cypher statements..."));
|
||||
// Export as Cypher CREATE statements
|
||||
}
|
||||
"graphml" => {
|
||||
println!("{}", format_info("Generating GraphML export..."));
|
||||
// Export as GraphML XML format
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unsupported export format: {}", format)),
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
println!(
|
||||
"{}",
|
||||
format_success(&format!(
|
||||
"Export completed in {:.2}s",
|
||||
elapsed.as_secs_f64()
|
||||
))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Show graph database information
|
||||
pub fn show_graph_info(db_path: &str, detailed: bool, config: &Config) -> Result<()> {
|
||||
println!("\n{}", "Graph Database Statistics".bold().green());
|
||||
|
||||
// TODO: Integrate with ruvector-neo4j to get actual statistics
|
||||
println!(" Database: {}", db_path.cyan());
|
||||
println!(" Graphs: {}", "1".cyan());
|
||||
println!(" Total nodes: {}", "0".cyan());
|
||||
println!(" Total relationships: {}", "0".cyan());
|
||||
println!(" Node labels: {}", "0".cyan());
|
||||
println!(" Relationship types: {}", "0".cyan());
|
||||
|
||||
if detailed {
|
||||
println!("\n{}", "Storage Information:".bold().cyan());
|
||||
println!(" Store size: {}", "0 bytes".cyan());
|
||||
println!(" Index size: {}", "0 bytes".cyan());
|
||||
|
||||
println!("\n{}", "Configuration:".bold().cyan());
|
||||
println!(" Cache size: {}", "N/A".cyan());
|
||||
println!(" Page size: {}", "N/A".cyan());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run graph benchmarks
|
||||
pub fn run_graph_benchmark(
|
||||
db_path: &str,
|
||||
num_queries: usize,
|
||||
bench_type: &str,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
println!("{}", "Running graph benchmark...".bold().green());
|
||||
println!(" Benchmark type: {}", bench_type.cyan());
|
||||
println!(" Queries: {}", num_queries.to_string().cyan());
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// TODO: Implement actual benchmarks with ruvector-neo4j
|
||||
match bench_type {
|
||||
"traverse" => {
|
||||
println!("{}", format_info("Benchmarking graph traversal..."));
|
||||
// Run traversal queries
|
||||
}
|
||||
"pattern" => {
|
||||
println!("{}", format_info("Benchmarking pattern matching..."));
|
||||
// Run pattern matching queries
|
||||
}
|
||||
"aggregate" => {
|
||||
println!("{}", format_info("Benchmarking aggregations..."));
|
||||
// Run aggregation queries
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unknown benchmark type: {}", bench_type)),
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
let qps = num_queries as f64 / elapsed.as_secs_f64();
|
||||
let avg_latency = elapsed.as_secs_f64() * 1000.0 / num_queries as f64;
|
||||
|
||||
println!("\n{}", "Benchmark Results:".bold().green());
|
||||
println!(" Total time: {:.2}s", elapsed.as_secs_f64());
|
||||
println!(" Queries per second: {:.0}", qps.to_string().cyan());
|
||||
println!(" Average latency: {:.2}ms", avg_latency.to_string().cyan());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start HTTP/gRPC server
|
||||
pub fn serve_graph(
|
||||
db_path: &str,
|
||||
host: &str,
|
||||
http_port: u16,
|
||||
grpc_port: u16,
|
||||
enable_graphql: bool,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
println!("{}", "Starting RuVector Graph Server...".bold().green());
|
||||
println!(" Database: {}", db_path.cyan());
|
||||
println!(
|
||||
" HTTP endpoint: {}:{}",
|
||||
host.cyan(),
|
||||
http_port.to_string().cyan()
|
||||
);
|
||||
println!(
|
||||
" gRPC endpoint: {}:{}",
|
||||
host.cyan(),
|
||||
grpc_port.to_string().cyan()
|
||||
);
|
||||
|
||||
if enable_graphql {
|
||||
println!(
|
||||
" GraphQL endpoint: {}:{}/graphql",
|
||||
host.cyan(),
|
||||
http_port.to_string().cyan()
|
||||
);
|
||||
}
|
||||
|
||||
println!("\n{}", format_info("Server configuration loaded"));
|
||||
|
||||
// TODO: Implement actual server with ruvector-neo4j
|
||||
println!("{}", format_success("Server ready! Press Ctrl+C to stop."));
|
||||
|
||||
// Placeholder - would run actual server here
|
||||
println!(
|
||||
"\n{}",
|
||||
format_info("Server implementation pending - integrate with ruvector-neo4j")
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper functions for formatting graph results
|
||||
|
||||
fn format_graph_results_table(results: &[serde_json::Value], query: &str) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
if results.is_empty() {
|
||||
output.push_str(&format!("{}\n", "No results found".dimmed()));
|
||||
output.push_str(&format!("Query: {}\n", query.dimmed()));
|
||||
} else {
|
||||
output.push_str(&format!("{} results\n", results.len().to_string().cyan()));
|
||||
// TODO: Format results as table
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn format_graph_results_json(results: &[serde_json::Value]) -> Result<String> {
|
||||
serde_json::to_string_pretty(&results)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize results: {}", e))
|
||||
}
|
||||
|
||||
fn format_graph_results_csv(results: &[serde_json::Value]) -> Result<String> {
|
||||
// TODO: Implement CSV formatting
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
fn print_shell_help() {
|
||||
println!("\n{}", "RuVector Graph Shell Commands".bold().cyan());
|
||||
println!(" {} - Exit the shell", ":exit, :quit, :q".yellow());
|
||||
println!(
|
||||
" {} - Show this help message",
|
||||
":help, :h".yellow()
|
||||
);
|
||||
println!(" {} - Clear query buffer", ":clear".yellow());
|
||||
println!("\n{}", "Cypher Examples:".bold().cyan());
|
||||
println!(" {}", "CREATE (n:Person {{name: 'Alice'}})".dimmed());
|
||||
println!(" {}", "MATCH (n:Person) RETURN n".dimmed());
|
||||
println!(" {}", "MATCH (a)-[r:KNOWS]->(b) RETURN a, r, b".dimmed());
|
||||
println!();
|
||||
}
|
||||
2507
vendor/ruvector/crates/ruvector-cli/src/cli/hooks.rs
vendored
Normal file
2507
vendor/ruvector/crates/ruvector-cli/src/cli/hooks.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
415
vendor/ruvector/crates/ruvector-cli/src/cli/hooks_postgres.rs
vendored
Normal file
415
vendor/ruvector/crates/ruvector-cli/src/cli/hooks_postgres.rs
vendored
Normal file
@@ -0,0 +1,415 @@
|
||||
//! PostgreSQL storage backend for hooks intelligence data
|
||||
//!
|
||||
//! This module provides PostgreSQL-based storage for the hooks system,
|
||||
//! using the ruvector extension for vector operations.
|
||||
//!
|
||||
//! Enable with the `postgres` feature flag.
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
#[cfg(feature = "postgres")]
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
use std::env;
|
||||
|
||||
/// PostgreSQL storage configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PostgresConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub user: String,
|
||||
pub password: Option<String>,
|
||||
pub dbname: String,
|
||||
}
|
||||
|
||||
impl PostgresConfig {
|
||||
/// Create config from environment variables
|
||||
pub fn from_env() -> Option<Self> {
|
||||
// Try RUVECTOR_POSTGRES_URL first, then DATABASE_URL
|
||||
if let Ok(url) = env::var("RUVECTOR_POSTGRES_URL").or_else(|_| env::var("DATABASE_URL")) {
|
||||
return Self::from_url(&url);
|
||||
}
|
||||
|
||||
// Try individual environment variables
|
||||
let host = env::var("RUVECTOR_PG_HOST").unwrap_or_else(|_| "localhost".to_string());
|
||||
let port = env::var("RUVECTOR_PG_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(5432);
|
||||
let user = env::var("RUVECTOR_PG_USER").ok()?;
|
||||
let password = env::var("RUVECTOR_PG_PASSWORD").ok();
|
||||
let dbname = env::var("RUVECTOR_PG_DATABASE").unwrap_or_else(|_| "ruvector".to_string());
|
||||
|
||||
Some(Self {
|
||||
host,
|
||||
port,
|
||||
user,
|
||||
password,
|
||||
dbname,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse PostgreSQL connection URL
|
||||
pub fn from_url(url: &str) -> Option<Self> {
|
||||
// Parse postgres://user:password@host:port/dbname
|
||||
let url = url
|
||||
.strip_prefix("postgres://")
|
||||
.or_else(|| url.strip_prefix("postgresql://"))?;
|
||||
|
||||
let (auth, rest) = url.split_once('@')?;
|
||||
let (user, password) = if auth.contains(':') {
|
||||
let (u, p) = auth.split_once(':')?;
|
||||
(u.to_string(), Some(p.to_string()))
|
||||
} else {
|
||||
(auth.to_string(), None)
|
||||
};
|
||||
|
||||
let (host_port, dbname) = rest.split_once('/')?;
|
||||
let dbname = dbname.split('?').next()?.to_string();
|
||||
|
||||
let (host, port) = if host_port.contains(':') {
|
||||
let (h, p) = host_port.split_once(':')?;
|
||||
(h.to_string(), p.parse().ok()?)
|
||||
} else {
|
||||
(host_port.to_string(), 5432)
|
||||
};
|
||||
|
||||
Some(Self {
|
||||
host,
|
||||
port,
|
||||
user,
|
||||
password,
|
||||
dbname,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgreSQL storage backend for hooks
|
||||
#[cfg(feature = "postgres")]
|
||||
pub struct PostgresStorage {
|
||||
pool: Pool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
impl PostgresStorage {
|
||||
/// Create a new PostgreSQL storage backend
|
||||
pub async fn new(config: PostgresConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut cfg = Config::new();
|
||||
cfg.host = Some(config.host);
|
||||
cfg.port = Some(config.port);
|
||||
cfg.user = Some(config.user);
|
||||
cfg.password = config.password;
|
||||
cfg.dbname = Some(config.dbname);
|
||||
|
||||
let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls)?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
/// Update Q-value for state-action pair
|
||||
pub async fn update_q(
|
||||
&self,
|
||||
state: &str,
|
||||
action: &str,
|
||||
reward: f32,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
client
|
||||
.execute(
|
||||
"SELECT ruvector_hooks_update_q($1, $2, $3)",
|
||||
&[&state, &action, &reward],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get best action for state
|
||||
pub async fn best_action(
|
||||
&self,
|
||||
state: &str,
|
||||
actions: &[String],
|
||||
) -> Result<Option<(String, f32, f32)>, Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
let row = client
|
||||
.query_opt(
|
||||
"SELECT action, q_value, confidence FROM ruvector_hooks_best_action($1, $2)",
|
||||
&[&state, &actions],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|r| (r.get(0), r.get(1), r.get(2))))
|
||||
}
|
||||
|
||||
/// Store content in semantic memory
|
||||
pub async fn remember(
|
||||
&self,
|
||||
memory_type: &str,
|
||||
content: &str,
|
||||
embedding: Option<&[f32]>,
|
||||
metadata: &serde_json::Value,
|
||||
) -> Result<i32, Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
let metadata_str = serde_json::to_string(metadata)?;
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT ruvector_hooks_remember($1, $2, $3, $4::jsonb)",
|
||||
&[&memory_type, &content, &embedding, &metadata_str],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(row.get(0))
|
||||
}
|
||||
|
||||
/// Search memory semantically
|
||||
pub async fn recall(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
limit: i32,
|
||||
) -> Result<Vec<MemoryResult>, Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
let rows = client
|
||||
.query(
|
||||
"SELECT id, memory_type, content, metadata::text, similarity
|
||||
FROM ruvector_hooks_recall($1, $2)",
|
||||
&[&query_embedding, &limit],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let metadata_str: String = r.get(3);
|
||||
MemoryResult {
|
||||
id: r.get(0),
|
||||
memory_type: r.get(1),
|
||||
content: r.get(2),
|
||||
metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
|
||||
similarity: r.get(4),
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Record file sequence
|
||||
pub async fn record_sequence(
|
||||
&self,
|
||||
from_file: &str,
|
||||
to_file: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
client
|
||||
.execute(
|
||||
"SELECT ruvector_hooks_record_sequence($1, $2)",
|
||||
&[&from_file, &to_file],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get suggested next files
|
||||
pub async fn suggest_next(
|
||||
&self,
|
||||
file: &str,
|
||||
limit: i32,
|
||||
) -> Result<Vec<(String, i32)>, Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
let rows = client
|
||||
.query(
|
||||
"SELECT to_file, count FROM ruvector_hooks_suggest_next($1, $2)",
|
||||
&[&file, &limit],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(rows.iter().map(|r| (r.get(0), r.get(1))).collect())
|
||||
}
|
||||
|
||||
/// Record error pattern
|
||||
pub async fn record_error(
|
||||
&self,
|
||||
code: &str,
|
||||
error_type: &str,
|
||||
message: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
client
|
||||
.execute(
|
||||
"SELECT ruvector_hooks_record_error($1, $2, $3)",
|
||||
&[&code, &error_type, &message],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register agent in swarm
|
||||
pub async fn swarm_register(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
agent_type: &str,
|
||||
capabilities: &[String],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
client
|
||||
.execute(
|
||||
"SELECT ruvector_hooks_swarm_register($1, $2, $3)",
|
||||
&[&agent_id, &agent_type, &capabilities],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record coordination between agents
|
||||
pub async fn swarm_coordinate(
|
||||
&self,
|
||||
source: &str,
|
||||
target: &str,
|
||||
weight: f32,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
client
|
||||
.execute(
|
||||
"SELECT ruvector_hooks_swarm_coordinate($1, $2, $3)",
|
||||
&[&source, &target, &weight],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get swarm statistics
|
||||
pub async fn swarm_stats(&self) -> Result<SwarmStats, Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
let row = client
|
||||
.query_one("SELECT * FROM ruvector_hooks_swarm_stats()", &[])
|
||||
.await?;
|
||||
|
||||
Ok(SwarmStats {
|
||||
total_agents: row.get(0),
|
||||
active_agents: row.get(1),
|
||||
total_edges: row.get(2),
|
||||
avg_success_rate: row.get(3),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get overall statistics
|
||||
pub async fn get_stats(&self) -> Result<IntelligenceStats, Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
let row = client
|
||||
.query_one("SELECT * FROM ruvector_hooks_get_stats()", &[])
|
||||
.await?;
|
||||
|
||||
Ok(IntelligenceStats {
|
||||
total_patterns: row.get(0),
|
||||
total_memories: row.get(1),
|
||||
total_trajectories: row.get(2),
|
||||
total_errors: row.get(3),
|
||||
session_count: row.get(4),
|
||||
})
|
||||
}
|
||||
|
||||
/// Start a new session
|
||||
pub async fn session_start(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = self.pool.get().await?;
|
||||
client
|
||||
.execute("SELECT ruvector_hooks_session_start()", &[])
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory search result
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryResult {
|
||||
pub id: i32,
|
||||
pub memory_type: String,
|
||||
pub content: String,
|
||||
pub metadata: serde_json::Value,
|
||||
pub similarity: f32,
|
||||
}
|
||||
|
||||
/// Swarm statistics
|
||||
#[derive(Debug)]
|
||||
pub struct SwarmStats {
|
||||
pub total_agents: i64,
|
||||
pub active_agents: i64,
|
||||
pub total_edges: i64,
|
||||
pub avg_success_rate: f32,
|
||||
}
|
||||
|
||||
/// Intelligence statistics
|
||||
#[derive(Debug)]
|
||||
pub struct IntelligenceStats {
|
||||
pub total_patterns: i64,
|
||||
pub total_memories: i64,
|
||||
pub total_trajectories: i64,
|
||||
pub total_errors: i64,
|
||||
pub session_count: i64,
|
||||
}
|
||||
|
||||
/// Check if PostgreSQL is available
|
||||
pub fn is_postgres_available() -> bool {
|
||||
PostgresConfig::from_env().is_some()
|
||||
}
|
||||
|
||||
/// Storage backend selector
|
||||
pub enum StorageBackend {
|
||||
#[cfg(feature = "postgres")]
|
||||
Postgres(PostgresStorage),
|
||||
Json(super::Intelligence),
|
||||
}
|
||||
|
||||
impl StorageBackend {
|
||||
/// Create storage backend from environment
|
||||
#[cfg(feature = "postgres")]
|
||||
pub async fn from_env() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
if let Some(config) = PostgresConfig::from_env() {
|
||||
match PostgresStorage::new(config).await {
|
||||
Ok(pg) => return Ok(Self::Postgres(pg)),
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"Warning: PostgreSQL unavailable ({}), using JSON fallback",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Self::Json(super::Intelligence::new(
|
||||
super::get_intelligence_path(),
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "postgres"))]
|
||||
pub fn from_env() -> Self {
|
||||
Self::Json(super::Intelligence::new(super::get_intelligence_path()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_from_url() {
|
||||
let config =
|
||||
PostgresConfig::from_url("postgres://user:pass@localhost:5432/ruvector").unwrap();
|
||||
assert_eq!(config.host, "localhost");
|
||||
assert_eq!(config.port, 5432);
|
||||
assert_eq!(config.user, "user");
|
||||
assert_eq!(config.password, Some("pass".to_string()));
|
||||
assert_eq!(config.dbname, "ruvector");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_from_url_no_password() {
|
||||
let config = PostgresConfig::from_url("postgres://user@localhost/ruvector").unwrap();
|
||||
assert_eq!(config.user, "user");
|
||||
assert_eq!(config.password, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_from_url_with_query() {
|
||||
let config = PostgresConfig::from_url(
|
||||
"postgres://user:pass@localhost:5432/ruvector?sslmode=require",
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(config.dbname, "ruvector");
|
||||
}
|
||||
}
|
||||
15
vendor/ruvector/crates/ruvector-cli/src/cli/mod.rs
vendored
Normal file
15
vendor/ruvector/crates/ruvector-cli/src/cli/mod.rs
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
//! CLI module for Ruvector
|
||||
|
||||
pub mod commands;
|
||||
pub mod format;
|
||||
pub mod graph;
|
||||
pub mod hooks;
|
||||
#[cfg(feature = "postgres")]
|
||||
pub mod hooks_postgres;
|
||||
pub mod progress;
|
||||
|
||||
pub use commands::*;
|
||||
pub use format::*;
|
||||
pub use graph::*;
|
||||
pub use hooks::*;
|
||||
pub use progress::ProgressTracker;
|
||||
56
vendor/ruvector/crates/ruvector-cli/src/cli/progress.rs
vendored
Normal file
56
vendor/ruvector/crates/ruvector-cli/src/cli/progress.rs
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
// ! Progress tracking for CLI operations
|
||||
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Progress tracker for long-running operations
|
||||
pub struct ProgressTracker {
|
||||
multi: MultiProgress,
|
||||
}
|
||||
|
||||
impl ProgressTracker {
|
||||
/// Create a new progress tracker
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
multi: MultiProgress::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a progress bar for an operation
|
||||
pub fn create_bar(&self, total: u64, message: &str) -> ProgressBar {
|
||||
let pb = self.multi.add(ProgressBar::new(total));
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta})")
|
||||
.unwrap()
|
||||
.progress_chars("#>-")
|
||||
);
|
||||
pb.set_message(message.to_string());
|
||||
pb.enable_steady_tick(Duration::from_millis(100));
|
||||
pb
|
||||
}
|
||||
|
||||
/// Create a spinner for indeterminate operations
|
||||
pub fn create_spinner(&self, message: &str) -> ProgressBar {
|
||||
let pb = self.multi.add(ProgressBar::new_spinner());
|
||||
pb.set_style(
|
||||
ProgressStyle::default_spinner()
|
||||
.template("{spinner:.green} {msg}")
|
||||
.unwrap(),
|
||||
);
|
||||
pb.set_message(message.to_string());
|
||||
pb.enable_steady_tick(Duration::from_millis(100));
|
||||
pb
|
||||
}
|
||||
|
||||
/// Finish all progress bars
|
||||
pub fn finish_all(&self) {
|
||||
// Progress bars auto-finish when dropped
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProgressTracker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
280
vendor/ruvector/crates/ruvector-cli/src/config.rs
vendored
Normal file
280
vendor/ruvector/crates/ruvector-cli/src/config.rs
vendored
Normal file
@@ -0,0 +1,280 @@
|
||||
//! Configuration management for Ruvector CLI
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, QuantizationConfig};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Ruvector CLI configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
/// Database options
|
||||
#[serde(default)]
|
||||
pub database: DatabaseConfig,
|
||||
|
||||
/// CLI options
|
||||
#[serde(default)]
|
||||
pub cli: CliConfig,
|
||||
|
||||
/// MCP server options
|
||||
#[serde(default)]
|
||||
pub mcp: McpConfig,
|
||||
}
|
||||
|
||||
/// Database configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatabaseConfig {
|
||||
/// Default storage path
|
||||
#[serde(default = "default_storage_path")]
|
||||
pub storage_path: String,
|
||||
|
||||
/// Default dimensions
|
||||
#[serde(default = "default_dimensions")]
|
||||
pub dimensions: usize,
|
||||
|
||||
/// Distance metric
|
||||
#[serde(default = "default_distance_metric")]
|
||||
pub distance_metric: DistanceMetric,
|
||||
|
||||
/// HNSW configuration
|
||||
#[serde(default)]
|
||||
pub hnsw: Option<HnswConfig>,
|
||||
|
||||
/// Quantization configuration
|
||||
#[serde(default)]
|
||||
pub quantization: Option<QuantizationConfig>,
|
||||
}
|
||||
|
||||
/// CLI configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CliConfig {
|
||||
/// Show progress bars
|
||||
#[serde(default = "default_true")]
|
||||
pub progress: bool,
|
||||
|
||||
/// Use colors in output
|
||||
#[serde(default = "default_true")]
|
||||
pub colors: bool,
|
||||
|
||||
/// Default batch size for operations
|
||||
#[serde(default = "default_batch_size")]
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
/// MCP server configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpConfig {
|
||||
/// Server host for SSE transport
|
||||
#[serde(default = "default_host")]
|
||||
pub host: String,
|
||||
|
||||
/// Server port for SSE transport
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
/// Enable CORS
|
||||
#[serde(default = "default_true")]
|
||||
pub cors: bool,
|
||||
|
||||
/// Allowed data directory for MCP file operations (path confinement)
|
||||
/// All db_path and backup_path values must resolve within this directory.
|
||||
/// Defaults to the current working directory.
|
||||
#[serde(default = "default_data_dir")]
|
||||
pub data_dir: String,
|
||||
}
|
||||
|
||||
// Default value functions
|
||||
fn default_storage_path() -> String {
|
||||
"./ruvector.db".to_string()
|
||||
}
|
||||
|
||||
fn default_dimensions() -> usize {
|
||||
384
|
||||
}
|
||||
|
||||
fn default_distance_metric() -> DistanceMetric {
|
||||
DistanceMetric::Cosine
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_batch_size() -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
fn default_data_dir() -> String {
|
||||
std::env::current_dir()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| ".".to_string())
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_port() -> u16 {
|
||||
3000
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
database: DatabaseConfig::default(),
|
||||
cli: CliConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
storage_path: default_storage_path(),
|
||||
dimensions: default_dimensions(),
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
hnsw: Some(HnswConfig::default()),
|
||||
quantization: Some(QuantizationConfig::Scalar),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CliConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
progress: true,
|
||||
colors: true,
|
||||
batch_size: default_batch_size(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for McpConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_host(),
|
||||
port: default_port(),
|
||||
cors: true,
|
||||
data_dir: default_data_dir(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load configuration from file
|
||||
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let content =
|
||||
std::fs::read_to_string(path.as_ref()).context("Failed to read config file")?;
|
||||
let config: Config = toml::from_str(&content).context("Failed to parse config file")?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Load configuration with precedence: CLI args > env vars > config file > defaults
|
||||
pub fn load(config_path: Option<PathBuf>) -> Result<Self> {
|
||||
let mut config = if let Some(path) = config_path {
|
||||
Self::from_file(&path).unwrap_or_default()
|
||||
} else {
|
||||
// Try default locations
|
||||
Self::try_default_locations().unwrap_or_default()
|
||||
};
|
||||
|
||||
// Override with environment variables
|
||||
config.apply_env_vars()?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Try loading from default locations
|
||||
fn try_default_locations() -> Option<Self> {
|
||||
let paths = vec![
|
||||
"ruvector.toml",
|
||||
".ruvector.toml",
|
||||
"~/.config/ruvector/config.toml",
|
||||
"/etc/ruvector/config.toml",
|
||||
];
|
||||
|
||||
for path in paths {
|
||||
let expanded = shellexpand::tilde(path).to_string();
|
||||
if let Ok(config) = Self::from_file(&expanded) {
|
||||
return Some(config);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Apply environment variable overrides
|
||||
fn apply_env_vars(&mut self) -> Result<()> {
|
||||
if let Ok(path) = std::env::var("RUVECTOR_STORAGE_PATH") {
|
||||
self.database.storage_path = path;
|
||||
}
|
||||
|
||||
if let Ok(dims) = std::env::var("RUVECTOR_DIMENSIONS") {
|
||||
self.database.dimensions = dims.parse().context("Invalid RUVECTOR_DIMENSIONS")?;
|
||||
}
|
||||
|
||||
if let Ok(metric) = std::env::var("RUVECTOR_DISTANCE_METRIC") {
|
||||
self.database.distance_metric = match metric.to_lowercase().as_str() {
|
||||
"euclidean" => DistanceMetric::Euclidean,
|
||||
"cosine" => DistanceMetric::Cosine,
|
||||
"dotproduct" => DistanceMetric::DotProduct,
|
||||
"manhattan" => DistanceMetric::Manhattan,
|
||||
_ => return Err(anyhow::anyhow!("Invalid distance metric: {}", metric)),
|
||||
};
|
||||
}
|
||||
|
||||
if let Ok(host) = std::env::var("RUVECTOR_MCP_HOST") {
|
||||
self.mcp.host = host;
|
||||
}
|
||||
|
||||
if let Ok(port) = std::env::var("RUVECTOR_MCP_PORT") {
|
||||
self.mcp.port = port.parse().context("Invalid RUVECTOR_MCP_PORT")?;
|
||||
}
|
||||
|
||||
if let Ok(data_dir) = std::env::var("RUVECTOR_MCP_DATA_DIR") {
|
||||
self.mcp.data_dir = data_dir;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert to DbOptions
|
||||
pub fn to_db_options(&self) -> DbOptions {
|
||||
DbOptions {
|
||||
dimensions: self.database.dimensions,
|
||||
distance_metric: self.database.distance_metric,
|
||||
storage_path: self.database.storage_path.clone(),
|
||||
hnsw_config: self.database.hnsw.clone(),
|
||||
quantization: self.database.quantization.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save configuration to file
|
||||
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
|
||||
std::fs::write(path, content).context("Failed to write config file")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = Config::default();
|
||||
assert_eq!(config.database.dimensions, 384);
|
||||
assert_eq!(config.cli.batch_size, 1000);
|
||||
assert_eq!(config.mcp.port, 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = Config::default();
|
||||
let toml_str = toml::to_string(&config).unwrap();
|
||||
let parsed: Config = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(config.database.dimensions, parsed.database.dimensions);
|
||||
}
|
||||
}
|
||||
416
vendor/ruvector/crates/ruvector-cli/src/main.rs
vendored
Normal file
416
vendor/ruvector/crates/ruvector-cli/src/main.rs
vendored
Normal file
@@ -0,0 +1,416 @@
|
||||
//! Ruvector CLI - High-performance vector database command-line interface
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use colored::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
mod cli;
|
||||
mod config;
|
||||
|
||||
use crate::cli::commands::*;
|
||||
use crate::config::Config;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "ruvector")]
|
||||
#[command(about = "High-performance Rust vector database CLI", long_about = None)]
|
||||
#[command(version)]
|
||||
struct Cli {
|
||||
/// Configuration file path
|
||||
#[arg(short, long, global = true)]
|
||||
config: Option<PathBuf>,
|
||||
|
||||
/// Enable debug mode
|
||||
#[arg(short, long, global = true)]
|
||||
debug: bool,
|
||||
|
||||
/// Disable colors
|
||||
#[arg(long, global = true)]
|
||||
no_color: bool,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Create a new vector database
|
||||
Create {
|
||||
/// Database file path
|
||||
#[arg(short, long, default_value = "./ruvector.db")]
|
||||
path: String,
|
||||
|
||||
/// Vector dimensions
|
||||
#[arg(short = 'D', long)]
|
||||
dimensions: usize,
|
||||
},
|
||||
|
||||
/// Insert vectors from a file
|
||||
Insert {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector.db")]
|
||||
db: String,
|
||||
|
||||
/// Input file path
|
||||
#[arg(short, long)]
|
||||
input: String,
|
||||
|
||||
/// Input format (json, csv, npy)
|
||||
#[arg(short, long, default_value = "json")]
|
||||
format: String,
|
||||
|
||||
/// Hide progress bar
|
||||
#[arg(long)]
|
||||
no_progress: bool,
|
||||
},
|
||||
|
||||
/// Search for similar vectors
|
||||
Search {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector.db")]
|
||||
db: String,
|
||||
|
||||
/// Query vector (comma-separated floats or JSON array)
|
||||
#[arg(short, long)]
|
||||
query: String,
|
||||
|
||||
/// Number of results
|
||||
#[arg(short = 'k', long, default_value = "10")]
|
||||
top_k: usize,
|
||||
|
||||
/// Show full vectors in results
|
||||
#[arg(long)]
|
||||
show_vectors: bool,
|
||||
},
|
||||
|
||||
/// Show database information
|
||||
Info {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector.db")]
|
||||
db: String,
|
||||
},
|
||||
|
||||
/// Run a quick performance benchmark
|
||||
Benchmark {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector.db")]
|
||||
db: String,
|
||||
|
||||
/// Number of queries to run
|
||||
#[arg(short = 'n', long, default_value = "1000")]
|
||||
queries: usize,
|
||||
},
|
||||
|
||||
/// Export database to file
|
||||
Export {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector.db")]
|
||||
db: String,
|
||||
|
||||
/// Output file path
|
||||
#[arg(short, long)]
|
||||
output: String,
|
||||
|
||||
/// Output format (json, csv)
|
||||
#[arg(short, long, default_value = "json")]
|
||||
format: String,
|
||||
},
|
||||
|
||||
/// Import from other vector databases
|
||||
Import {
|
||||
/// Database file path
|
||||
#[arg(short = 'b', long, default_value = "./ruvector.db")]
|
||||
db: String,
|
||||
|
||||
/// Source database type (faiss, pinecone, weaviate)
|
||||
#[arg(short, long)]
|
||||
source: String,
|
||||
|
||||
/// Source file or connection path
|
||||
#[arg(short = 'p', long)]
|
||||
source_path: String,
|
||||
},
|
||||
|
||||
/// Graph database operations (Neo4j-compatible)
|
||||
Graph {
|
||||
#[command(subcommand)]
|
||||
action: cli::graph::GraphCommands,
|
||||
},
|
||||
|
||||
/// Self-learning intelligence hooks for Claude Code
|
||||
Hooks {
|
||||
#[command(subcommand)]
|
||||
action: cli::hooks::HooksCommands,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize logging
|
||||
if cli.debug {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("ruvector=debug")
|
||||
.init();
|
||||
}
|
||||
|
||||
// Disable colors if requested
|
||||
if cli.no_color {
|
||||
colored::control::set_override(false);
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
let config = Config::load(cli.config)?;
|
||||
|
||||
// Execute command
|
||||
let result = match cli.command {
|
||||
Commands::Create { path, dimensions } => create_database(&path, dimensions, &config),
|
||||
Commands::Insert {
|
||||
db,
|
||||
input,
|
||||
format,
|
||||
no_progress,
|
||||
} => insert_vectors(&db, &input, &format, &config, !no_progress),
|
||||
Commands::Search {
|
||||
db,
|
||||
query,
|
||||
top_k,
|
||||
show_vectors,
|
||||
} => {
|
||||
let query_vec = parse_query_vector(&query)?;
|
||||
search_vectors(&db, query_vec, top_k, &config, show_vectors)
|
||||
}
|
||||
Commands::Info { db } => show_info(&db, &config),
|
||||
Commands::Benchmark { db, queries } => run_benchmark(&db, &config, queries),
|
||||
Commands::Export { db, output, format } => export_database(&db, &output, &format, &config),
|
||||
Commands::Import {
|
||||
db,
|
||||
source,
|
||||
source_path,
|
||||
} => import_from_external(&db, &source, &source_path, &config),
|
||||
Commands::Graph { action } => {
|
||||
use cli::graph::GraphCommands;
|
||||
match action {
|
||||
GraphCommands::Create {
|
||||
path,
|
||||
name,
|
||||
indexed,
|
||||
} => cli::graph::create_graph(&path, &name, indexed, &config),
|
||||
GraphCommands::Query {
|
||||
db,
|
||||
cypher,
|
||||
format,
|
||||
explain,
|
||||
} => cli::graph::execute_query(&db, &cypher, &format, explain, &config),
|
||||
GraphCommands::Shell { db, multiline } => {
|
||||
cli::graph::run_shell(&db, multiline, &config)
|
||||
}
|
||||
GraphCommands::Import {
|
||||
db,
|
||||
input,
|
||||
format,
|
||||
graph,
|
||||
skip_errors,
|
||||
} => cli::graph::import_graph(&db, &input, &format, &graph, skip_errors, &config),
|
||||
GraphCommands::Export {
|
||||
db,
|
||||
output,
|
||||
format,
|
||||
graph,
|
||||
} => cli::graph::export_graph(&db, &output, &format, &graph, &config),
|
||||
GraphCommands::Info { db, detailed } => {
|
||||
cli::graph::show_graph_info(&db, detailed, &config)
|
||||
}
|
||||
GraphCommands::Benchmark {
|
||||
db,
|
||||
queries,
|
||||
bench_type,
|
||||
} => cli::graph::run_graph_benchmark(&db, queries, &bench_type, &config),
|
||||
GraphCommands::Serve {
|
||||
db,
|
||||
host,
|
||||
http_port,
|
||||
grpc_port,
|
||||
graphql,
|
||||
} => cli::graph::serve_graph(&db, &host, http_port, grpc_port, graphql, &config),
|
||||
}
|
||||
}
|
||||
Commands::Hooks { action } => {
|
||||
use cli::hooks::HooksCommands;
|
||||
match action {
|
||||
HooksCommands::Init { force, postgres } => {
|
||||
cli::hooks::init_hooks(force, postgres, &config)
|
||||
}
|
||||
HooksCommands::Install { settings_dir } => {
|
||||
cli::hooks::install_hooks(&settings_dir, &config)
|
||||
}
|
||||
HooksCommands::Stats => cli::hooks::show_stats(&config),
|
||||
HooksCommands::Remember {
|
||||
memory_type,
|
||||
content,
|
||||
} => cli::hooks::remember_content(&memory_type, &content.join(" "), &config),
|
||||
HooksCommands::Recall { query, top_k } => {
|
||||
cli::hooks::recall_content(&query.join(" "), top_k, &config)
|
||||
}
|
||||
HooksCommands::Learn {
|
||||
state,
|
||||
action,
|
||||
reward,
|
||||
} => cli::hooks::learn_trajectory(&state, &action, reward, &config),
|
||||
HooksCommands::Suggest { state, actions } => {
|
||||
cli::hooks::suggest_action(&state, &actions, &config)
|
||||
}
|
||||
HooksCommands::Route {
|
||||
task,
|
||||
file,
|
||||
crate_name,
|
||||
operation,
|
||||
} => cli::hooks::route_task(
|
||||
&task.join(" "),
|
||||
file.as_deref(),
|
||||
crate_name.as_deref(),
|
||||
&operation,
|
||||
&config,
|
||||
),
|
||||
HooksCommands::PreEdit { file } => cli::hooks::pre_edit_hook(&file, &config),
|
||||
HooksCommands::PostEdit { file, success } => {
|
||||
cli::hooks::post_edit_hook(&file, success, &config)
|
||||
}
|
||||
HooksCommands::PreCommand { command } => {
|
||||
cli::hooks::pre_command_hook(&command.join(" "), &config)
|
||||
}
|
||||
HooksCommands::PostCommand {
|
||||
command,
|
||||
success,
|
||||
stderr,
|
||||
} => cli::hooks::post_command_hook(
|
||||
&command.join(" "),
|
||||
success,
|
||||
stderr.as_deref(),
|
||||
&config,
|
||||
),
|
||||
HooksCommands::SessionStart { session_id, resume } => {
|
||||
cli::hooks::session_start_hook(session_id.as_deref(), resume, &config)
|
||||
}
|
||||
HooksCommands::SessionEnd { export_metrics } => {
|
||||
cli::hooks::session_end_hook(export_metrics, &config)
|
||||
}
|
||||
HooksCommands::PreCompact { length, auto } => {
|
||||
cli::hooks::pre_compact_hook(length, auto, &config)
|
||||
}
|
||||
HooksCommands::SuggestContext => cli::hooks::suggest_context_cmd(&config),
|
||||
HooksCommands::TrackNotification { notification_type } => {
|
||||
cli::hooks::track_notification_cmd(notification_type.as_deref(), &config)
|
||||
}
|
||||
// Claude Code v2.0.55+ features
|
||||
HooksCommands::LspDiagnostic {
|
||||
file,
|
||||
severity,
|
||||
message,
|
||||
} => cli::hooks::lsp_diagnostic_cmd(
|
||||
file.as_deref(),
|
||||
severity.as_deref(),
|
||||
message.as_deref(),
|
||||
&config,
|
||||
),
|
||||
HooksCommands::SuggestUltrathink { task, file } => {
|
||||
cli::hooks::suggest_ultrathink_cmd(&task.join(" "), file.as_deref(), &config)
|
||||
}
|
||||
HooksCommands::AsyncAgent {
|
||||
action,
|
||||
agent_id,
|
||||
task,
|
||||
} => cli::hooks::async_agent_cmd(
|
||||
&action,
|
||||
agent_id.as_deref(),
|
||||
task.as_deref(),
|
||||
&config,
|
||||
),
|
||||
HooksCommands::RecordError { command, stderr } => {
|
||||
cli::hooks::record_error_cmd(&command, &stderr, &config)
|
||||
}
|
||||
HooksCommands::SuggestFix { error_code } => {
|
||||
cli::hooks::suggest_fix_cmd(&error_code, &config)
|
||||
}
|
||||
HooksCommands::SuggestNext { file, count } => {
|
||||
cli::hooks::suggest_next_cmd(&file, count, &config)
|
||||
}
|
||||
HooksCommands::ShouldTest { file } => cli::hooks::should_test_cmd(&file, &config),
|
||||
HooksCommands::SwarmRegister {
|
||||
agent_id,
|
||||
agent_type,
|
||||
capabilities,
|
||||
} => cli::hooks::swarm_register_cmd(
|
||||
&agent_id,
|
||||
&agent_type,
|
||||
capabilities.as_deref(),
|
||||
&config,
|
||||
),
|
||||
HooksCommands::SwarmCoordinate {
|
||||
source,
|
||||
target,
|
||||
weight,
|
||||
} => cli::hooks::swarm_coordinate_cmd(&source, &target, weight, &config),
|
||||
HooksCommands::SwarmOptimize { tasks } => {
|
||||
cli::hooks::swarm_optimize_cmd(&tasks, &config)
|
||||
}
|
||||
HooksCommands::SwarmRecommend { task_type } => {
|
||||
cli::hooks::swarm_recommend_cmd(&task_type, &config)
|
||||
}
|
||||
HooksCommands::SwarmHeal { agent_id } => {
|
||||
cli::hooks::swarm_heal_cmd(&agent_id, &config)
|
||||
}
|
||||
HooksCommands::SwarmStats => cli::hooks::swarm_stats_cmd(&config),
|
||||
HooksCommands::Completions { shell } => cli::hooks::generate_completions(shell),
|
||||
HooksCommands::Compress => cli::hooks::compress_storage(&config),
|
||||
HooksCommands::CacheStats => cli::hooks::cache_stats(&config),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Handle errors
|
||||
if let Err(e) = result {
|
||||
eprintln!("{}", cli::format::format_error(&e.to_string()));
|
||||
if cli.debug {
|
||||
eprintln!("\n{:#?}", e);
|
||||
} else {
|
||||
eprintln!("\n{}", "Run with --debug for more details".dimmed());
|
||||
}
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse query vector from string
|
||||
fn parse_query_vector(s: &str) -> Result<Vec<f32>> {
|
||||
// Try JSON first
|
||||
if s.trim().starts_with('[') {
|
||||
return serde_json::from_str(s)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse query vector as JSON: {}", e));
|
||||
}
|
||||
|
||||
// Try comma-separated
|
||||
s.split(',')
|
||||
.map(|s| s.trim().parse::<f32>())
|
||||
.collect::<std::result::Result<Vec<f32>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse query vector: {}", e))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_query_vector_json() {
|
||||
let vec = parse_query_vector("[1.0, 2.0, 3.0]").unwrap();
|
||||
assert_eq!(vec, vec![1.0, 2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_query_vector_csv() {
|
||||
let vec = parse_query_vector("1.0, 2.0, 3.0").unwrap();
|
||||
assert_eq!(vec, vec![1.0, 2.0, 3.0]);
|
||||
}
|
||||
}
|
||||
463
vendor/ruvector/crates/ruvector-cli/src/mcp/gnn_cache.rs
vendored
Normal file
463
vendor/ruvector/crates/ruvector-cli/src/mcp/gnn_cache.rs
vendored
Normal file
@@ -0,0 +1,463 @@
|
||||
//! GNN Layer Caching for Performance Optimization
|
||||
//!
|
||||
//! This module provides persistent caching for GNN layers and query results,
|
||||
//! eliminating the ~2.5s overhead per operation from process initialization,
|
||||
//! database loading, and index deserialization.
|
||||
//!
|
||||
//! ## Performance Impact
|
||||
//!
|
||||
//! | Operation | Before | After | Improvement |
|
||||
//! |-----------|--------|-------|-------------|
|
||||
//! | Layer init | ~2.5s | ~5-10ms | 250-500x |
|
||||
//! | Query | ~2.5s | ~5-10ms | 250-500x |
|
||||
//! | Batch query | ~2.5s * N | ~5-10ms | Amortized |
|
||||
|
||||
use lru::LruCache;
|
||||
use ruvector_gnn::layer::RuvectorLayer;
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Cache entry with metadata for monitoring
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheEntry<T> {
|
||||
pub value: T,
|
||||
pub created_at: Instant,
|
||||
pub last_accessed: Instant,
|
||||
pub access_count: u64,
|
||||
}
|
||||
|
||||
impl<T: Clone> CacheEntry<T> {
|
||||
pub fn new(value: T) -> Self {
|
||||
let now = Instant::now();
|
||||
Self {
|
||||
value,
|
||||
created_at: now,
|
||||
last_accessed: now,
|
||||
access_count: 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn access(&mut self) -> &T {
|
||||
self.last_accessed = Instant::now();
|
||||
self.access_count += 1;
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the GNN cache
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GnnCacheConfig {
|
||||
/// Maximum number of GNN layers to cache
|
||||
pub max_layers: usize,
|
||||
/// Maximum number of query results to cache
|
||||
pub max_query_results: usize,
|
||||
/// TTL for cached query results (in seconds)
|
||||
pub query_result_ttl_secs: u64,
|
||||
/// Whether to preload common layer configurations
|
||||
pub preload_common: bool,
|
||||
}
|
||||
|
||||
impl Default for GnnCacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_layers: 32,
|
||||
max_query_results: 1000,
|
||||
query_result_ttl_secs: 300, // 5 minutes
|
||||
preload_common: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query result cache key
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct QueryCacheKey {
|
||||
/// Layer configuration hash
|
||||
pub layer_hash: String,
|
||||
/// Query vector hash (first 8 floats as u64 bits)
|
||||
pub query_hash: u64,
|
||||
/// Number of results requested
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl QueryCacheKey {
|
||||
pub fn new(layer_id: &str, query: &[f32], k: usize) -> Self {
|
||||
// Simple hash of query vector
|
||||
let query_hash = query
|
||||
.iter()
|
||||
.take(8)
|
||||
.fold(0u64, |acc, &v| acc.wrapping_add(v.to_bits() as u64));
|
||||
|
||||
Self {
|
||||
layer_hash: layer_id.to_string(),
|
||||
query_hash,
|
||||
k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cached query result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedQueryResult {
|
||||
pub result: Vec<f32>,
|
||||
pub cached_at: Instant,
|
||||
}
|
||||
|
||||
/// GNN Layer cache with LRU eviction and TTL support
|
||||
pub struct GnnCache {
|
||||
/// Cached GNN layers by configuration hash
|
||||
layers: Arc<RwLock<HashMap<String, CacheEntry<RuvectorLayer>>>>,
|
||||
/// LRU cache for query results
|
||||
query_results: Arc<RwLock<LruCache<QueryCacheKey, CachedQueryResult>>>,
|
||||
/// Configuration
|
||||
config: GnnCacheConfig,
|
||||
/// Cache statistics
|
||||
stats: Arc<RwLock<CacheStats>>,
|
||||
}
|
||||
|
||||
/// Cache statistics for monitoring
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CacheStats {
|
||||
pub layer_hits: u64,
|
||||
pub layer_misses: u64,
|
||||
pub query_hits: u64,
|
||||
pub query_misses: u64,
|
||||
pub evictions: u64,
|
||||
pub total_queries: u64,
|
||||
}
|
||||
|
||||
impl CacheStats {
|
||||
pub fn layer_hit_rate(&self) -> f64 {
|
||||
let total = self.layer_hits + self.layer_misses;
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.layer_hits as f64 / total as f64
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query_hit_rate(&self) -> f64 {
|
||||
let total = self.query_hits + self.query_misses;
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.query_hits as f64 / total as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GnnCache {
|
||||
/// Create a new GNN cache with the given configuration
|
||||
pub fn new(config: GnnCacheConfig) -> Self {
|
||||
let query_cache_size =
|
||||
NonZeroUsize::new(config.max_query_results).unwrap_or(NonZeroUsize::new(1000).unwrap());
|
||||
|
||||
Self {
|
||||
layers: Arc::new(RwLock::new(HashMap::new())),
|
||||
query_results: Arc::new(RwLock::new(LruCache::new(query_cache_size))),
|
||||
config,
|
||||
stats: Arc::new(RwLock::new(CacheStats::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create a GNN layer with the specified configuration
|
||||
pub async fn get_or_create_layer(
|
||||
&self,
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
heads: usize,
|
||||
dropout: f32,
|
||||
) -> RuvectorLayer {
|
||||
let key = format!(
|
||||
"{}_{}_{}_{}",
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
heads,
|
||||
(dropout * 1000.0) as u32
|
||||
);
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let mut layers = self.layers.write().await;
|
||||
if let Some(entry) = layers.get_mut(&key) {
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.layer_hits += 1;
|
||||
return entry.access().clone();
|
||||
}
|
||||
}
|
||||
|
||||
// Create new layer
|
||||
let layer = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout)
|
||||
.expect("GNN layer cache: invalid layer configuration");
|
||||
|
||||
// Cache it
|
||||
{
|
||||
let mut layers = self.layers.write().await;
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.layer_misses += 1;
|
||||
|
||||
// Evict if necessary
|
||||
if layers.len() >= self.config.max_layers {
|
||||
// Simple eviction: remove oldest entry
|
||||
if let Some(oldest_key) = layers
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.last_accessed)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
layers.remove(&oldest_key);
|
||||
stats.evictions += 1;
|
||||
}
|
||||
}
|
||||
|
||||
layers.insert(key, CacheEntry::new(layer.clone()));
|
||||
}
|
||||
|
||||
layer
|
||||
}
|
||||
|
||||
/// Get cached query result if available and not expired
|
||||
pub async fn get_query_result(&self, key: &QueryCacheKey) -> Option<Vec<f32>> {
|
||||
let mut results = self.query_results.write().await;
|
||||
|
||||
if let Some(cached) = results.get(key) {
|
||||
let ttl = Duration::from_secs(self.config.query_result_ttl_secs);
|
||||
if cached.cached_at.elapsed() < ttl {
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.query_hits += 1;
|
||||
stats.total_queries += 1;
|
||||
return Some(cached.result.clone());
|
||||
}
|
||||
// Expired, remove it
|
||||
results.pop(key);
|
||||
}
|
||||
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.query_misses += 1;
|
||||
stats.total_queries += 1;
|
||||
None
|
||||
}
|
||||
|
||||
/// Cache a query result
|
||||
pub async fn cache_query_result(&self, key: QueryCacheKey, result: Vec<f32>) {
|
||||
let mut results = self.query_results.write().await;
|
||||
results.put(
|
||||
key,
|
||||
CachedQueryResult {
|
||||
result,
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Get current cache statistics
|
||||
pub async fn stats(&self) -> CacheStats {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear all caches
|
||||
pub async fn clear(&self) {
|
||||
self.layers.write().await.clear();
|
||||
self.query_results.write().await.clear();
|
||||
}
|
||||
|
||||
/// Preload common layer configurations for faster first access
|
||||
pub async fn preload_common_layers(&self) {
|
||||
// Common configurations used in practice
|
||||
let common_configs = [
|
||||
(128, 256, 4, 0.1), // Small model
|
||||
(256, 512, 8, 0.1), // Medium model
|
||||
(384, 768, 8, 0.1), // Base model (BERT-like)
|
||||
(768, 1024, 16, 0.1), // Large model
|
||||
];
|
||||
|
||||
for (input, hidden, heads, dropout) in common_configs {
|
||||
let _ = self
|
||||
.get_or_create_layer(input, hidden, heads, dropout)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get number of cached layers
|
||||
pub async fn layer_count(&self) -> usize {
|
||||
self.layers.read().await.len()
|
||||
}
|
||||
|
||||
/// Get number of cached query results
|
||||
pub async fn query_result_count(&self) -> usize {
|
||||
self.query_results.read().await.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch operation for multiple GNN forward passes
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BatchGnnRequest {
|
||||
pub layer_config: LayerConfig,
|
||||
pub operations: Vec<GnnOperation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LayerConfig {
|
||||
pub input_dim: usize,
|
||||
pub hidden_dim: usize,
|
||||
pub heads: usize,
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GnnOperation {
|
||||
pub node_embedding: Vec<f32>,
|
||||
pub neighbor_embeddings: Vec<Vec<f32>>,
|
||||
pub edge_weights: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BatchGnnResult {
|
||||
pub results: Vec<Vec<f32>>,
|
||||
pub cached_count: usize,
|
||||
pub computed_count: usize,
|
||||
pub total_time_ms: f64,
|
||||
}
|
||||
|
||||
impl GnnCache {
|
||||
/// Execute batch GNN operations with caching
|
||||
pub async fn batch_forward(&self, request: BatchGnnRequest) -> BatchGnnResult {
|
||||
let start = Instant::now();
|
||||
|
||||
// Get or create the layer
|
||||
let layer = self
|
||||
.get_or_create_layer(
|
||||
request.layer_config.input_dim,
|
||||
request.layer_config.hidden_dim,
|
||||
request.layer_config.heads,
|
||||
request.layer_config.dropout,
|
||||
)
|
||||
.await;
|
||||
|
||||
let layer_id = format!(
|
||||
"{}_{}_{}",
|
||||
request.layer_config.input_dim,
|
||||
request.layer_config.hidden_dim,
|
||||
request.layer_config.heads
|
||||
);
|
||||
|
||||
let mut results = Vec::with_capacity(request.operations.len());
|
||||
let mut cached_count = 0;
|
||||
let mut computed_count = 0;
|
||||
|
||||
for op in &request.operations {
|
||||
// Check cache
|
||||
let cache_key = QueryCacheKey::new(&layer_id, &op.node_embedding, 1);
|
||||
|
||||
if let Some(cached) = self.get_query_result(&cache_key).await {
|
||||
results.push(cached);
|
||||
cached_count += 1;
|
||||
} else {
|
||||
// Compute forward pass
|
||||
let result = layer.forward(
|
||||
&op.node_embedding,
|
||||
&op.neighbor_embeddings,
|
||||
&op.edge_weights,
|
||||
);
|
||||
|
||||
// Cache the result
|
||||
self.cache_query_result(cache_key, result.clone()).await;
|
||||
results.push(result);
|
||||
computed_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
BatchGnnResult {
|
||||
results,
|
||||
cached_count,
|
||||
computed_count,
|
||||
total_time_ms: start.elapsed().as_secs_f64() * 1000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_layer_caching() {
|
||||
let cache = GnnCache::new(GnnCacheConfig::default());
|
||||
|
||||
// First access - miss
|
||||
let layer1 = cache.get_or_create_layer(128, 256, 4, 0.1).await;
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.layer_misses, 1);
|
||||
assert_eq!(stats.layer_hits, 0);
|
||||
|
||||
// Second access - hit
|
||||
let _layer2 = cache.get_or_create_layer(128, 256, 4, 0.1).await;
|
||||
let stats = cache.stats().await;
|
||||
assert_eq!(stats.layer_misses, 1);
|
||||
assert_eq!(stats.layer_hits, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_result_caching() {
|
||||
let cache = GnnCache::new(GnnCacheConfig::default());
|
||||
|
||||
let key = QueryCacheKey::new("test", &[1.0, 2.0, 3.0], 10);
|
||||
let result = vec![0.1, 0.2, 0.3];
|
||||
|
||||
// Cache miss
|
||||
assert!(cache.get_query_result(&key).await.is_none());
|
||||
|
||||
// Cache the result
|
||||
cache.cache_query_result(key.clone(), result.clone()).await;
|
||||
|
||||
// Cache hit
|
||||
let cached = cache.get_query_result(&key).await;
|
||||
assert!(cached.is_some());
|
||||
assert_eq!(cached.unwrap(), result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_forward() {
|
||||
let cache = GnnCache::new(GnnCacheConfig::default());
|
||||
|
||||
let request = BatchGnnRequest {
|
||||
layer_config: LayerConfig {
|
||||
input_dim: 4,
|
||||
hidden_dim: 8,
|
||||
heads: 2,
|
||||
dropout: 0.1,
|
||||
},
|
||||
operations: vec![
|
||||
GnnOperation {
|
||||
node_embedding: vec![1.0, 2.0, 3.0, 4.0],
|
||||
neighbor_embeddings: vec![vec![0.5, 1.0, 1.5, 2.0]],
|
||||
edge_weights: vec![1.0],
|
||||
},
|
||||
GnnOperation {
|
||||
node_embedding: vec![2.0, 3.0, 4.0, 5.0],
|
||||
neighbor_embeddings: vec![vec![1.0, 1.5, 2.0, 2.5]],
|
||||
edge_weights: vec![1.0],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let result = cache.batch_forward(request).await;
|
||||
assert_eq!(result.results.len(), 2);
|
||||
assert_eq!(result.computed_count, 2);
|
||||
assert_eq!(result.cached_count, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_preload_common_layers() {
|
||||
let cache = GnnCache::new(GnnCacheConfig {
|
||||
preload_common: true,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
cache.preload_common_layers().await;
|
||||
|
||||
// Should have 4 preloaded layers
|
||||
assert_eq!(cache.layer_count().await, 4);
|
||||
}
|
||||
}
|
||||
927
vendor/ruvector/crates/ruvector-cli/src/mcp/handlers.rs
vendored
Normal file
927
vendor/ruvector/crates/ruvector-cli/src/mcp/handlers.rs
vendored
Normal file
@@ -0,0 +1,927 @@
|
||||
//! MCP request handlers
|
||||
|
||||
use super::gnn_cache::{BatchGnnRequest, GnnCache, GnnCacheConfig, GnnOperation, LayerConfig};
|
||||
use super::protocol::*;
|
||||
use crate::config::Config;
|
||||
use anyhow::{Context, Result};
|
||||
use ruvector_core::{
|
||||
types::{DbOptions, DistanceMetric, SearchQuery, VectorEntry},
|
||||
VectorDB,
|
||||
};
|
||||
use ruvector_gnn::{compress::TensorCompress, search::differentiable_search};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// MCP handler state with GNN caching for performance optimization
|
||||
pub struct McpHandler {
|
||||
config: Config,
|
||||
databases: Arc<RwLock<HashMap<String, Arc<VectorDB>>>>,
|
||||
/// GNN layer cache for eliminating ~2.5s initialization overhead
|
||||
gnn_cache: Arc<GnnCache>,
|
||||
/// Tensor compressor for GNN operations
|
||||
tensor_compress: Arc<TensorCompress>,
|
||||
/// Allowed base directory for all file operations (path confinement)
|
||||
allowed_data_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl McpHandler {
|
||||
pub fn new(config: Config) -> Self {
|
||||
let gnn_cache = Arc::new(GnnCache::new(GnnCacheConfig::default()));
|
||||
let allowed_data_dir = PathBuf::from(&config.mcp.data_dir);
|
||||
// Canonicalize at startup so all later comparisons are absolute
|
||||
let allowed_data_dir = std::fs::canonicalize(&allowed_data_dir)
|
||||
.unwrap_or_else(|_| std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/")));
|
||||
|
||||
Self {
|
||||
config,
|
||||
databases: Arc::new(RwLock::new(HashMap::new())),
|
||||
gnn_cache,
|
||||
tensor_compress: Arc::new(TensorCompress::new()),
|
||||
allowed_data_dir,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize with preloaded GNN layers for optimal performance
|
||||
pub async fn with_preload(config: Config) -> Self {
|
||||
let handler = Self::new(config);
|
||||
handler.gnn_cache.preload_common_layers().await;
|
||||
handler
|
||||
}
|
||||
|
||||
/// Validate that a user-supplied path resolves within the allowed data directory.
|
||||
///
|
||||
/// Prevents CWE-22 path traversal by:
|
||||
/// 1. Resolving the path relative to `allowed_data_dir` (not cwd)
|
||||
/// 2. Canonicalizing to eliminate `..`, symlinks, and other tricks
|
||||
/// 3. Checking that the canonical path starts with the allowed directory
|
||||
fn validate_path(&self, user_path: &str) -> Result<PathBuf> {
|
||||
// Reject obviously malicious absolute paths outside data dir
|
||||
let path = Path::new(user_path);
|
||||
|
||||
// If relative, resolve against allowed_data_dir
|
||||
let resolved = if path.is_absolute() {
|
||||
PathBuf::from(user_path)
|
||||
} else {
|
||||
self.allowed_data_dir.join(user_path)
|
||||
};
|
||||
|
||||
// For existing paths, canonicalize resolves symlinks and ..
|
||||
// For non-existing paths, canonicalize the parent and append the filename
|
||||
let canonical = if resolved.exists() {
|
||||
std::fs::canonicalize(&resolved)
|
||||
.with_context(|| format!("Failed to resolve path: {}", user_path))?
|
||||
} else {
|
||||
// Canonicalize the parent directory (must exist), then append filename
|
||||
let parent = resolved.parent().unwrap_or(Path::new("/"));
|
||||
let parent_canonical = if parent.exists() {
|
||||
std::fs::canonicalize(parent).with_context(|| {
|
||||
format!("Parent directory does not exist: {}", parent.display())
|
||||
})?
|
||||
} else {
|
||||
// Create the parent directory within allowed_data_dir if it doesn't exist
|
||||
anyhow::bail!(
|
||||
"Path '{}' references non-existent directory '{}'",
|
||||
user_path,
|
||||
parent.display()
|
||||
);
|
||||
};
|
||||
let filename = resolved
|
||||
.file_name()
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid path: no filename in '{}'", user_path))?;
|
||||
parent_canonical.join(filename)
|
||||
};
|
||||
|
||||
// Security check: canonical path must be inside allowed_data_dir
|
||||
if !canonical.starts_with(&self.allowed_data_dir) {
|
||||
anyhow::bail!(
|
||||
"Access denied: path '{}' resolves to '{}' which is outside the allowed data directory '{}'",
|
||||
user_path,
|
||||
canonical.display(),
|
||||
self.allowed_data_dir.display()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
/// Handle MCP request
|
||||
pub async fn handle_request(&self, request: McpRequest) -> McpResponse {
|
||||
match request.method.as_str() {
|
||||
"initialize" => self.handle_initialize(request.id).await,
|
||||
"tools/list" => self.handle_tools_list(request.id).await,
|
||||
"tools/call" => self.handle_tools_call(request.id, request.params).await,
|
||||
"resources/list" => self.handle_resources_list(request.id).await,
|
||||
"resources/read" => self.handle_resources_read(request.id, request.params).await,
|
||||
"prompts/list" => self.handle_prompts_list(request.id).await,
|
||||
"prompts/get" => self.handle_prompts_get(request.id, request.params).await,
|
||||
_ => McpResponse::error(
|
||||
request.id,
|
||||
McpError::new(error_codes::METHOD_NOT_FOUND, "Method not found"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_initialize(&self, id: Option<Value>) -> McpResponse {
|
||||
McpResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {},
|
||||
"resources": {},
|
||||
"prompts": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "ruvector-mcp",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_tools_list(&self, id: Option<Value>) -> McpResponse {
|
||||
let tools = vec![
|
||||
McpTool {
|
||||
name: "vector_db_create".to_string(),
|
||||
description: "Create a new vector database".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Database file path"},
|
||||
"dimensions": {"type": "integer", "description": "Vector dimensions"},
|
||||
"distance_metric": {"type": "string", "enum": ["euclidean", "cosine", "dotproduct", "manhattan"]}
|
||||
},
|
||||
"required": ["path", "dimensions"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "vector_db_insert".to_string(),
|
||||
description: "Insert vectors into database".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"db_path": {"type": "string"},
|
||||
"vectors": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"vector": {"type": "array", "items": {"type": "number"}},
|
||||
"metadata": {"type": "object"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["db_path", "vectors"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "vector_db_search".to_string(),
|
||||
description: "Search for similar vectors".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"db_path": {"type": "string"},
|
||||
"query": {"type": "array", "items": {"type": "number"}},
|
||||
"k": {"type": "integer", "default": 10},
|
||||
"filter": {"type": "object"}
|
||||
},
|
||||
"required": ["db_path", "query"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "vector_db_stats".to_string(),
|
||||
description: "Get database statistics".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"db_path": {"type": "string"}
|
||||
},
|
||||
"required": ["db_path"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "vector_db_backup".to_string(),
|
||||
description: "Backup database to file".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"db_path": {"type": "string"},
|
||||
"backup_path": {"type": "string"}
|
||||
},
|
||||
"required": ["db_path", "backup_path"]
|
||||
}),
|
||||
},
|
||||
// GNN Tools with persistent caching (~250-500x faster)
|
||||
McpTool {
|
||||
name: "gnn_layer_create".to_string(),
|
||||
description: "Create/cache a GNN layer (eliminates ~2.5s init overhead)"
|
||||
.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_dim": {"type": "integer", "description": "Input embedding dimension"},
|
||||
"hidden_dim": {"type": "integer", "description": "Hidden layer dimension"},
|
||||
"heads": {"type": "integer", "description": "Number of attention heads"},
|
||||
"dropout": {"type": "number", "default": 0.1, "description": "Dropout rate"}
|
||||
},
|
||||
"required": ["input_dim", "hidden_dim", "heads"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "gnn_forward".to_string(),
|
||||
description: "Forward pass through cached GNN layer (~5-10ms vs ~2.5s)".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"layer_id": {"type": "string", "description": "Layer config: input_hidden_heads"},
|
||||
"node_embedding": {"type": "array", "items": {"type": "number"}},
|
||||
"neighbor_embeddings": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}},
|
||||
"edge_weights": {"type": "array", "items": {"type": "number"}}
|
||||
},
|
||||
"required": ["layer_id", "node_embedding", "neighbor_embeddings", "edge_weights"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "gnn_batch_forward".to_string(),
|
||||
description: "Batch GNN forward passes with result caching (amortized cost)"
|
||||
.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"layer_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_dim": {"type": "integer"},
|
||||
"hidden_dim": {"type": "integer"},
|
||||
"heads": {"type": "integer"},
|
||||
"dropout": {"type": "number", "default": 0.1}
|
||||
},
|
||||
"required": ["input_dim", "hidden_dim", "heads"]
|
||||
},
|
||||
"operations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"node_embedding": {"type": "array", "items": {"type": "number"}},
|
||||
"neighbor_embeddings": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}},
|
||||
"edge_weights": {"type": "array", "items": {"type": "number"}}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["layer_config", "operations"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "gnn_cache_stats".to_string(),
|
||||
description: "Get GNN cache statistics (hit rates, counts)".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"include_details": {"type": "boolean", "default": false}
|
||||
}
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "gnn_compress".to_string(),
|
||||
description: "Compress embedding based on access frequency".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"embedding": {"type": "array", "items": {"type": "number"}},
|
||||
"access_freq": {"type": "number", "description": "Access frequency 0.0-1.0"}
|
||||
},
|
||||
"required": ["embedding", "access_freq"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "gnn_decompress".to_string(),
|
||||
description: "Decompress a compressed tensor".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"compressed_json": {"type": "string", "description": "Compressed tensor JSON"}
|
||||
},
|
||||
"required": ["compressed_json"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "gnn_search".to_string(),
|
||||
description: "Differentiable search with soft attention".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "array", "items": {"type": "number"}},
|
||||
"candidates": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}},
|
||||
"k": {"type": "integer", "description": "Number of results"},
|
||||
"temperature": {"type": "number", "default": 1.0}
|
||||
},
|
||||
"required": ["query", "candidates", "k"]
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
McpResponse::success(id, json!({ "tools": tools }))
|
||||
}
|
||||
|
||||
async fn handle_tools_call(&self, id: Option<Value>, params: Option<Value>) -> McpResponse {
|
||||
let params = match params {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return McpResponse::error(
|
||||
id,
|
||||
McpError::new(error_codes::INVALID_PARAMS, "Missing params"),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let tool_name = params["name"].as_str().unwrap_or("");
|
||||
let arguments = ¶ms["arguments"];
|
||||
|
||||
let result = match tool_name {
|
||||
// Vector DB tools
|
||||
"vector_db_create" => self.tool_create_db(arguments).await,
|
||||
"vector_db_insert" => self.tool_insert(arguments).await,
|
||||
"vector_db_search" => self.tool_search(arguments).await,
|
||||
"vector_db_stats" => self.tool_stats(arguments).await,
|
||||
"vector_db_backup" => self.tool_backup(arguments).await,
|
||||
// GNN tools with caching
|
||||
"gnn_layer_create" => self.tool_gnn_layer_create(arguments).await,
|
||||
"gnn_forward" => self.tool_gnn_forward(arguments).await,
|
||||
"gnn_batch_forward" => self.tool_gnn_batch_forward(arguments).await,
|
||||
"gnn_cache_stats" => self.tool_gnn_cache_stats(arguments).await,
|
||||
"gnn_compress" => self.tool_gnn_compress(arguments).await,
|
||||
"gnn_decompress" => self.tool_gnn_decompress(arguments).await,
|
||||
"gnn_search" => self.tool_gnn_search(arguments).await,
|
||||
_ => Err(anyhow::anyhow!("Unknown tool: {}", tool_name)),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(value) => {
|
||||
McpResponse::success(id, json!({ "content": [{"type": "text", "text": value}] }))
|
||||
}
|
||||
Err(e) => McpResponse::error(
|
||||
id,
|
||||
McpError::new(error_codes::INTERNAL_ERROR, e.to_string()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_resources_list(&self, id: Option<Value>) -> McpResponse {
|
||||
McpResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"resources": [
|
||||
{
|
||||
"uri": "database://local/default",
|
||||
"name": "Default Database",
|
||||
"description": "Default vector database",
|
||||
"mimeType": "application/x-ruvector-db"
|
||||
}
|
||||
]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_resources_read(
|
||||
&self,
|
||||
id: Option<Value>,
|
||||
_params: Option<Value>,
|
||||
) -> McpResponse {
|
||||
McpResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"contents": [{
|
||||
"uri": "database://local/default",
|
||||
"mimeType": "application/json",
|
||||
"text": "{\"status\": \"available\"}"
|
||||
}]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_prompts_list(&self, id: Option<Value>) -> McpResponse {
|
||||
McpResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"prompts": [
|
||||
{
|
||||
"name": "semantic-search",
|
||||
"description": "Generate a semantic search query",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "query",
|
||||
"description": "Natural language query",
|
||||
"required": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_prompts_get(&self, id: Option<Value>, _params: Option<Value>) -> McpResponse {
|
||||
McpResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"description": "Semantic search template",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Search for vectors related to: {{query}}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
// Tool implementations
|
||||
async fn tool_create_db(&self, args: &Value) -> Result<String> {
|
||||
let params: CreateDbParams =
|
||||
serde_json::from_value(args.clone()).context("Invalid parameters")?;
|
||||
|
||||
// Validate path to prevent directory traversal (CWE-22)
|
||||
let validated_path = self.validate_path(¶ms.path)?;
|
||||
|
||||
let mut db_options = self.config.to_db_options();
|
||||
db_options.storage_path = validated_path.to_string_lossy().to_string();
|
||||
db_options.dimensions = params.dimensions;
|
||||
|
||||
if let Some(metric) = params.distance_metric {
|
||||
db_options.distance_metric = match metric.as_str() {
|
||||
"euclidean" => DistanceMetric::Euclidean,
|
||||
"cosine" => DistanceMetric::Cosine,
|
||||
"dotproduct" => DistanceMetric::DotProduct,
|
||||
"manhattan" => DistanceMetric::Manhattan,
|
||||
_ => DistanceMetric::Cosine,
|
||||
};
|
||||
}
|
||||
|
||||
let db = VectorDB::new(db_options)?;
|
||||
let path_str = validated_path.to_string_lossy().to_string();
|
||||
self.databases
|
||||
.write()
|
||||
.await
|
||||
.insert(path_str.clone(), Arc::new(db));
|
||||
|
||||
Ok(format!("Database created at: {}", path_str))
|
||||
}
|
||||
|
||||
async fn tool_insert(&self, args: &Value) -> Result<String> {
|
||||
let params: InsertParams = serde_json::from_value(args.clone())?;
|
||||
let db = self.get_or_open_db(¶ms.db_path).await?;
|
||||
|
||||
let entries: Vec<VectorEntry> = params
|
||||
.vectors
|
||||
.into_iter()
|
||||
.map(|v| VectorEntry {
|
||||
id: v.id,
|
||||
vector: v.vector,
|
||||
metadata: v.metadata.and_then(|m| serde_json::from_value(m).ok()),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids = db.insert_batch(entries)?;
|
||||
Ok(format!("Inserted {} vectors", ids.len()))
|
||||
}
|
||||
|
||||
async fn tool_search(&self, args: &Value) -> Result<String> {
|
||||
let params: SearchParams = serde_json::from_value(args.clone())?;
|
||||
let db = self.get_or_open_db(¶ms.db_path).await?;
|
||||
|
||||
let results = db.search(SearchQuery {
|
||||
vector: params.query,
|
||||
k: params.k,
|
||||
filter: params.filter.and_then(|f| serde_json::from_value(f).ok()),
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
serde_json::to_string_pretty(&results).context("Failed to serialize results")
|
||||
}
|
||||
|
||||
async fn tool_stats(&self, args: &Value) -> Result<String> {
|
||||
let params: StatsParams = serde_json::from_value(args.clone())?;
|
||||
let db = self.get_or_open_db(¶ms.db_path).await?;
|
||||
|
||||
let count = db.len()?;
|
||||
let options = db.options();
|
||||
|
||||
Ok(json!({
|
||||
"count": count,
|
||||
"dimensions": options.dimensions,
|
||||
"distance_metric": format!("{:?}", options.distance_metric),
|
||||
"hnsw_enabled": options.hnsw_config.is_some()
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
async fn tool_backup(&self, args: &Value) -> Result<String> {
|
||||
let params: BackupParams = serde_json::from_value(args.clone())?;
|
||||
|
||||
// Validate both paths to prevent directory traversal (CWE-22)
|
||||
let validated_db_path = self.validate_path(¶ms.db_path)?;
|
||||
let validated_backup_path = self.validate_path(¶ms.backup_path)?;
|
||||
|
||||
std::fs::copy(&validated_db_path, &validated_backup_path)
|
||||
.context("Failed to backup database")?;
|
||||
|
||||
Ok(format!("Backed up to: {}", validated_backup_path.display()))
|
||||
}
|
||||
|
||||
async fn get_or_open_db(&self, path: &str) -> Result<Arc<VectorDB>> {
|
||||
// Validate path to prevent directory traversal (CWE-22)
|
||||
let validated_path = self.validate_path(path)?;
|
||||
let path_str = validated_path.to_string_lossy().to_string();
|
||||
|
||||
let databases = self.databases.read().await;
|
||||
if let Some(db) = databases.get(&path_str) {
|
||||
return Ok(db.clone());
|
||||
}
|
||||
drop(databases);
|
||||
|
||||
// Open new database
|
||||
let mut db_options = self.config.to_db_options();
|
||||
db_options.storage_path = path_str.clone();
|
||||
|
||||
let db = Arc::new(VectorDB::new(db_options)?);
|
||||
self.databases.write().await.insert(path_str, db.clone());
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
// ==================== GNN Tool Implementations ====================
|
||||
// These tools eliminate ~2.5s overhead per operation via persistent caching
|
||||
|
||||
/// Create or retrieve a cached GNN layer
|
||||
async fn tool_gnn_layer_create(&self, args: &Value) -> Result<String> {
|
||||
let params: GnnLayerCreateParams =
|
||||
serde_json::from_value(args.clone()).context("Invalid parameters")?;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let _layer = self
|
||||
.gnn_cache
|
||||
.get_or_create_layer(
|
||||
params.input_dim,
|
||||
params.hidden_dim,
|
||||
params.heads,
|
||||
params.dropout,
|
||||
)
|
||||
.await;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
let layer_id = format!(
|
||||
"{}_{}_{}_{}",
|
||||
params.input_dim,
|
||||
params.hidden_dim,
|
||||
params.heads,
|
||||
(params.dropout * 1000.0) as u32
|
||||
);
|
||||
|
||||
Ok(json!({
|
||||
"layer_id": layer_id,
|
||||
"input_dim": params.input_dim,
|
||||
"hidden_dim": params.hidden_dim,
|
||||
"heads": params.heads,
|
||||
"dropout": params.dropout,
|
||||
"creation_time_ms": elapsed.as_secs_f64() * 1000.0,
|
||||
"cached": elapsed.as_millis() < 50 // <50ms indicates cache hit
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Forward pass through a cached GNN layer
|
||||
async fn tool_gnn_forward(&self, args: &Value) -> Result<String> {
|
||||
let params: GnnForwardParams =
|
||||
serde_json::from_value(args.clone()).context("Invalid parameters")?;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Parse layer_id format: "input_hidden_heads_dropout"
|
||||
let parts: Vec<&str> = params.layer_id.split('_').collect();
|
||||
if parts.len() < 3 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid layer_id format. Expected: input_hidden_heads[_dropout]"
|
||||
));
|
||||
}
|
||||
|
||||
let input_dim: usize = parts[0].parse()?;
|
||||
let hidden_dim: usize = parts[1].parse()?;
|
||||
let heads: usize = parts[2].parse()?;
|
||||
let dropout: f32 = parts
|
||||
.get(3)
|
||||
.map(|s| s.parse::<u32>().unwrap_or(100) as f32 / 1000.0)
|
||||
.unwrap_or(0.1);
|
||||
|
||||
let layer = self
|
||||
.gnn_cache
|
||||
.get_or_create_layer(input_dim, hidden_dim, heads, dropout)
|
||||
.await;
|
||||
|
||||
// Convert f64 to f32
|
||||
let node_f32: Vec<f32> = params.node_embedding.iter().map(|&x| x as f32).collect();
|
||||
let neighbors_f32: Vec<Vec<f32>> = params
|
||||
.neighbor_embeddings
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|&x| x as f32).collect())
|
||||
.collect();
|
||||
let weights_f32: Vec<f32> = params.edge_weights.iter().map(|&x| x as f32).collect();
|
||||
|
||||
let result = layer.forward(&node_f32, &neighbors_f32, &weights_f32);
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Convert back to f64 for JSON
|
||||
let result_f64: Vec<f64> = result.iter().map(|&x| x as f64).collect();
|
||||
|
||||
Ok(json!({
|
||||
"result": result_f64,
|
||||
"output_dim": result.len(),
|
||||
"latency_ms": elapsed.as_secs_f64() * 1000.0
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Batch forward passes with caching
|
||||
async fn tool_gnn_batch_forward(&self, args: &Value) -> Result<String> {
|
||||
let params: GnnBatchForwardParams =
|
||||
serde_json::from_value(args.clone()).context("Invalid parameters")?;
|
||||
|
||||
let request = BatchGnnRequest {
|
||||
layer_config: LayerConfig {
|
||||
input_dim: params.layer_config.input_dim,
|
||||
hidden_dim: params.layer_config.hidden_dim,
|
||||
heads: params.layer_config.heads,
|
||||
dropout: params.layer_config.dropout,
|
||||
},
|
||||
operations: params
|
||||
.operations
|
||||
.into_iter()
|
||||
.map(|op| GnnOperation {
|
||||
node_embedding: op.node_embedding.iter().map(|&x| x as f32).collect(),
|
||||
neighbor_embeddings: op
|
||||
.neighbor_embeddings
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|&x| x as f32).collect())
|
||||
.collect(),
|
||||
edge_weights: op.edge_weights.iter().map(|&x| x as f32).collect(),
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let batch_result = self.gnn_cache.batch_forward(request).await;
|
||||
|
||||
// Convert results to f64
|
||||
let results_f64: Vec<Vec<f64>> = batch_result
|
||||
.results
|
||||
.iter()
|
||||
.map(|r| r.iter().map(|&x| x as f64).collect())
|
||||
.collect();
|
||||
|
||||
Ok(json!({
|
||||
"results": results_f64,
|
||||
"cached_count": batch_result.cached_count,
|
||||
"computed_count": batch_result.computed_count,
|
||||
"total_time_ms": batch_result.total_time_ms,
|
||||
"avg_time_per_op_ms": batch_result.total_time_ms / (batch_result.cached_count + batch_result.computed_count) as f64
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Get GNN cache statistics
|
||||
async fn tool_gnn_cache_stats(&self, args: &Value) -> Result<String> {
|
||||
let params: GnnCacheStatsParams =
|
||||
serde_json::from_value(args.clone()).unwrap_or(GnnCacheStatsParams {
|
||||
include_details: false,
|
||||
});
|
||||
|
||||
let stats = self.gnn_cache.stats().await;
|
||||
let layer_count = self.gnn_cache.layer_count().await;
|
||||
let query_count = self.gnn_cache.query_result_count().await;
|
||||
|
||||
let mut result = json!({
|
||||
"layer_hits": stats.layer_hits,
|
||||
"layer_misses": stats.layer_misses,
|
||||
"layer_hit_rate": format!("{:.2}%", stats.layer_hit_rate() * 100.0),
|
||||
"query_hits": stats.query_hits,
|
||||
"query_misses": stats.query_misses,
|
||||
"query_hit_rate": format!("{:.2}%", stats.query_hit_rate() * 100.0),
|
||||
"total_queries": stats.total_queries,
|
||||
"evictions": stats.evictions,
|
||||
"cached_layers": layer_count,
|
||||
"cached_queries": query_count
|
||||
});
|
||||
|
||||
if params.include_details {
|
||||
result["estimated_memory_saved_ms"] = json!((stats.layer_hits as f64) * 2500.0);
|
||||
// ~2.5s per hit
|
||||
}
|
||||
|
||||
Ok(result.to_string())
|
||||
}
|
||||
|
||||
/// Compress embedding based on access frequency
|
||||
async fn tool_gnn_compress(&self, args: &Value) -> Result<String> {
|
||||
let params: GnnCompressParams =
|
||||
serde_json::from_value(args.clone()).context("Invalid parameters")?;
|
||||
|
||||
let embedding_f32: Vec<f32> = params.embedding.iter().map(|&x| x as f32).collect();
|
||||
|
||||
let compressed = self
|
||||
.tensor_compress
|
||||
.compress(&embedding_f32, params.access_freq as f32)
|
||||
.map_err(|e| anyhow::anyhow!("Compression error: {}", e))?;
|
||||
|
||||
let compressed_json = serde_json::to_string(&compressed)?;
|
||||
|
||||
Ok(json!({
|
||||
"compressed_json": compressed_json,
|
||||
"original_size": params.embedding.len() * 4,
|
||||
"compressed_size": compressed_json.len(),
|
||||
"compression_ratio": (params.embedding.len() * 4) as f64 / compressed_json.len() as f64
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Decompress a compressed tensor
|
||||
async fn tool_gnn_decompress(&self, args: &Value) -> Result<String> {
|
||||
let params: GnnDecompressParams =
|
||||
serde_json::from_value(args.clone()).context("Invalid parameters")?;
|
||||
|
||||
let compressed: ruvector_gnn::compress::CompressedTensor =
|
||||
serde_json::from_str(¶ms.compressed_json)
|
||||
.context("Invalid compressed tensor JSON")?;
|
||||
|
||||
let decompressed = self
|
||||
.tensor_compress
|
||||
.decompress(&compressed)
|
||||
.map_err(|e| anyhow::anyhow!("Decompression error: {}", e))?;
|
||||
|
||||
let decompressed_f64: Vec<f64> = decompressed.iter().map(|&x| x as f64).collect();
|
||||
|
||||
Ok(json!({
|
||||
"embedding": decompressed_f64,
|
||||
"dimensions": decompressed.len()
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Differentiable search with soft attention
|
||||
async fn tool_gnn_search(&self, args: &Value) -> Result<String> {
|
||||
let params: GnnSearchParams =
|
||||
serde_json::from_value(args.clone()).context("Invalid parameters")?;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let query_f32: Vec<f32> = params.query.iter().map(|&x| x as f32).collect();
|
||||
let candidates_f32: Vec<Vec<f32>> = params
|
||||
.candidates
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|&x| x as f32).collect())
|
||||
.collect();
|
||||
|
||||
let (indices, weights) = differentiable_search(
|
||||
&query_f32,
|
||||
&candidates_f32,
|
||||
params.k,
|
||||
params.temperature as f32,
|
||||
);
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
Ok(json!({
|
||||
"indices": indices,
|
||||
"weights": weights.iter().map(|&w| w as f64).collect::<Vec<f64>>(),
|
||||
"k": params.k,
|
||||
"latency_ms": elapsed.as_secs_f64() * 1000.0
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn handler_with_data_dir(data_dir: &Path) -> McpHandler {
|
||||
let mut config = Config::default();
|
||||
config.mcp.data_dir = data_dir.to_string_lossy().to_string();
|
||||
McpHandler::new(config)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_allows_relative_within_data_dir() {
|
||||
let dir = tempdir().unwrap();
|
||||
let handler = handler_with_data_dir(dir.path());
|
||||
|
||||
// Create a file to validate against
|
||||
std::fs::write(dir.path().join("test.db"), b"test").unwrap();
|
||||
|
||||
let result = handler.validate_path("test.db");
|
||||
assert!(result.is_ok(), "Should allow relative path within data dir");
|
||||
assert!(result.unwrap().starts_with(dir.path()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_blocks_absolute_outside_data_dir() {
|
||||
let dir = tempdir().unwrap();
|
||||
let handler = handler_with_data_dir(dir.path());
|
||||
|
||||
let result = handler.validate_path("/etc/passwd");
|
||||
assert!(result.is_err(), "Should block /etc/passwd");
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("outside the allowed data directory"),
|
||||
"Error should mention path confinement: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_blocks_dot_dot_traversal() {
|
||||
let dir = tempdir().unwrap();
|
||||
// Create a subdir so ../.. resolves to something real
|
||||
let subdir = dir.path().join("sub");
|
||||
std::fs::create_dir_all(&subdir).unwrap();
|
||||
let handler = handler_with_data_dir(&subdir);
|
||||
|
||||
let result = handler.validate_path("../../../etc/passwd");
|
||||
assert!(result.is_err(), "Should block ../ traversal: {:?}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_blocks_dot_dot_in_middle() {
|
||||
let dir = tempdir().unwrap();
|
||||
let handler = handler_with_data_dir(dir.path());
|
||||
|
||||
// Create the inner directory
|
||||
std::fs::create_dir_all(dir.path().join("a")).unwrap();
|
||||
|
||||
let result = handler.validate_path("a/../../etc/passwd");
|
||||
assert!(result.is_err(), "Should block ../ in the middle of path");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_allows_subdirectory_within_data_dir() {
|
||||
let dir = tempdir().unwrap();
|
||||
let handler = handler_with_data_dir(dir.path());
|
||||
|
||||
// Create subdirectory
|
||||
std::fs::create_dir_all(dir.path().join("backups")).unwrap();
|
||||
|
||||
let result = handler.validate_path("backups/mydb.bak");
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should allow path in subdirectory: {:?}",
|
||||
result
|
||||
);
|
||||
assert!(result.unwrap().starts_with(dir.path()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_allows_new_file_in_data_dir() {
|
||||
let dir = tempdir().unwrap();
|
||||
let handler = handler_with_data_dir(dir.path());
|
||||
|
||||
let result = handler.validate_path("new_database.db");
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should allow new file in data dir: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_blocks_absolute_path_to_etc() {
|
||||
let dir = tempdir().unwrap();
|
||||
let handler = handler_with_data_dir(dir.path());
|
||||
|
||||
// Test all 3 POCs from the issue
|
||||
for path in &["/etc/passwd", "/etc/shadow", "/etc/hosts"] {
|
||||
let result = handler.validate_path(path);
|
||||
assert!(result.is_err(), "Should block {}", path);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_blocks_home_ssh_keys() {
|
||||
let dir = tempdir().unwrap();
|
||||
let handler = handler_with_data_dir(dir.path());
|
||||
|
||||
let result = handler.validate_path("~/.ssh/id_rsa");
|
||||
// This is a relative path so it won't expand ~, but test the principle
|
||||
let result2 = handler.validate_path("/root/.ssh/id_rsa");
|
||||
assert!(result2.is_err(), "Should block /root/.ssh/id_rsa");
|
||||
}
|
||||
}
|
||||
11
vendor/ruvector/crates/ruvector-cli/src/mcp/mod.rs
vendored
Normal file
11
vendor/ruvector/crates/ruvector-cli/src/mcp/mod.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
//! Model Context Protocol (MCP) implementation for Ruvector
|
||||
|
||||
pub mod gnn_cache;
|
||||
pub mod handlers;
|
||||
pub mod protocol;
|
||||
pub mod transport;
|
||||
|
||||
pub use gnn_cache::*;
|
||||
pub use handlers::*;
|
||||
pub use protocol::*;
|
||||
pub use transport::*;
|
||||
238
vendor/ruvector/crates/ruvector-cli/src/mcp/protocol.rs
vendored
Normal file
238
vendor/ruvector/crates/ruvector-cli/src/mcp/protocol.rs
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
//! MCP protocol types and utilities
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// MCP request message
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: Option<Value>,
|
||||
pub method: String,
|
||||
pub params: Option<Value>,
|
||||
}
|
||||
|
||||
/// MCP response message
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<McpError>,
|
||||
}
|
||||
|
||||
/// MCP error
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
}
|
||||
|
||||
impl McpError {
|
||||
pub fn new(code: i32, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message: message.into(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_data(mut self, data: Value) -> Self {
|
||||
self.data = Some(data);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Standard MCP error codes
|
||||
pub mod error_codes {
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
}
|
||||
|
||||
impl McpResponse {
|
||||
pub fn success(id: Option<Value>, result: Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(id: Option<Value>, error: McpError) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP Tool definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpTool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: Value,
|
||||
}
|
||||
|
||||
/// MCP Resource definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpResource {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
#[serde(rename = "mimeType")]
|
||||
pub mime_type: String,
|
||||
}
|
||||
|
||||
/// MCP Prompt definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpPrompt {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub arguments: Option<Vec<PromptArgument>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptArgument {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub required: bool,
|
||||
}
|
||||
|
||||
/// Tool call parameters for vector_db_create
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateDbParams {
|
||||
pub path: String,
|
||||
pub dimensions: usize,
|
||||
#[serde(default)]
|
||||
pub distance_metric: Option<String>,
|
||||
}
|
||||
|
||||
/// Tool call parameters for vector_db_insert
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InsertParams {
|
||||
pub db_path: String,
|
||||
pub vectors: Vec<VectorInsert>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorInsert {
|
||||
pub id: Option<String>,
|
||||
pub vector: Vec<f32>,
|
||||
pub metadata: Option<Value>,
|
||||
}
|
||||
|
||||
/// Tool call parameters for vector_db_search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchParams {
|
||||
pub db_path: String,
|
||||
pub query: Vec<f32>,
|
||||
pub k: usize,
|
||||
pub filter: Option<Value>,
|
||||
}
|
||||
|
||||
/// Tool call parameters for vector_db_stats
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StatsParams {
|
||||
pub db_path: String,
|
||||
}
|
||||
|
||||
/// Tool call parameters for vector_db_backup
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BackupParams {
|
||||
pub db_path: String,
|
||||
pub backup_path: String,
|
||||
}
|
||||
|
||||
// ==================== GNN Tool Parameters ====================
|
||||
|
||||
/// Tool call parameters for gnn_layer_create
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnLayerCreateParams {
|
||||
pub input_dim: usize,
|
||||
pub hidden_dim: usize,
|
||||
pub heads: usize,
|
||||
#[serde(default = "default_dropout")]
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
fn default_dropout() -> f32 {
|
||||
0.1
|
||||
}
|
||||
|
||||
/// Tool call parameters for gnn_forward
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnForwardParams {
|
||||
pub layer_id: String,
|
||||
pub node_embedding: Vec<f64>,
|
||||
pub neighbor_embeddings: Vec<Vec<f64>>,
|
||||
pub edge_weights: Vec<f64>,
|
||||
}
|
||||
|
||||
/// Tool call parameters for gnn_batch_forward
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnBatchForwardParams {
|
||||
pub layer_config: GnnLayerConfigParams,
|
||||
pub operations: Vec<GnnOperationParams>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnLayerConfigParams {
|
||||
pub input_dim: usize,
|
||||
pub hidden_dim: usize,
|
||||
pub heads: usize,
|
||||
#[serde(default = "default_dropout")]
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnOperationParams {
|
||||
pub node_embedding: Vec<f64>,
|
||||
pub neighbor_embeddings: Vec<Vec<f64>>,
|
||||
pub edge_weights: Vec<f64>,
|
||||
}
|
||||
|
||||
/// Tool call parameters for gnn_cache_stats
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnCacheStatsParams {
|
||||
#[serde(default)]
|
||||
pub include_details: bool,
|
||||
}
|
||||
|
||||
/// Tool call parameters for gnn_compress
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnCompressParams {
|
||||
pub embedding: Vec<f64>,
|
||||
pub access_freq: f64,
|
||||
}
|
||||
|
||||
/// Tool call parameters for gnn_decompress
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnDecompressParams {
|
||||
pub compressed_json: String,
|
||||
}
|
||||
|
||||
/// Tool call parameters for gnn_search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GnnSearchParams {
|
||||
pub query: Vec<f64>,
|
||||
pub candidates: Vec<Vec<f64>>,
|
||||
pub k: usize,
|
||||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f64,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f64 {
|
||||
1.0
|
||||
}
|
||||
186
vendor/ruvector/crates/ruvector-cli/src/mcp/transport.rs
vendored
Normal file
186
vendor/ruvector/crates/ruvector-cli/src/mcp/transport.rs
vendored
Normal file
@@ -0,0 +1,186 @@
|
||||
//! MCP transport layers (STDIO and SSE)
|
||||
|
||||
use super::{handlers::McpHandler, protocol::*};
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{header, StatusCode},
|
||||
response::{sse::Event, IntoResponse, Sse},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
|
||||
/// STDIO transport for local MCP communication
|
||||
pub struct StdioTransport {
|
||||
handler: Arc<McpHandler>,
|
||||
}
|
||||
|
||||
impl StdioTransport {
|
||||
pub fn new(handler: Arc<McpHandler>) -> Self {
|
||||
Self { handler }
|
||||
}
|
||||
|
||||
/// Run STDIO transport loop
|
||||
pub async fn run(&self) -> Result<()> {
|
||||
let stdin = tokio::io::stdin();
|
||||
let mut stdout = tokio::io::stdout();
|
||||
let mut reader = BufReader::new(stdin);
|
||||
let mut line = String::new();
|
||||
|
||||
tracing::info!("MCP STDIO transport started");
|
||||
|
||||
loop {
|
||||
line.clear();
|
||||
let n = reader.read_line(&mut line).await?;
|
||||
|
||||
if n == 0 {
|
||||
// EOF
|
||||
break;
|
||||
}
|
||||
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse request
|
||||
let request: McpRequest = match serde_json::from_str(trimmed) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
let error_response = McpResponse::error(
|
||||
None,
|
||||
McpError::new(error_codes::PARSE_ERROR, e.to_string()),
|
||||
);
|
||||
let response_json = serde_json::to_string(&error_response)?;
|
||||
stdout.write_all(response_json.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Handle request
|
||||
let response = self.handler.handle_request(request).await;
|
||||
|
||||
// Send response
|
||||
let response_json = serde_json::to_string(&response)?;
|
||||
stdout.write_all(response_json.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
|
||||
tracing::info!("MCP STDIO transport stopped");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// SSE (Server-Sent Events) transport for HTTP streaming
|
||||
pub struct SseTransport {
|
||||
handler: Arc<McpHandler>,
|
||||
host: String,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
impl SseTransport {
|
||||
pub fn new(handler: Arc<McpHandler>, host: String, port: u16) -> Self {
|
||||
Self {
|
||||
handler,
|
||||
host,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run SSE transport server
|
||||
pub async fn run(&self) -> Result<()> {
|
||||
// Use restrictive CORS: only allow localhost origins by default
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(AllowOrigin::predicate(|origin, _| {
|
||||
if let Ok(origin_str) = origin.to_str() {
|
||||
origin_str.starts_with("http://127.0.0.1")
|
||||
|| origin_str.starts_with("http://localhost")
|
||||
|| origin_str.starts_with("https://127.0.0.1")
|
||||
|| origin_str.starts_with("https://localhost")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}))
|
||||
.allow_methods([axum::http::Method::GET, axum::http::Method::POST])
|
||||
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/mcp", post(mcp_handler))
|
||||
.route("/mcp/sse", get(mcp_sse_handler))
|
||||
.layer(cors)
|
||||
.with_state(self.handler.clone());
|
||||
|
||||
let addr = format!("{}:{}", self.host, self.port);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
|
||||
tracing::info!("MCP SSE transport listening on http://{}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP handlers
|
||||
|
||||
async fn root() -> &'static str {
|
||||
"Ruvector MCP Server"
|
||||
}
|
||||
|
||||
async fn mcp_handler(
|
||||
State(handler): State<Arc<McpHandler>>,
|
||||
Json(request): Json<McpRequest>,
|
||||
) -> Json<McpResponse> {
|
||||
let response = handler.handle_request(request).await;
|
||||
Json(response)
|
||||
}
|
||||
|
||||
async fn mcp_sse_handler(
|
||||
State(handler): State<Arc<McpHandler>>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
|
||||
let stream = async_stream::stream! {
|
||||
// Send initial connection event
|
||||
yield Ok(Event::default().data("connected"));
|
||||
|
||||
// Keep connection alive with periodic pings
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
yield Ok(Event::default().event("ping").data("keep-alive"));
|
||||
}
|
||||
};
|
||||
|
||||
Sse::new(stream).keep_alive(
|
||||
axum::response::sse::KeepAlive::new()
|
||||
.interval(tokio::time::Duration::from_secs(30))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::Config;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stdio_transport_creation() {
|
||||
let config = Config::default();
|
||||
let handler = Arc::new(McpHandler::new(config));
|
||||
let _transport = StdioTransport::new(handler);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_transport_creation() {
|
||||
let config = Config::default();
|
||||
let handler = Arc::new(McpHandler::new(config));
|
||||
let _transport = SseTransport::new(handler, "127.0.0.1".to_string(), 3000);
|
||||
}
|
||||
}
|
||||
90
vendor/ruvector/crates/ruvector-cli/src/mcp_server.rs
vendored
Normal file
90
vendor/ruvector/crates/ruvector-cli/src/mcp_server.rs
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
//! MCP Server for Ruvector - Main entry point
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber;
|
||||
|
||||
mod config;
|
||||
mod mcp;
|
||||
|
||||
use config::Config;
|
||||
use mcp::{
|
||||
handlers::McpHandler,
|
||||
transport::{SseTransport, StdioTransport},
|
||||
};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "ruvector-mcp")]
|
||||
#[command(about = "Ruvector MCP Server", long_about = None)]
|
||||
#[command(version)]
|
||||
struct Cli {
|
||||
/// Configuration file path
|
||||
#[arg(short, long)]
|
||||
config: Option<PathBuf>,
|
||||
|
||||
/// Transport type (stdio or sse)
|
||||
#[arg(short, long, default_value = "stdio")]
|
||||
transport: String,
|
||||
|
||||
/// Host for SSE transport
|
||||
#[arg(long)]
|
||||
host: Option<String>,
|
||||
|
||||
/// Port for SSE transport
|
||||
#[arg(short, long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Enable debug logging
|
||||
#[arg(short, long)]
|
||||
debug: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize logging
|
||||
if cli.debug {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("ruvector=debug")
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("ruvector=info")
|
||||
.init();
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
let config = Config::load(cli.config)?;
|
||||
|
||||
// Create MCP handler
|
||||
let handler = Arc::new(McpHandler::new(config.clone()));
|
||||
|
||||
// Run appropriate transport
|
||||
match cli.transport.as_str() {
|
||||
"stdio" => {
|
||||
tracing::info!("Starting MCP server with STDIO transport");
|
||||
let transport = StdioTransport::new(handler);
|
||||
transport.run().await?;
|
||||
}
|
||||
"sse" => {
|
||||
let host = cli.host.unwrap_or(config.mcp.host.clone());
|
||||
let port = cli.port.unwrap_or(config.mcp.port);
|
||||
|
||||
tracing::info!(
|
||||
"Starting MCP server with SSE transport on {}:{}",
|
||||
host,
|
||||
port
|
||||
);
|
||||
let transport = SseTransport::new(handler, host, port);
|
||||
transport.run().await?;
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!("Invalid transport type: {}", cli.transport));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user