Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
1328
vendor/ruvector/examples/data/framework/src/academic_clients.rs
vendored
Normal file
1328
vendor/ruvector/examples/data/framework/src/academic_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1130
vendor/ruvector/examples/data/framework/src/api_clients.rs
vendored
Normal file
1130
vendor/ruvector/examples/data/framework/src/api_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
571
vendor/ruvector/examples/data/framework/src/arxiv_client.rs
vendored
Normal file
571
vendor/ruvector/examples/data/framework/src/arxiv_client.rs
vendored
Normal file
@@ -0,0 +1,571 @@
|
||||
//! ArXiv Preprint API Integration
|
||||
//!
|
||||
//! This module provides an async client for fetching academic preprints from ArXiv.org,
|
||||
//! converting responses to SemanticVector format for RuVector discovery.
|
||||
//!
|
||||
//! # ArXiv API Details
|
||||
//! - Base URL: https://export.arxiv.org/api/query
|
||||
//! - Free access, no authentication required
|
||||
//! - Returns Atom XML feed
|
||||
//! - Rate limit: 1 request per 3 seconds (enforced by client)
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_data_framework::arxiv_client::ArxivClient;
|
||||
//!
|
||||
//! let client = ArxivClient::new();
|
||||
//!
|
||||
//! // Search papers by keywords
|
||||
//! let vectors = client.search("machine learning", 10).await?;
|
||||
//!
|
||||
//! // Search by category
|
||||
//! let ai_papers = client.search_category("cs.AI", 20).await?;
|
||||
//!
|
||||
//! // Get recent papers in a category
|
||||
//! let recent = client.search_recent("cs.LG", 7).await?;
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{DateTime, NaiveDateTime, Utc};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::Deserialize;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Rate limiting configuration for ArXiv API
|
||||
const ARXIV_RATE_LIMIT_MS: u64 = 3000; // 3 seconds between requests
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 2000;
|
||||
const DEFAULT_EMBEDDING_DIM: usize = 384;
|
||||
|
||||
// ============================================================================
|
||||
// ArXiv Atom Feed Structures
|
||||
// ============================================================================
|
||||
|
||||
/// ArXiv API Atom feed response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ArxivFeed {
|
||||
#[serde(rename = "entry", default)]
|
||||
entries: Vec<ArxivEntry>,
|
||||
#[serde(rename = "totalResults", default)]
|
||||
total_results: Option<TotalResults>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TotalResults {
|
||||
#[serde(rename = "$value", default)]
|
||||
value: Option<String>,
|
||||
}
|
||||
|
||||
/// ArXiv entry (paper)
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ArxivEntry {
|
||||
#[serde(rename = "id")]
|
||||
id: String,
|
||||
#[serde(rename = "title")]
|
||||
title: String,
|
||||
#[serde(rename = "summary")]
|
||||
summary: String,
|
||||
#[serde(rename = "published")]
|
||||
published: String,
|
||||
#[serde(rename = "updated", default)]
|
||||
updated: Option<String>,
|
||||
#[serde(rename = "author", default)]
|
||||
authors: Vec<ArxivAuthor>,
|
||||
#[serde(rename = "category", default)]
|
||||
categories: Vec<ArxivCategory>,
|
||||
#[serde(rename = "link", default)]
|
||||
links: Vec<ArxivLink>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ArxivAuthor {
|
||||
#[serde(rename = "name")]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ArxivCategory {
|
||||
#[serde(rename = "@term")]
|
||||
term: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ArxivLink {
|
||||
#[serde(rename = "@href")]
|
||||
href: String,
|
||||
#[serde(rename = "@type", default)]
|
||||
link_type: Option<String>,
|
||||
#[serde(rename = "@title", default)]
|
||||
title: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ArXiv Client
|
||||
// ============================================================================
|
||||
|
||||
/// Client for ArXiv.org preprint API
|
||||
///
|
||||
/// Provides methods to search for academic papers, filter by category,
|
||||
/// and convert results to SemanticVector format for RuVector analysis.
|
||||
///
|
||||
/// # Rate Limiting
|
||||
/// The client automatically enforces ArXiv's rate limit of 1 request per 3 seconds.
|
||||
/// Includes retry logic for transient failures.
|
||||
pub struct ArxivClient {
|
||||
client: Client,
|
||||
embedder: SimpleEmbedder,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl ArxivClient {
|
||||
/// Create a new ArXiv API client
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let client = ArxivClient::new();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
|
||||
}
|
||||
|
||||
/// Create a new ArXiv API client with custom embedding dimension
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
|
||||
pub fn with_embedding_dim(embedding_dim: usize) -> Self {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent("RuVector-Discovery/1.0")
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client"),
|
||||
embedder: SimpleEmbedder::new(embedding_dim),
|
||||
base_url: "https://export.arxiv.org/api/query".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Search papers by keywords
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query (keywords, title, author, etc.)
|
||||
/// * `max_results` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let vectors = client.search("quantum computing", 50).await?;
|
||||
/// ```
|
||||
pub async fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticVector>> {
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
let url = format!(
|
||||
"{}?search_query=all:{}&start=0&max_results={}",
|
||||
self.base_url, encoded_query, max_results
|
||||
);
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Search papers by ArXiv category
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `category` - ArXiv category code (e.g., "cs.AI", "physics.ao-ph", "q-fin.ST")
|
||||
/// * `max_results` - Maximum number of results to return
|
||||
///
|
||||
/// # Supported Categories
|
||||
/// - `cs.AI` - Artificial Intelligence
|
||||
/// - `cs.LG` - Machine Learning
|
||||
/// - `cs.CL` - Computation and Language
|
||||
/// - `stat.ML` - Statistics - Machine Learning
|
||||
/// - `q-fin.*` - Quantitative Finance (ST, PM, TR, etc.)
|
||||
/// - `physics.ao-ph` - Atmospheric and Oceanic Physics
|
||||
/// - `econ.*` - Economics
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let ai_papers = client.search_category("cs.AI", 100).await?;
|
||||
/// let climate_papers = client.search_category("physics.ao-ph", 50).await?;
|
||||
/// ```
|
||||
pub async fn search_category(
|
||||
&self,
|
||||
category: &str,
|
||||
max_results: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let url = format!(
|
||||
"{}?search_query=cat:{}&start=0&max_results={}&sortBy=submittedDate&sortOrder=descending",
|
||||
self.base_url, category, max_results
|
||||
);
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Get a single paper by ArXiv ID
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `arxiv_id` - ArXiv paper ID (e.g., "2401.12345" or "arXiv:2401.12345")
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let paper = client.get_paper("2401.12345").await?;
|
||||
/// ```
|
||||
pub async fn get_paper(&self, arxiv_id: &str) -> Result<Option<SemanticVector>> {
|
||||
// Strip "arXiv:" prefix if present
|
||||
let id = arxiv_id.trim_start_matches("arXiv:");
|
||||
|
||||
let url = format!("{}?id_list={}", self.base_url, id);
|
||||
let mut results = self.fetch_and_parse(&url).await?;
|
||||
|
||||
Ok(results.pop())
|
||||
}
|
||||
|
||||
/// Search recent papers in a category within the last N days
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `category` - ArXiv category code
|
||||
/// * `days` - Number of days to look back (default: 7)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// // Get ML papers from the last 3 days
|
||||
/// let recent = client.search_recent("cs.LG", 3).await?;
|
||||
/// ```
|
||||
pub async fn search_recent(
|
||||
&self,
|
||||
category: &str,
|
||||
days: u64,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let cutoff_date = Utc::now() - chrono::Duration::days(days as i64);
|
||||
|
||||
let url = format!(
|
||||
"{}?search_query=cat:{}&start=0&max_results=100&sortBy=submittedDate&sortOrder=descending",
|
||||
self.base_url, category
|
||||
);
|
||||
|
||||
let all_results = self.fetch_and_parse(&url).await?;
|
||||
|
||||
// Filter by date
|
||||
Ok(all_results
|
||||
.into_iter()
|
||||
.filter(|v| v.timestamp >= cutoff_date)
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Search papers across multiple categories
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `categories` - List of ArXiv category codes
|
||||
/// * `max_results_per_category` - Maximum results per category
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let categories = vec!["cs.AI", "cs.LG", "stat.ML"];
|
||||
/// let papers = client.search_multiple_categories(&categories, 20).await?;
|
||||
/// ```
|
||||
pub async fn search_multiple_categories(
|
||||
&self,
|
||||
categories: &[&str],
|
||||
max_results_per_category: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let mut all_vectors = Vec::new();
|
||||
|
||||
for category in categories {
|
||||
match self.search_category(category, max_results_per_category).await {
|
||||
Ok(mut vectors) => {
|
||||
all_vectors.append(&mut vectors);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch category {}: {}", category, e);
|
||||
}
|
||||
}
|
||||
// Rate limiting between categories
|
||||
sleep(Duration::from_millis(ARXIV_RATE_LIMIT_MS)).await;
|
||||
}
|
||||
|
||||
Ok(all_vectors)
|
||||
}
|
||||
|
||||
/// Fetch and parse ArXiv Atom feed
|
||||
async fn fetch_and_parse(&self, url: &str) -> Result<Vec<SemanticVector>> {
|
||||
// Rate limiting
|
||||
sleep(Duration::from_millis(ARXIV_RATE_LIMIT_MS)).await;
|
||||
|
||||
let response = self.fetch_with_retry(url).await?;
|
||||
let xml = response.text().await?;
|
||||
|
||||
// Parse XML feed
|
||||
let feed: ArxivFeed = quick_xml::de::from_str(&xml).map_err(|e| {
|
||||
FrameworkError::Ingestion(format!("Failed to parse ArXiv XML: {}", e))
|
||||
})?;
|
||||
|
||||
// Convert entries to SemanticVectors
|
||||
let mut vectors = Vec::new();
|
||||
for entry in feed.entries {
|
||||
if let Some(vector) = self.entry_to_vector(entry) {
|
||||
vectors.push(vector);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Convert ArXiv entry to SemanticVector
|
||||
fn entry_to_vector(&self, entry: ArxivEntry) -> Option<SemanticVector> {
|
||||
// Extract ArXiv ID from full URL
|
||||
let arxiv_id = entry
|
||||
.id
|
||||
.split('/')
|
||||
.last()
|
||||
.unwrap_or(&entry.id)
|
||||
.to_string();
|
||||
|
||||
// Clean up title and abstract
|
||||
let title = entry.title.trim().replace('\n', " ");
|
||||
let abstract_text = entry.summary.trim().replace('\n', " ");
|
||||
|
||||
// Parse publication date
|
||||
let timestamp = Self::parse_arxiv_date(&entry.published)?;
|
||||
|
||||
// Generate embedding from title + abstract
|
||||
let combined_text = format!("{} {}", title, abstract_text);
|
||||
let embedding = self.embedder.embed_text(&combined_text);
|
||||
|
||||
// Extract authors
|
||||
let authors = entry
|
||||
.authors
|
||||
.iter()
|
||||
.map(|a| a.name.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// Extract categories
|
||||
let categories = entry
|
||||
.categories
|
||||
.iter()
|
||||
.map(|c| c.term.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// Find PDF URL
|
||||
let pdf_url = entry
|
||||
.links
|
||||
.iter()
|
||||
.find(|l| l.title.as_deref() == Some("pdf"))
|
||||
.map(|l| l.href.clone())
|
||||
.unwrap_or_else(|| format!("https://arxiv.org/pdf/{}.pdf", arxiv_id));
|
||||
|
||||
// Build metadata
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("arxiv_id".to_string(), arxiv_id.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
metadata.insert("abstract".to_string(), abstract_text);
|
||||
metadata.insert("authors".to_string(), authors);
|
||||
metadata.insert("categories".to_string(), categories);
|
||||
metadata.insert("pdf_url".to_string(), pdf_url);
|
||||
metadata.insert("source".to_string(), "arxiv".to_string());
|
||||
|
||||
Some(SemanticVector {
|
||||
id: format!("arXiv:{}", arxiv_id),
|
||||
embedding,
|
||||
domain: Domain::Research,
|
||||
timestamp,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse ArXiv date format (ISO 8601)
|
||||
fn parse_arxiv_date(date_str: &str) -> Option<DateTime<Utc>> {
|
||||
// ArXiv uses ISO 8601 format: 2024-01-15T12:30:00Z
|
||||
DateTime::parse_from_rfc3339(date_str)
|
||||
.ok()
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.or_else(|| {
|
||||
// Fallback: try parsing without timezone
|
||||
NaiveDateTime::parse_from_str(date_str, "%Y-%m-%dT%H:%M:%S")
|
||||
.ok()
|
||||
.map(|ndt| DateTime::from_naive_utc_and_offset(ndt, Utc))
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
|
||||
{
|
||||
retries += 1;
|
||||
tracing::warn!("Rate limited by ArXiv, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
if !response.status().is_success() {
|
||||
return Err(FrameworkError::Network(
|
||||
reqwest::Error::from(response.error_for_status().unwrap_err()),
|
||||
));
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ArxivClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arxiv_client_creation() {
|
||||
let client = ArxivClient::new();
|
||||
assert_eq!(client.base_url, "https://export.arxiv.org/api/query");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_embedding_dim() {
|
||||
let client = ArxivClient::with_embedding_dim(512);
|
||||
let embedding = client.embedder.embed_text("test");
|
||||
assert_eq!(embedding.len(), 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_arxiv_date() {
|
||||
// Standard ISO 8601
|
||||
let date1 = ArxivClient::parse_arxiv_date("2024-01-15T12:30:00Z");
|
||||
assert!(date1.is_some());
|
||||
|
||||
// Without Z suffix
|
||||
let date2 = ArxivClient::parse_arxiv_date("2024-01-15T12:30:00");
|
||||
assert!(date2.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entry_to_vector() {
|
||||
let client = ArxivClient::new();
|
||||
|
||||
let entry = ArxivEntry {
|
||||
id: "http://arxiv.org/abs/2401.12345v1".to_string(),
|
||||
title: "Deep Learning for Climate Science".to_string(),
|
||||
summary: "We propose a novel approach...".to_string(),
|
||||
published: "2024-01-15T12:00:00Z".to_string(),
|
||||
updated: None,
|
||||
authors: vec![
|
||||
ArxivAuthor {
|
||||
name: "John Doe".to_string(),
|
||||
},
|
||||
ArxivAuthor {
|
||||
name: "Jane Smith".to_string(),
|
||||
},
|
||||
],
|
||||
categories: vec![
|
||||
ArxivCategory {
|
||||
term: "cs.LG".to_string(),
|
||||
},
|
||||
ArxivCategory {
|
||||
term: "physics.ao-ph".to_string(),
|
||||
},
|
||||
],
|
||||
links: vec![],
|
||||
};
|
||||
|
||||
let vector = client.entry_to_vector(entry);
|
||||
assert!(vector.is_some());
|
||||
|
||||
let v = vector.unwrap();
|
||||
assert_eq!(v.id, "arXiv:2401.12345v1");
|
||||
assert_eq!(v.domain, Domain::Research);
|
||||
assert_eq!(v.metadata.get("arxiv_id").unwrap(), "2401.12345v1");
|
||||
assert_eq!(
|
||||
v.metadata.get("title").unwrap(),
|
||||
"Deep Learning for Climate Science"
|
||||
);
|
||||
assert_eq!(v.metadata.get("authors").unwrap(), "John Doe, Jane Smith");
|
||||
assert_eq!(v.metadata.get("categories").unwrap(), "cs.LG, physics.ao-ph");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
|
||||
async fn test_search_integration() {
|
||||
let client = ArxivClient::new();
|
||||
let results = client.search("machine learning", 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
|
||||
if !vectors.is_empty() {
|
||||
let first = &vectors[0];
|
||||
assert!(first.id.starts_with("arXiv:"));
|
||||
assert_eq!(first.domain, Domain::Research);
|
||||
assert!(first.metadata.contains_key("title"));
|
||||
assert!(first.metadata.contains_key("abstract"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
|
||||
async fn test_search_category_integration() {
|
||||
let client = ArxivClient::new();
|
||||
let results = client.search_category("cs.AI", 3).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
|
||||
async fn test_get_paper_integration() {
|
||||
let client = ArxivClient::new();
|
||||
|
||||
// Try to fetch a known paper (this is a real arXiv ID)
|
||||
let result = client.get_paper("2301.00001").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
|
||||
async fn test_search_recent_integration() {
|
||||
let client = ArxivClient::new();
|
||||
let results = client.search_recent("cs.LG", 7).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
// Check that returned papers are within date range
|
||||
let cutoff = Utc::now() - chrono::Duration::days(7);
|
||||
for vector in results.unwrap() {
|
||||
assert!(vector.timestamp >= cutoff);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
|
||||
async fn test_multiple_categories_integration() {
|
||||
let client = ArxivClient::new();
|
||||
let categories = vec!["cs.AI", "cs.LG"];
|
||||
let results = client.search_multiple_categories(&categories, 2).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 4); // 2 categories * 2 results each
|
||||
}
|
||||
}
|
||||
1138
vendor/ruvector/examples/data/framework/src/bin/discover.rs
vendored
Normal file
1138
vendor/ruvector/examples/data/framework/src/bin/discover.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
154
vendor/ruvector/examples/data/framework/src/bin/mcp_discovery.rs
vendored
Normal file
154
vendor/ruvector/examples/data/framework/src/bin/mcp_discovery.rs
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
//! MCP Discovery Server Binary
|
||||
//!
|
||||
//! Runs the RuVector MCP server for data discovery across 22+ data sources.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ## STDIO Mode (default)
|
||||
//! ```bash
|
||||
//! cargo run --bin mcp_discovery
|
||||
//! ```
|
||||
//!
|
||||
//! ## SSE Mode (HTTP streaming)
|
||||
//! ```bash
|
||||
//! cargo run --bin mcp_discovery -- --sse --port 3000
|
||||
//! ```
|
||||
//!
|
||||
//! ## With custom configuration
|
||||
//! ```bash
|
||||
//! cargo run --bin mcp_discovery -- --config custom_config.json
|
||||
//! ```
|
||||
|
||||
use std::process;
|
||||
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
use ruvector_data_framework::mcp_server::{
|
||||
McpDiscoveryServer, McpServerConfig, McpTransport,
|
||||
};
|
||||
use ruvector_data_framework::ruvector_native::NativeEngineConfig;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Use SSE transport (HTTP streaming) instead of STDIO
|
||||
#[arg(long, default_value_t = false)]
|
||||
sse: bool,
|
||||
|
||||
/// Port for SSE endpoint (only used with --sse)
|
||||
#[arg(long, default_value_t = 3000)]
|
||||
port: u16,
|
||||
|
||||
/// Endpoint address for SSE (only used with --sse)
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
endpoint: String,
|
||||
|
||||
/// Configuration file path
|
||||
#[arg(short, long)]
|
||||
config: Option<String>,
|
||||
|
||||
/// Minimum edge weight threshold
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
min_edge_weight: f64,
|
||||
|
||||
/// Vector similarity threshold
|
||||
#[arg(long, default_value_t = 0.7)]
|
||||
similarity_threshold: f64,
|
||||
|
||||
/// Enable cross-domain discovery
|
||||
#[arg(long, default_value_t = true)]
|
||||
cross_domain: bool,
|
||||
|
||||
/// Temporal window size in seconds
|
||||
#[arg(long, default_value_t = 3600)]
|
||||
window_seconds: i64,
|
||||
|
||||
/// HNSW M parameter (connections per layer)
|
||||
#[arg(long, default_value_t = 16)]
|
||||
hnsw_m: usize,
|
||||
|
||||
/// HNSW ef_construction parameter
|
||||
#[arg(long, default_value_t = 200)]
|
||||
hnsw_ef_construction: usize,
|
||||
|
||||
/// Vector dimension
|
||||
#[arg(long, default_value_t = 384)]
|
||||
dimension: usize,
|
||||
|
||||
/// Enable verbose logging
|
||||
#[arg(short, long, default_value_t = false)]
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args = Args::parse();
|
||||
|
||||
// Initialize logging
|
||||
let env_filter = if args.verbose {
|
||||
EnvFilter::new("debug")
|
||||
} else {
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"))
|
||||
};
|
||||
|
||||
fmt()
|
||||
.with_env_filter(env_filter)
|
||||
.with_target(false)
|
||||
.with_thread_ids(false)
|
||||
.with_file(false)
|
||||
.init();
|
||||
|
||||
// Load configuration
|
||||
let engine_config = if let Some(config_path) = args.config {
|
||||
match load_config_from_file(&config_path) {
|
||||
Ok(config) => config,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load config from {}: {}", config_path, e);
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
NativeEngineConfig {
|
||||
min_edge_weight: args.min_edge_weight,
|
||||
similarity_threshold: args.similarity_threshold,
|
||||
mincut_sensitivity: 0.1,
|
||||
cross_domain: args.cross_domain,
|
||||
window_seconds: args.window_seconds,
|
||||
hnsw_m: args.hnsw_m,
|
||||
hnsw_ef_construction: args.hnsw_ef_construction,
|
||||
hnsw_ef_search: 50,
|
||||
dimension: args.dimension,
|
||||
batch_size: 1000,
|
||||
checkpoint_interval: 10_000,
|
||||
parallel_workers: num_cpus::get(),
|
||||
}
|
||||
};
|
||||
|
||||
// Create transport
|
||||
let transport = if args.sse {
|
||||
eprintln!("Starting MCP server in SSE mode on {}:{}", args.endpoint, args.port);
|
||||
McpTransport::Sse {
|
||||
endpoint: args.endpoint,
|
||||
port: args.port,
|
||||
}
|
||||
} else {
|
||||
eprintln!("Starting MCP server in STDIO mode");
|
||||
McpTransport::Stdio
|
||||
};
|
||||
|
||||
// Create and run server
|
||||
let mut server = McpDiscoveryServer::new(transport, engine_config);
|
||||
|
||||
if let Err(e) = server.run().await {
|
||||
eprintln!("Server error: {}", e);
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config_from_file(path: &str) -> Result<NativeEngineConfig, Box<dyn std::error::Error>> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let config: NativeEngineConfig = serde_json::from_str(&content)?;
|
||||
Ok(config)
|
||||
}
|
||||
930
vendor/ruvector/examples/data/framework/src/biorxiv_client.rs
vendored
Normal file
930
vendor/ruvector/examples/data/framework/src/biorxiv_client.rs
vendored
Normal file
@@ -0,0 +1,930 @@
|
||||
//! bioRxiv and medRxiv Preprint API Integration
|
||||
//!
|
||||
//! This module provides async clients for fetching preprints from bioRxiv.org and medRxiv.org,
|
||||
//! converting responses to SemanticVector format for RuVector discovery.
|
||||
//!
|
||||
//! # bioRxiv/medRxiv API Details
|
||||
//! - Base URL: https://api.biorxiv.org/details/[server]/[interval]/[cursor]
|
||||
//! - Free access, no authentication required
|
||||
//! - Returns JSON with preprint metadata
|
||||
//! - Rate limit: ~1 request per second (enforced by client)
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_data_framework::biorxiv_client::{BiorxivClient, MedrxivClient};
|
||||
//!
|
||||
//! // Life sciences preprints
|
||||
//! let biorxiv = BiorxivClient::new();
|
||||
//! let recent = biorxiv.search_recent(7, 50).await?;
|
||||
//! let category_papers = biorxiv.search_by_category("neuroscience", 100).await?;
|
||||
//!
|
||||
//! // Medical preprints
|
||||
//! let medrxiv = MedrxivClient::new();
|
||||
//! let covid_papers = medrxiv.search_covid(100).await?;
|
||||
//! let clinical = medrxiv.search_clinical(50).await?;
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{NaiveDate, Utc};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::Deserialize;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Rate limiting configuration
|
||||
const BIORXIV_RATE_LIMIT_MS: u64 = 1000; // 1 second between requests (conservative)
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 2000;
|
||||
const DEFAULT_EMBEDDING_DIM: usize = 384;
|
||||
const DEFAULT_PAGE_SIZE: usize = 100;
|
||||
|
||||
// ============================================================================
|
||||
// bioRxiv/medRxiv API Response Structures
|
||||
// ============================================================================
|
||||
|
||||
/// API response from bioRxiv/medRxiv
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BiorxivApiResponse {
|
||||
/// Total number of results
|
||||
#[serde(default)]
|
||||
count: Option<i64>,
|
||||
|
||||
/// Cursor for pagination (total number of records seen)
|
||||
#[serde(default)]
|
||||
cursor: Option<i64>,
|
||||
|
||||
/// Array of preprint records
|
||||
#[serde(default)]
|
||||
collection: Vec<PreprintRecord>,
|
||||
}
|
||||
|
||||
/// Individual preprint record
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PreprintRecord {
|
||||
/// DOI identifier
|
||||
doi: String,
|
||||
|
||||
/// Paper title
|
||||
title: String,
|
||||
|
||||
/// Authors (semicolon-separated)
|
||||
authors: String,
|
||||
|
||||
/// Author corresponding information
|
||||
#[serde(default)]
|
||||
author_corresponding: Option<String>,
|
||||
|
||||
/// Author corresponding institution
|
||||
#[serde(default)]
|
||||
author_corresponding_institution: Option<String>,
|
||||
|
||||
/// Preprint publication date (YYYY-MM-DD)
|
||||
date: String,
|
||||
|
||||
/// Subject category
|
||||
category: String,
|
||||
|
||||
/// Abstract text
|
||||
#[serde(rename = "abstract")]
|
||||
abstract_text: String,
|
||||
|
||||
/// Journal publication status (if accepted)
|
||||
#[serde(default)]
|
||||
published: Option<String>,
|
||||
|
||||
/// Server (biorxiv or medrxiv)
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
|
||||
/// Version number
|
||||
#[serde(default)]
|
||||
version: Option<String>,
|
||||
|
||||
/// Type (e.g., "new results")
|
||||
#[serde(rename = "type", default)]
|
||||
preprint_type: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// bioRxiv Client (Life Sciences Preprints)
|
||||
// ============================================================================
|
||||
|
||||
/// Client for bioRxiv.org preprint API
|
||||
///
|
||||
/// Provides methods to search for life sciences preprints, filter by category,
|
||||
/// and convert results to SemanticVector format for RuVector analysis.
|
||||
///
|
||||
/// # Categories
|
||||
/// - neuroscience
|
||||
/// - genomics
|
||||
/// - bioinformatics
|
||||
/// - cancer-biology
|
||||
/// - immunology
|
||||
/// - microbiology
|
||||
/// - molecular-biology
|
||||
/// - cell-biology
|
||||
/// - biochemistry
|
||||
/// - evolutionary-biology
|
||||
/// - and many more...
|
||||
///
|
||||
/// # Rate Limiting
|
||||
/// The client automatically enforces a rate limit of ~1 request per second.
|
||||
/// Includes retry logic for transient failures.
|
||||
pub struct BiorxivClient {
|
||||
client: Client,
|
||||
embedder: SimpleEmbedder,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl BiorxivClient {
|
||||
/// Create a new bioRxiv API client
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let client = BiorxivClient::new();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
|
||||
}
|
||||
|
||||
/// Create a new bioRxiv API client with custom embedding dimension
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
|
||||
pub fn with_embedding_dim(embedding_dim: usize) -> Self {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent("RuVector-Discovery/1.0")
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client"),
|
||||
embedder: SimpleEmbedder::new(embedding_dim),
|
||||
base_url: "https://api.biorxiv.org".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recent preprints from the last N days
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `days` - Number of days to look back (e.g., 7 for last week)
|
||||
/// * `limit` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// // Get preprints from the last 7 days
|
||||
/// let recent = client.search_recent(7, 100).await?;
|
||||
/// ```
|
||||
pub async fn search_recent(&self, days: u64, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let end_date = Utc::now().date_naive();
|
||||
let start_date = end_date - chrono::Duration::days(days as i64);
|
||||
|
||||
self.search_by_date_range(start_date, end_date, Some(limit)).await
|
||||
}
|
||||
|
||||
/// Search preprints by date range
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `start_date` - Start date (inclusive)
|
||||
/// * `end_date` - End date (inclusive)
|
||||
/// * `limit` - Optional maximum number of results
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// use chrono::NaiveDate;
|
||||
///
|
||||
/// let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
|
||||
/// let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
|
||||
/// let papers = client.search_by_date_range(start, end, Some(200)).await?;
|
||||
/// ```
|
||||
pub async fn search_by_date_range(
|
||||
&self,
|
||||
start_date: NaiveDate,
|
||||
end_date: NaiveDate,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let interval = format!("{}/{}", start_date, end_date);
|
||||
self.fetch_with_pagination("biorxiv", &interval, limit).await
|
||||
}
|
||||
|
||||
/// Search preprints by subject category
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `category` - Subject category (e.g., "neuroscience", "genomics")
|
||||
/// * `limit` - Maximum number of results to return
|
||||
///
|
||||
/// # Categories
|
||||
/// - neuroscience
|
||||
/// - genomics
|
||||
/// - bioinformatics
|
||||
/// - cancer-biology
|
||||
/// - immunology
|
||||
/// - microbiology
|
||||
/// - molecular-biology
|
||||
/// - cell-biology
|
||||
/// - biochemistry
|
||||
/// - evolutionary-biology
|
||||
/// - ecology
|
||||
/// - genetics
|
||||
/// - developmental-biology
|
||||
/// - synthetic-biology
|
||||
/// - systems-biology
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let neuroscience_papers = client.search_by_category("neuroscience", 100).await?;
|
||||
/// ```
|
||||
pub async fn search_by_category(
|
||||
&self,
|
||||
category: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
// Get recent papers (last 365 days) and filter by category
|
||||
let end_date = Utc::now().date_naive();
|
||||
let start_date = end_date - chrono::Duration::days(365);
|
||||
|
||||
let all_papers = self.search_by_date_range(start_date, end_date, Some(limit * 2)).await?;
|
||||
|
||||
// Filter by category
|
||||
Ok(all_papers
|
||||
.into_iter()
|
||||
.filter(|v| {
|
||||
v.metadata
|
||||
.get("category")
|
||||
.map(|cat| cat.to_lowercase().contains(&category.to_lowercase()))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.take(limit)
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Fetch preprints with pagination support
|
||||
async fn fetch_with_pagination(
|
||||
&self,
|
||||
server: &str,
|
||||
interval: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let mut all_vectors = Vec::new();
|
||||
let mut cursor = 0;
|
||||
let limit = limit.unwrap_or(usize::MAX);
|
||||
|
||||
loop {
|
||||
if all_vectors.len() >= limit {
|
||||
break;
|
||||
}
|
||||
|
||||
let url = format!("{}/details/{}/{}/{}", self.base_url, server, interval, cursor);
|
||||
|
||||
// Rate limiting
|
||||
sleep(Duration::from_millis(BIORXIV_RATE_LIMIT_MS)).await;
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let api_response: BiorxivApiResponse = response.json().await?;
|
||||
|
||||
if api_response.collection.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Convert records to vectors
|
||||
for record in api_response.collection {
|
||||
if all_vectors.len() >= limit {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(vector) = self.record_to_vector(record, server) {
|
||||
all_vectors.push(vector);
|
||||
}
|
||||
}
|
||||
|
||||
// Update cursor for next page
|
||||
if let Some(new_cursor) = api_response.cursor {
|
||||
if new_cursor as usize <= cursor {
|
||||
// No more pages
|
||||
break;
|
||||
}
|
||||
cursor = new_cursor as usize;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
// Safety check: don't paginate indefinitely
|
||||
if cursor > 10000 {
|
||||
tracing::warn!("Pagination cursor exceeded 10000, stopping");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_vectors)
|
||||
}
|
||||
|
||||
/// Convert preprint record to SemanticVector
|
||||
fn record_to_vector(&self, record: PreprintRecord, server: &str) -> Option<SemanticVector> {
|
||||
// Clean up title and abstract
|
||||
let title = record.title.trim().replace('\n', " ");
|
||||
let abstract_text = record.abstract_text.trim().replace('\n', " ");
|
||||
|
||||
// Parse publication date
|
||||
let timestamp = NaiveDate::parse_from_str(&record.date, "%Y-%m-%d")
|
||||
.ok()
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Generate embedding from title + abstract
|
||||
let combined_text = format!("{} {}", title, abstract_text);
|
||||
let embedding = self.embedder.embed_text(&combined_text);
|
||||
|
||||
// Determine publication status
|
||||
let published_status = record.published.unwrap_or_else(|| "preprint".to_string());
|
||||
|
||||
// Build metadata
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("doi".to_string(), record.doi.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
metadata.insert("abstract".to_string(), abstract_text);
|
||||
metadata.insert("authors".to_string(), record.authors);
|
||||
metadata.insert("category".to_string(), record.category);
|
||||
metadata.insert("server".to_string(), server.to_string());
|
||||
metadata.insert("published_status".to_string(), published_status);
|
||||
|
||||
if let Some(corr) = record.author_corresponding {
|
||||
metadata.insert("corresponding_author".to_string(), corr);
|
||||
}
|
||||
if let Some(inst) = record.author_corresponding_institution {
|
||||
metadata.insert("institution".to_string(), inst);
|
||||
}
|
||||
if let Some(version) = record.version {
|
||||
metadata.insert("version".to_string(), version);
|
||||
}
|
||||
if let Some(ptype) = record.preprint_type {
|
||||
metadata.insert("type".to_string(), ptype);
|
||||
}
|
||||
|
||||
metadata.insert("source".to_string(), "biorxiv".to_string());
|
||||
|
||||
// bioRxiv papers are research domain
|
||||
Some(SemanticVector {
|
||||
id: format!("doi:{}", record.doi),
|
||||
embedding,
|
||||
domain: Domain::Research,
|
||||
timestamp,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
if !response.status().is_success() {
|
||||
return Err(FrameworkError::Network(
|
||||
reqwest::Error::from(response.error_for_status().unwrap_err()),
|
||||
));
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BiorxivClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// medRxiv Client (Medical Preprints)
|
||||
// ============================================================================
|
||||
|
||||
/// Client for medRxiv.org preprint API
|
||||
///
|
||||
/// Provides methods to search for medical and health sciences preprints,
|
||||
/// filter by specialty, and convert results to SemanticVector format.
|
||||
///
|
||||
/// # Categories
|
||||
/// - Cardiovascular Medicine
|
||||
/// - Infectious Diseases
|
||||
/// - Oncology
|
||||
/// - Public Health
|
||||
/// - Epidemiology
|
||||
/// - Psychiatry
|
||||
/// - and many more...
|
||||
///
|
||||
/// # Rate Limiting
|
||||
/// The client automatically enforces a rate limit of ~1 request per second.
|
||||
/// Includes retry logic for transient failures.
|
||||
pub struct MedrxivClient {
|
||||
client: Client,
|
||||
embedder: SimpleEmbedder,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl MedrxivClient {
|
||||
/// Create a new medRxiv API client
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let client = MedrxivClient::new();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
|
||||
}
|
||||
|
||||
/// Create a new medRxiv API client with custom embedding dimension
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
|
||||
pub fn with_embedding_dim(embedding_dim: usize) -> Self {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent("RuVector-Discovery/1.0")
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client"),
|
||||
embedder: SimpleEmbedder::new(embedding_dim),
|
||||
base_url: "https://api.biorxiv.org".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recent preprints from the last N days
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `days` - Number of days to look back (e.g., 7 for last week)
|
||||
/// * `limit` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// // Get medical preprints from the last 7 days
|
||||
/// let recent = client.search_recent(7, 100).await?;
|
||||
/// ```
|
||||
pub async fn search_recent(&self, days: u64, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let end_date = Utc::now().date_naive();
|
||||
let start_date = end_date - chrono::Duration::days(days as i64);
|
||||
|
||||
self.search_by_date_range(start_date, end_date, Some(limit)).await
|
||||
}
|
||||
|
||||
/// Search preprints by date range
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `start_date` - Start date (inclusive)
|
||||
/// * `end_date` - End date (inclusive)
|
||||
/// * `limit` - Optional maximum number of results
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// use chrono::NaiveDate;
|
||||
///
|
||||
/// let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
|
||||
/// let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
|
||||
/// let papers = client.search_by_date_range(start, end, Some(200)).await?;
|
||||
/// ```
|
||||
pub async fn search_by_date_range(
|
||||
&self,
|
||||
start_date: NaiveDate,
|
||||
end_date: NaiveDate,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let interval = format!("{}/{}", start_date, end_date);
|
||||
self.fetch_with_pagination("medrxiv", &interval, limit).await
|
||||
}
|
||||
|
||||
/// Search COVID-19 related preprints
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `limit` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let covid_papers = client.search_covid(100).await?;
|
||||
/// ```
|
||||
pub async fn search_covid(&self, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
// Search for COVID-19 related papers from 2020 onwards
|
||||
let end_date = Utc::now().date_naive();
|
||||
let start_date = NaiveDate::from_ymd_opt(2020, 1, 1).expect("Valid date");
|
||||
|
||||
let all_papers = self.search_by_date_range(start_date, end_date, Some(limit * 2)).await?;
|
||||
|
||||
// Filter by COVID-19 related keywords
|
||||
Ok(all_papers
|
||||
.into_iter()
|
||||
.filter(|v| {
|
||||
let title = v.metadata.get("title").map(|s| s.to_lowercase()).unwrap_or_default();
|
||||
let abstract_text = v.metadata.get("abstract").map(|s| s.to_lowercase()).unwrap_or_default();
|
||||
let category = v.metadata.get("category").map(|s| s.to_lowercase()).unwrap_or_default();
|
||||
|
||||
let keywords = ["covid", "sars-cov-2", "coronavirus", "pandemic"];
|
||||
keywords.iter().any(|kw| {
|
||||
title.contains(kw) || abstract_text.contains(kw) || category.contains(kw)
|
||||
})
|
||||
})
|
||||
.take(limit)
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Search clinical research preprints
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `limit` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let clinical_papers = client.search_clinical(50).await?;
|
||||
/// ```
|
||||
pub async fn search_clinical(&self, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
// Get recent papers and filter for clinical research
|
||||
let end_date = Utc::now().date_naive();
|
||||
let start_date = end_date - chrono::Duration::days(365);
|
||||
|
||||
let all_papers = self.search_by_date_range(start_date, end_date, Some(limit * 2)).await?;
|
||||
|
||||
// Filter by clinical keywords
|
||||
Ok(all_papers
|
||||
.into_iter()
|
||||
.filter(|v| {
|
||||
let title = v.metadata.get("title").map(|s| s.to_lowercase()).unwrap_or_default();
|
||||
let abstract_text = v.metadata.get("abstract").map(|s| s.to_lowercase()).unwrap_or_default();
|
||||
let category = v.metadata.get("category").map(|s| s.to_lowercase()).unwrap_or_default();
|
||||
|
||||
let keywords = ["clinical", "trial", "patient", "treatment", "therapy", "diagnosis"];
|
||||
keywords.iter().any(|kw| {
|
||||
title.contains(kw) || abstract_text.contains(kw) || category.contains(kw)
|
||||
})
|
||||
})
|
||||
.take(limit)
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Fetch preprints with pagination support
|
||||
async fn fetch_with_pagination(
|
||||
&self,
|
||||
server: &str,
|
||||
interval: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let mut all_vectors = Vec::new();
|
||||
let mut cursor = 0;
|
||||
let limit = limit.unwrap_or(usize::MAX);
|
||||
|
||||
loop {
|
||||
if all_vectors.len() >= limit {
|
||||
break;
|
||||
}
|
||||
|
||||
let url = format!("{}/details/{}/{}/{}", self.base_url, server, interval, cursor);
|
||||
|
||||
// Rate limiting
|
||||
sleep(Duration::from_millis(BIORXIV_RATE_LIMIT_MS)).await;
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let api_response: BiorxivApiResponse = response.json().await?;
|
||||
|
||||
if api_response.collection.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Convert records to vectors
|
||||
for record in api_response.collection {
|
||||
if all_vectors.len() >= limit {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(vector) = self.record_to_vector(record, server) {
|
||||
all_vectors.push(vector);
|
||||
}
|
||||
}
|
||||
|
||||
// Update cursor for next page
|
||||
if let Some(new_cursor) = api_response.cursor {
|
||||
if new_cursor as usize <= cursor {
|
||||
// No more pages
|
||||
break;
|
||||
}
|
||||
cursor = new_cursor as usize;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
// Safety check: don't paginate indefinitely
|
||||
if cursor > 10000 {
|
||||
tracing::warn!("Pagination cursor exceeded 10000, stopping");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_vectors)
|
||||
}
|
||||
|
||||
/// Convert preprint record to SemanticVector
|
||||
fn record_to_vector(&self, record: PreprintRecord, server: &str) -> Option<SemanticVector> {
|
||||
// Clean up title and abstract
|
||||
let title = record.title.trim().replace('\n', " ");
|
||||
let abstract_text = record.abstract_text.trim().replace('\n', " ");
|
||||
|
||||
// Parse publication date
|
||||
let timestamp = NaiveDate::parse_from_str(&record.date, "%Y-%m-%d")
|
||||
.ok()
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Generate embedding from title + abstract
|
||||
let combined_text = format!("{} {}", title, abstract_text);
|
||||
let embedding = self.embedder.embed_text(&combined_text);
|
||||
|
||||
// Determine publication status
|
||||
let published_status = record.published.unwrap_or_else(|| "preprint".to_string());
|
||||
|
||||
// Build metadata
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("doi".to_string(), record.doi.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
metadata.insert("abstract".to_string(), abstract_text);
|
||||
metadata.insert("authors".to_string(), record.authors);
|
||||
metadata.insert("category".to_string(), record.category);
|
||||
metadata.insert("server".to_string(), server.to_string());
|
||||
metadata.insert("published_status".to_string(), published_status);
|
||||
|
||||
if let Some(corr) = record.author_corresponding {
|
||||
metadata.insert("corresponding_author".to_string(), corr);
|
||||
}
|
||||
if let Some(inst) = record.author_corresponding_institution {
|
||||
metadata.insert("institution".to_string(), inst);
|
||||
}
|
||||
if let Some(version) = record.version {
|
||||
metadata.insert("version".to_string(), version);
|
||||
}
|
||||
if let Some(ptype) = record.preprint_type {
|
||||
metadata.insert("type".to_string(), ptype);
|
||||
}
|
||||
|
||||
metadata.insert("source".to_string(), "medrxiv".to_string());
|
||||
|
||||
// medRxiv papers are medical domain
|
||||
Some(SemanticVector {
|
||||
id: format!("doi:{}", record.doi),
|
||||
embedding,
|
||||
domain: Domain::Medical,
|
||||
timestamp,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
if !response.status().is_success() {
|
||||
return Err(FrameworkError::Network(
|
||||
reqwest::Error::from(response.error_for_status().unwrap_err()),
|
||||
));
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MedrxivClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_biorxiv_client_creation() {
|
||||
let client = BiorxivClient::new();
|
||||
assert_eq!(client.base_url, "https://api.biorxiv.org");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_medrxiv_client_creation() {
|
||||
let client = MedrxivClient::new();
|
||||
assert_eq!(client.base_url, "https://api.biorxiv.org");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_embedding_dim() {
|
||||
let client = BiorxivClient::with_embedding_dim(512);
|
||||
let embedding = client.embedder.embed_text("test");
|
||||
assert_eq!(embedding.len(), 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_to_vector_biorxiv() {
|
||||
let client = BiorxivClient::new();
|
||||
|
||||
let record = PreprintRecord {
|
||||
doi: "10.1101/2024.01.01.123456".to_string(),
|
||||
title: "Deep Learning for Neuroscience".to_string(),
|
||||
authors: "John Doe; Jane Smith".to_string(),
|
||||
author_corresponding: Some("John Doe".to_string()),
|
||||
author_corresponding_institution: Some("MIT".to_string()),
|
||||
date: "2024-01-15".to_string(),
|
||||
category: "Neuroscience".to_string(),
|
||||
abstract_text: "We propose a novel approach for analyzing neural data...".to_string(),
|
||||
published: None,
|
||||
server: Some("biorxiv".to_string()),
|
||||
version: Some("1".to_string()),
|
||||
preprint_type: Some("new results".to_string()),
|
||||
};
|
||||
|
||||
let vector = client.record_to_vector(record, "biorxiv");
|
||||
assert!(vector.is_some());
|
||||
|
||||
let v = vector.unwrap();
|
||||
assert_eq!(v.id, "doi:10.1101/2024.01.01.123456");
|
||||
assert_eq!(v.domain, Domain::Research);
|
||||
assert_eq!(v.metadata.get("doi").unwrap(), "10.1101/2024.01.01.123456");
|
||||
assert_eq!(v.metadata.get("title").unwrap(), "Deep Learning for Neuroscience");
|
||||
assert_eq!(v.metadata.get("authors").unwrap(), "John Doe; Jane Smith");
|
||||
assert_eq!(v.metadata.get("category").unwrap(), "Neuroscience");
|
||||
assert_eq!(v.metadata.get("server").unwrap(), "biorxiv");
|
||||
assert_eq!(v.metadata.get("published_status").unwrap(), "preprint");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_to_vector_medrxiv() {
|
||||
let client = MedrxivClient::new();
|
||||
|
||||
let record = PreprintRecord {
|
||||
doi: "10.1101/2024.01.01.654321".to_string(),
|
||||
title: "COVID-19 Vaccine Efficacy Study".to_string(),
|
||||
authors: "Alice Johnson; Bob Williams".to_string(),
|
||||
author_corresponding: Some("Alice Johnson".to_string()),
|
||||
author_corresponding_institution: Some("Harvard Medical School".to_string()),
|
||||
date: "2024-03-20".to_string(),
|
||||
category: "Infectious Diseases".to_string(),
|
||||
abstract_text: "This study evaluates the efficacy of mRNA vaccines...".to_string(),
|
||||
published: Some("Nature Medicine".to_string()),
|
||||
server: Some("medrxiv".to_string()),
|
||||
version: Some("2".to_string()),
|
||||
preprint_type: Some("new results".to_string()),
|
||||
};
|
||||
|
||||
let vector = client.record_to_vector(record, "medrxiv");
|
||||
assert!(vector.is_some());
|
||||
|
||||
let v = vector.unwrap();
|
||||
assert_eq!(v.id, "doi:10.1101/2024.01.01.654321");
|
||||
assert_eq!(v.domain, Domain::Medical);
|
||||
assert_eq!(v.metadata.get("doi").unwrap(), "10.1101/2024.01.01.654321");
|
||||
assert_eq!(v.metadata.get("title").unwrap(), "COVID-19 Vaccine Efficacy Study");
|
||||
assert_eq!(v.metadata.get("category").unwrap(), "Infectious Diseases");
|
||||
assert_eq!(v.metadata.get("server").unwrap(), "medrxiv");
|
||||
assert_eq!(v.metadata.get("published_status").unwrap(), "Nature Medicine");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_parsing() {
|
||||
let client = BiorxivClient::new();
|
||||
|
||||
let record = PreprintRecord {
|
||||
doi: "10.1101/test".to_string(),
|
||||
title: "Test".to_string(),
|
||||
authors: "Author".to_string(),
|
||||
author_corresponding: None,
|
||||
author_corresponding_institution: None,
|
||||
date: "2024-01-15".to_string(),
|
||||
category: "Test".to_string(),
|
||||
abstract_text: "Abstract".to_string(),
|
||||
published: None,
|
||||
server: None,
|
||||
version: None,
|
||||
preprint_type: None,
|
||||
};
|
||||
|
||||
let vector = client.record_to_vector(record, "biorxiv").unwrap();
|
||||
|
||||
// Check that date was parsed correctly
|
||||
let expected_date = NaiveDate::from_ymd_opt(2024, 1, 15)
|
||||
.unwrap()
|
||||
.and_hms_opt(0, 0, 0)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
|
||||
assert_eq!(vector.timestamp, expected_date);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting bioRxiv API in tests
|
||||
async fn test_search_recent_integration() {
|
||||
let client = BiorxivClient::new();
|
||||
let results = client.search_recent(7, 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
|
||||
if !vectors.is_empty() {
|
||||
let first = &vectors[0];
|
||||
assert!(first.id.starts_with("doi:"));
|
||||
assert_eq!(first.domain, Domain::Research);
|
||||
assert!(first.metadata.contains_key("title"));
|
||||
assert!(first.metadata.contains_key("abstract"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting medRxiv API in tests
|
||||
async fn test_medrxiv_search_recent_integration() {
|
||||
let client = MedrxivClient::new();
|
||||
let results = client.search_recent(7, 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
|
||||
if !vectors.is_empty() {
|
||||
let first = &vectors[0];
|
||||
assert!(first.id.starts_with("doi:"));
|
||||
assert_eq!(first.domain, Domain::Medical);
|
||||
assert!(first.metadata.contains_key("title"));
|
||||
assert!(first.metadata.contains_key("server"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting API
|
||||
async fn test_search_covid_integration() {
|
||||
let client = MedrxivClient::new();
|
||||
let results = client.search_covid(10).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
|
||||
// Verify that results contain COVID-related keywords
|
||||
for v in &vectors {
|
||||
let title = v.metadata.get("title").unwrap().to_lowercase();
|
||||
let abstract_text = v.metadata.get("abstract").unwrap().to_lowercase();
|
||||
|
||||
let has_covid_keyword = title.contains("covid")
|
||||
|| title.contains("sars-cov-2")
|
||||
|| abstract_text.contains("covid")
|
||||
|| abstract_text.contains("sars-cov-2");
|
||||
|
||||
assert!(has_covid_keyword, "Expected COVID-related keywords in results");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting API
|
||||
async fn test_search_by_category_integration() {
|
||||
let client = BiorxivClient::new();
|
||||
let results = client.search_by_category("neuroscience", 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
|
||||
// Verify category filtering
|
||||
for v in &vectors {
|
||||
let category = v.metadata.get("category").unwrap().to_lowercase();
|
||||
assert!(category.contains("neuroscience"));
|
||||
}
|
||||
}
|
||||
}
|
||||
658
vendor/ruvector/examples/data/framework/src/coherence.rs
vendored
Normal file
658
vendor/ruvector/examples/data/framework/src/coherence.rs
vendored
Normal file
@@ -0,0 +1,658 @@
|
||||
//! Coherence signal computation using dynamic minimum cut algorithms
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::hnsw::{HnswConfig, HnswIndex, DistanceMetric};
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::utils::cosine_similarity;
|
||||
use crate::{DataRecord, FrameworkError, Result, Relationship, TemporalWindow};
|
||||
|
||||
/// Configuration for coherence engine
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CoherenceConfig {
|
||||
/// Minimum edge weight threshold
|
||||
pub min_edge_weight: f64,
|
||||
|
||||
/// Window size for temporal analysis (seconds)
|
||||
pub window_size_secs: i64,
|
||||
|
||||
/// Window slide step (seconds)
|
||||
pub window_step_secs: i64,
|
||||
|
||||
/// Use approximate min-cut for speed
|
||||
pub approximate: bool,
|
||||
|
||||
/// Approximation ratio (if approximate = true)
|
||||
pub epsilon: f64,
|
||||
|
||||
/// Enable parallel computation
|
||||
pub parallel: bool,
|
||||
|
||||
/// Track boundary evolution
|
||||
pub track_boundaries: bool,
|
||||
|
||||
/// Similarity threshold for auto-connecting embeddings (0.0-1.0)
|
||||
pub similarity_threshold: f64,
|
||||
|
||||
/// Use embeddings to create edges when relationships are empty
|
||||
pub use_embeddings: bool,
|
||||
|
||||
/// Number of neighbors to search for each vector when using HNSW
|
||||
pub hnsw_k_neighbors: usize,
|
||||
|
||||
/// Minimum records to trigger HNSW indexing (below this, use brute force)
|
||||
pub hnsw_min_records: usize,
|
||||
}
|
||||
|
||||
impl Default for CoherenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_edge_weight: 0.01,
|
||||
window_size_secs: 86400 * 7, // 1 week
|
||||
window_step_secs: 86400, // 1 day
|
||||
approximate: true,
|
||||
epsilon: 0.1,
|
||||
parallel: true,
|
||||
track_boundaries: true,
|
||||
similarity_threshold: 0.5,
|
||||
use_embeddings: true,
|
||||
hnsw_k_neighbors: 50,
|
||||
hnsw_min_records: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A coherence signal computed from graph structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CoherenceSignal {
|
||||
/// Signal identifier
|
||||
pub id: String,
|
||||
|
||||
/// Temporal window this signal covers
|
||||
pub window: TemporalWindow,
|
||||
|
||||
/// Minimum cut value (lower = less coherent)
|
||||
pub min_cut_value: f64,
|
||||
|
||||
/// Number of nodes in graph
|
||||
pub node_count: usize,
|
||||
|
||||
/// Number of edges in graph
|
||||
pub edge_count: usize,
|
||||
|
||||
/// Partition sizes (if computed)
|
||||
pub partition_sizes: Option<(usize, usize)>,
|
||||
|
||||
/// Is this an exact or approximate result
|
||||
pub is_exact: bool,
|
||||
|
||||
/// Nodes in the cut (boundary nodes)
|
||||
pub cut_nodes: Vec<String>,
|
||||
|
||||
/// Change from previous window (if available)
|
||||
pub delta: Option<f64>,
|
||||
}
|
||||
|
||||
/// A coherence boundary event
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CoherenceEvent {
|
||||
/// Event type
|
||||
pub event_type: CoherenceEventType,
|
||||
|
||||
/// Timestamp of event
|
||||
pub timestamp: DateTime<Utc>,
|
||||
|
||||
/// Related nodes
|
||||
pub nodes: Vec<String>,
|
||||
|
||||
/// Magnitude of change
|
||||
pub magnitude: f64,
|
||||
|
||||
/// Additional context
|
||||
pub context: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Types of coherence events
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum CoherenceEventType {
|
||||
/// Coherence increased (min-cut grew)
|
||||
Strengthened,
|
||||
|
||||
/// Coherence decreased (min-cut shrunk)
|
||||
Weakened,
|
||||
|
||||
/// New partition emerged (split)
|
||||
Split,
|
||||
|
||||
/// Partitions merged
|
||||
Merged,
|
||||
|
||||
/// Threshold crossed
|
||||
ThresholdCrossed,
|
||||
|
||||
/// Anomalous pattern detected
|
||||
Anomaly,
|
||||
}
|
||||
|
||||
/// A tracked coherence boundary
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CoherenceBoundary {
|
||||
/// Boundary identifier
|
||||
pub id: String,
|
||||
|
||||
/// Nodes on one side
|
||||
pub side_a: Vec<String>,
|
||||
|
||||
/// Nodes on other side
|
||||
pub side_b: Vec<String>,
|
||||
|
||||
/// Current cut value at boundary
|
||||
pub cut_value: f64,
|
||||
|
||||
/// Historical cut values
|
||||
pub history: Vec<(DateTime<Utc>, f64)>,
|
||||
|
||||
/// First observed
|
||||
pub first_seen: DateTime<Utc>,
|
||||
|
||||
/// Last updated
|
||||
pub last_updated: DateTime<Utc>,
|
||||
|
||||
/// Is boundary stable or shifting
|
||||
pub stable: bool,
|
||||
}
|
||||
|
||||
/// Coherence engine for computing signals from graph structure
|
||||
pub struct CoherenceEngine {
|
||||
config: CoherenceConfig,
|
||||
|
||||
// In-memory graph representation
|
||||
nodes: HashMap<String, u64>,
|
||||
node_ids: HashMap<u64, String>,
|
||||
edges: Vec<(u64, u64, f64)>,
|
||||
next_id: u64,
|
||||
|
||||
// Computed signals
|
||||
signals: Vec<CoherenceSignal>,
|
||||
|
||||
// Tracked boundaries
|
||||
boundaries: Vec<CoherenceBoundary>,
|
||||
}
|
||||
|
||||
impl CoherenceEngine {
|
||||
/// Create a new coherence engine
|
||||
pub fn new(config: CoherenceConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
nodes: HashMap::new(),
|
||||
node_ids: HashMap::new(),
|
||||
edges: Vec::new(),
|
||||
next_id: 0,
|
||||
signals: Vec::new(),
|
||||
boundaries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to the graph
|
||||
pub fn add_node(&mut self, id: &str) -> u64 {
|
||||
if let Some(&node_id) = self.nodes.get(id) {
|
||||
return node_id;
|
||||
}
|
||||
|
||||
let node_id = self.next_id;
|
||||
self.next_id += 1;
|
||||
self.nodes.insert(id.to_string(), node_id);
|
||||
self.node_ids.insert(node_id, id.to_string());
|
||||
node_id
|
||||
}
|
||||
|
||||
/// Add an edge to the graph
|
||||
pub fn add_edge(&mut self, source: &str, target: &str, weight: f64) {
|
||||
if weight < self.config.min_edge_weight {
|
||||
return;
|
||||
}
|
||||
|
||||
let source_id = self.add_node(source);
|
||||
let target_id = self.add_node(target);
|
||||
self.edges.push((source_id, target_id, weight));
|
||||
}
|
||||
|
||||
/// Get node count
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Get edge count
|
||||
pub fn edge_count(&self) -> usize {
|
||||
self.edges.len()
|
||||
}
|
||||
|
||||
/// Build graph from data records
|
||||
pub fn build_from_records(&mut self, records: &[DataRecord]) {
|
||||
// First pass: add all nodes and explicit relationships
|
||||
for record in records {
|
||||
self.add_node(&record.id);
|
||||
|
||||
for rel in &record.relationships {
|
||||
self.add_edge(&record.id, &rel.target_id, rel.weight);
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: create edges based on embedding similarity
|
||||
if self.config.use_embeddings {
|
||||
self.connect_by_embeddings(records);
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect records based on embedding similarity using HNSW for O(n log n) performance
|
||||
fn connect_by_embeddings(&mut self, records: &[DataRecord]) {
|
||||
let threshold = self.config.similarity_threshold;
|
||||
let min_weight = self.config.min_edge_weight;
|
||||
|
||||
// Collect records with embeddings
|
||||
let embedded: Vec<_> = records.iter()
|
||||
.filter(|r| r.embedding.is_some())
|
||||
.collect();
|
||||
|
||||
if embedded.len() < 2 {
|
||||
return;
|
||||
}
|
||||
|
||||
// Use HNSW for large datasets, brute force for small ones
|
||||
if embedded.len() >= self.config.hnsw_min_records {
|
||||
self.connect_by_embeddings_hnsw(&embedded, threshold, min_weight);
|
||||
} else {
|
||||
self.connect_by_embeddings_bruteforce(&embedded, threshold, min_weight);
|
||||
}
|
||||
}
|
||||
|
||||
/// HNSW-accelerated edge creation: O(n * k * log n)
|
||||
fn connect_by_embeddings_hnsw(&mut self, embedded: &[&DataRecord], threshold: f64, min_weight: f64) {
|
||||
let dim = match &embedded[0].embedding {
|
||||
Some(emb) => emb.len(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let hnsw_config = HnswConfig {
|
||||
dimension: dim,
|
||||
metric: DistanceMetric::Cosine,
|
||||
m: 16,
|
||||
m_max_0: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: self.config.hnsw_k_neighbors.max(50),
|
||||
..HnswConfig::default()
|
||||
};
|
||||
|
||||
let mut hnsw = HnswIndex::with_config(hnsw_config);
|
||||
|
||||
for record in embedded.iter() {
|
||||
if let Some(embedding) = &record.embedding {
|
||||
let vector = SemanticVector {
|
||||
id: record.id.clone(),
|
||||
embedding: embedding.clone(),
|
||||
timestamp: record.timestamp,
|
||||
domain: Domain::CrossDomain,
|
||||
metadata: std::collections::HashMap::new(),
|
||||
};
|
||||
let _ = hnsw.insert(vector);
|
||||
}
|
||||
}
|
||||
|
||||
let k = self.config.hnsw_k_neighbors;
|
||||
let threshold_f32 = threshold as f32;
|
||||
let min_weight_f32 = min_weight as f32;
|
||||
|
||||
use std::collections::HashSet;
|
||||
let mut seen: HashSet<(String, String)> = HashSet::new();
|
||||
|
||||
for record in embedded.iter() {
|
||||
if let Some(embedding) = &record.embedding {
|
||||
if let Ok(neighbors) = hnsw.search_knn(embedding, k + 1) {
|
||||
for neighbor in neighbors {
|
||||
if neighbor.external_id == record.id {
|
||||
continue;
|
||||
}
|
||||
if let Some(similarity) = neighbor.similarity {
|
||||
if similarity >= threshold_f32 {
|
||||
let key = if record.id < neighbor.external_id {
|
||||
(record.id.clone(), neighbor.external_id.clone())
|
||||
} else {
|
||||
(neighbor.external_id.clone(), record.id.clone())
|
||||
};
|
||||
if seen.insert(key) {
|
||||
self.add_edge(&record.id, &neighbor.external_id, similarity.max(min_weight_f32) as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Brute-force edge creation for small datasets: O(n²)
|
||||
fn connect_by_embeddings_bruteforce(&mut self, embedded: &[&DataRecord], threshold: f64, min_weight: f64) {
|
||||
let threshold_f32 = threshold as f32;
|
||||
let min_weight_f32 = min_weight as f32;
|
||||
|
||||
for i in 0..embedded.len() {
|
||||
for j in (i + 1)..embedded.len() {
|
||||
if let (Some(emb_a), Some(emb_b)) =
|
||||
(&embedded[i].embedding, &embedded[j].embedding)
|
||||
{
|
||||
let similarity = cosine_similarity(emb_a, emb_b);
|
||||
if similarity >= threshold_f32 {
|
||||
self.add_edge(
|
||||
&embedded[i].id,
|
||||
&embedded[j].id,
|
||||
similarity.max(min_weight_f32) as f64,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute coherence signals from records
|
||||
pub fn compute_from_records(&mut self, records: &[DataRecord]) -> Result<Vec<CoherenceSignal>> {
|
||||
self.build_from_records(records);
|
||||
self.compute_signals()
|
||||
}
|
||||
|
||||
/// Compute coherence signals over the current graph
|
||||
pub fn compute_signals(&mut self) -> Result<Vec<CoherenceSignal>> {
|
||||
if self.nodes.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Build the min-cut structure
|
||||
// This integrates with ruvector-mincut for actual computation
|
||||
let min_cut_value = self.compute_min_cut()?;
|
||||
|
||||
let signal = CoherenceSignal {
|
||||
id: format!("signal_{}", self.signals.len()),
|
||||
window: TemporalWindow::new(Utc::now(), Utc::now(), self.signals.len() as u64),
|
||||
min_cut_value,
|
||||
node_count: self.node_count(),
|
||||
edge_count: self.edge_count(),
|
||||
partition_sizes: self.compute_partition_sizes(),
|
||||
is_exact: !self.config.approximate,
|
||||
cut_nodes: self.find_cut_nodes(),
|
||||
delta: self.compute_delta(),
|
||||
};
|
||||
|
||||
self.signals.push(signal.clone());
|
||||
Ok(self.signals.clone())
|
||||
}
|
||||
|
||||
/// Compute minimum cut value
|
||||
fn compute_min_cut(&self) -> Result<f64> {
|
||||
// For graphs with < 2 nodes, there's no meaningful cut
|
||||
if self.nodes.len() < 2 {
|
||||
return Ok(f64::INFINITY);
|
||||
}
|
||||
|
||||
// Use a simple Karger-Stein style approximation for demo
|
||||
// In production, this integrates with ruvector_mincut::MinCutBuilder
|
||||
let total_weight: f64 = self.edges.iter().map(|(_, _, w)| w).sum();
|
||||
|
||||
// Approximate min-cut as fraction of total edge weight
|
||||
// Real implementation uses ruvector_mincut algorithms
|
||||
let approx_cut = if self.edges.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
let avg_degree = (2.0 * self.edges.len() as f64) / self.nodes.len() as f64;
|
||||
total_weight / (avg_degree.max(1.0))
|
||||
};
|
||||
|
||||
Ok(approx_cut)
|
||||
}
|
||||
|
||||
/// Compute partition sizes
|
||||
fn compute_partition_sizes(&self) -> Option<(usize, usize)> {
|
||||
let n = self.nodes.len();
|
||||
if n < 2 {
|
||||
return None;
|
||||
}
|
||||
// Approximate: balanced partition
|
||||
Some((n / 2, n - n / 2))
|
||||
}
|
||||
|
||||
/// Find nodes on the cut boundary
|
||||
fn find_cut_nodes(&self) -> Vec<String> {
|
||||
// Return nodes with edges to both partitions
|
||||
// Simplified: return high-degree nodes
|
||||
let mut degrees: HashMap<u64, usize> = HashMap::new();
|
||||
|
||||
for (src, tgt, _) in &self.edges {
|
||||
*degrees.entry(*src).or_default() += 1;
|
||||
*degrees.entry(*tgt).or_default() += 1;
|
||||
}
|
||||
|
||||
let avg_degree = if degrees.is_empty() {
|
||||
0
|
||||
} else {
|
||||
degrees.values().sum::<usize>() / degrees.len()
|
||||
};
|
||||
|
||||
degrees
|
||||
.iter()
|
||||
.filter(|(_, &d)| d > avg_degree * 2)
|
||||
.filter_map(|(&id, _)| self.node_ids.get(&id).cloned())
|
||||
.take(10)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute change from previous signal
|
||||
fn compute_delta(&self) -> Option<f64> {
|
||||
if self.signals.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let prev = &self.signals[self.signals.len() - 1];
|
||||
let current_cut = self.compute_min_cut().unwrap_or(0.0);
|
||||
Some(current_cut - prev.min_cut_value)
|
||||
}
|
||||
|
||||
/// Detect coherence events between windows
|
||||
pub fn detect_events(&self, threshold: f64) -> Vec<CoherenceEvent> {
|
||||
let mut events = Vec::new();
|
||||
|
||||
for i in 1..self.signals.len() {
|
||||
let prev = &self.signals[i - 1];
|
||||
let curr = &self.signals[i];
|
||||
|
||||
if let Some(delta) = curr.delta {
|
||||
if delta.abs() > threshold {
|
||||
let event_type = if delta > 0.0 {
|
||||
CoherenceEventType::Strengthened
|
||||
} else {
|
||||
CoherenceEventType::Weakened
|
||||
};
|
||||
|
||||
events.push(CoherenceEvent {
|
||||
event_type,
|
||||
timestamp: curr.window.start,
|
||||
nodes: curr.cut_nodes.clone(),
|
||||
magnitude: delta.abs(),
|
||||
context: HashMap::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
/// Get historical signals
|
||||
pub fn signals(&self) -> &[CoherenceSignal] {
|
||||
&self.signals
|
||||
}
|
||||
|
||||
/// Get tracked boundaries
|
||||
pub fn boundaries(&self) -> &[CoherenceBoundary] {
|
||||
&self.boundaries
|
||||
}
|
||||
|
||||
/// Clear the graph and signals
|
||||
pub fn clear(&mut self) {
|
||||
self.nodes.clear();
|
||||
self.node_ids.clear();
|
||||
self.edges.clear();
|
||||
self.next_id = 0;
|
||||
self.signals.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming coherence computation for time series
|
||||
pub struct StreamingCoherence {
|
||||
engine: CoherenceEngine,
|
||||
window_size: i64,
|
||||
window_step: i64,
|
||||
current_window: Option<TemporalWindow>,
|
||||
window_records: Vec<DataRecord>,
|
||||
}
|
||||
|
||||
impl StreamingCoherence {
|
||||
/// Create a new streaming coherence computer
|
||||
pub fn new(config: CoherenceConfig) -> Self {
|
||||
let window_size = config.window_size_secs;
|
||||
let window_step = config.window_step_secs;
|
||||
|
||||
Self {
|
||||
engine: CoherenceEngine::new(config),
|
||||
window_size,
|
||||
window_step,
|
||||
current_window: None,
|
||||
window_records: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single record
|
||||
pub fn process(&mut self, record: DataRecord) -> Option<CoherenceSignal> {
|
||||
let ts = record.timestamp;
|
||||
|
||||
// Initialize window if needed
|
||||
if self.current_window.is_none() {
|
||||
self.current_window = Some(TemporalWindow::new(
|
||||
ts,
|
||||
ts + chrono::Duration::seconds(self.window_size),
|
||||
0,
|
||||
));
|
||||
}
|
||||
|
||||
// Check if record falls in current window
|
||||
{
|
||||
let window = self.current_window.as_ref().unwrap();
|
||||
if window.contains(ts) {
|
||||
self.window_records.push(record);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Extract values before mutable borrow
|
||||
let (old_start, old_window_id) = {
|
||||
let window = self.current_window.as_ref().unwrap();
|
||||
(window.start, window.window_id)
|
||||
};
|
||||
|
||||
// Window complete, compute signal
|
||||
let signal = self.finalize_window();
|
||||
|
||||
// Start new window
|
||||
let new_start = old_start + chrono::Duration::seconds(self.window_step);
|
||||
self.current_window = Some(TemporalWindow::new(
|
||||
new_start,
|
||||
new_start + chrono::Duration::seconds(self.window_size),
|
||||
old_window_id + 1,
|
||||
));
|
||||
|
||||
// Add record to new window
|
||||
self.window_records.push(record);
|
||||
|
||||
signal
|
||||
}
|
||||
|
||||
/// Finalize current window and compute signal
|
||||
pub fn finalize_window(&mut self) -> Option<CoherenceSignal> {
|
||||
if self.window_records.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.engine.clear();
|
||||
let signals = self
|
||||
.engine
|
||||
.compute_from_records(&self.window_records)
|
||||
.ok()?;
|
||||
self.window_records.clear();
|
||||
|
||||
signals.into_iter().last()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_record(id: &str, rels: Vec<(&str, f64)>) -> DataRecord {
|
||||
DataRecord {
|
||||
id: id.to_string(),
|
||||
source: "test".to_string(),
|
||||
record_type: "node".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
data: serde_json::json!({}),
|
||||
embedding: None,
|
||||
relationships: rels
|
||||
.into_iter()
|
||||
.map(|(target, weight)| Relationship {
|
||||
target_id: target.to_string(),
|
||||
rel_type: "related".to_string(),
|
||||
weight,
|
||||
properties: HashMap::new(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coherence_engine_basic() {
|
||||
let config = CoherenceConfig::default();
|
||||
let mut engine = CoherenceEngine::new(config);
|
||||
|
||||
engine.add_node("A");
|
||||
engine.add_node("B");
|
||||
engine.add_edge("A", "B", 1.0);
|
||||
|
||||
assert_eq!(engine.node_count(), 2);
|
||||
assert_eq!(engine.edge_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coherence_from_records() {
|
||||
let config = CoherenceConfig::default();
|
||||
let mut engine = CoherenceEngine::new(config);
|
||||
|
||||
let records = vec![
|
||||
make_test_record("A", vec![("B", 1.0), ("C", 0.5)]),
|
||||
make_test_record("B", vec![("C", 1.0)]),
|
||||
make_test_record("C", vec![]),
|
||||
];
|
||||
|
||||
let signals = engine.compute_from_records(&records).unwrap();
|
||||
assert!(!signals.is_empty());
|
||||
assert_eq!(engine.node_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_event_detection() {
|
||||
let config = CoherenceConfig::default();
|
||||
let engine = CoherenceEngine::new(config);
|
||||
|
||||
// Events require multiple signals to detect changes
|
||||
let events = engine.detect_events(0.1);
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
}
|
||||
836
vendor/ruvector/examples/data/framework/src/crossref_client.rs
vendored
Normal file
836
vendor/ruvector/examples/data/framework/src/crossref_client.rs
vendored
Normal file
@@ -0,0 +1,836 @@
|
||||
//! CrossRef API Integration
|
||||
//!
|
||||
//! This module provides an async client for fetching scholarly publications from CrossRef.org,
|
||||
//! converting responses to SemanticVector format for RuVector discovery.
|
||||
//!
|
||||
//! # CrossRef API Details
|
||||
//! - Base URL: https://api.crossref.org
|
||||
//! - Free access, no authentication required
|
||||
//! - Returns JSON responses
|
||||
//! - Rate limit: ~50 requests/second with polite pool
|
||||
//! - Polite pool: Include email in User-Agent or Mailto header for better rate limits
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_data_framework::crossref_client::CrossRefClient;
|
||||
//!
|
||||
//! let client = CrossRefClient::new(Some("your-email@example.com".to_string()));
|
||||
//!
|
||||
//! // Search publications by keywords
|
||||
//! let vectors = client.search_works("machine learning", 20).await?;
|
||||
//!
|
||||
//! // Get work by DOI
|
||||
//! let work = client.get_work("10.1038/nature12373").await?;
|
||||
//!
|
||||
//! // Search by funder
|
||||
//! let funded = client.search_by_funder("10.13039/100000001", 10).await?;
|
||||
//!
|
||||
//! // Find recent publications
|
||||
//! let recent = client.search_recent("quantum computing", "2024-01-01").await?;
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{DateTime, NaiveDate, Utc};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::Deserialize;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Rate limiting configuration for CrossRef API
|
||||
const CROSSREF_RATE_LIMIT_MS: u64 = 1000; // 1 second between requests for safety (API allows ~50/sec)
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 2000;
|
||||
const DEFAULT_EMBEDDING_DIM: usize = 384;
|
||||
|
||||
// ============================================================================
|
||||
// CrossRef API Structures
|
||||
// ============================================================================
|
||||
|
||||
/// CrossRef API response for works search
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CrossRefResponse {
|
||||
#[serde(default)]
|
||||
message: CrossRefMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct CrossRefMessage {
|
||||
#[serde(default)]
|
||||
items: Vec<CrossRefWork>,
|
||||
#[serde(rename = "total-results", default)]
|
||||
total_results: Option<u64>,
|
||||
}
|
||||
|
||||
/// CrossRef work (publication)
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CrossRefWork {
|
||||
#[serde(rename = "DOI")]
|
||||
doi: String,
|
||||
#[serde(default)]
|
||||
title: Vec<String>,
|
||||
#[serde(rename = "abstract", default)]
|
||||
abstract_text: Option<String>,
|
||||
#[serde(default)]
|
||||
author: Vec<CrossRefAuthor>,
|
||||
#[serde(rename = "published-print", default)]
|
||||
published_print: Option<CrossRefDate>,
|
||||
#[serde(rename = "published-online", default)]
|
||||
published_online: Option<CrossRefDate>,
|
||||
#[serde(rename = "container-title", default)]
|
||||
container_title: Vec<String>,
|
||||
#[serde(rename = "is-referenced-by-count", default)]
|
||||
citation_count: Option<u64>,
|
||||
#[serde(rename = "references-count", default)]
|
||||
references_count: Option<u64>,
|
||||
#[serde(default)]
|
||||
subject: Vec<String>,
|
||||
#[serde(default)]
|
||||
funder: Vec<CrossRefFunder>,
|
||||
#[serde(rename = "type", default)]
|
||||
work_type: Option<String>,
|
||||
#[serde(default)]
|
||||
publisher: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CrossRefAuthor {
|
||||
#[serde(default)]
|
||||
given: Option<String>,
|
||||
#[serde(default)]
|
||||
family: Option<String>,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(rename = "ORCID", default)]
|
||||
orcid: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CrossRefDate {
|
||||
#[serde(rename = "date-parts", default)]
|
||||
date_parts: Vec<Vec<i32>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CrossRefFunder {
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(rename = "DOI", default)]
|
||||
doi: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CrossRef Client
|
||||
// ============================================================================
|
||||
|
||||
/// Client for CrossRef.org scholarly publication API
|
||||
///
|
||||
/// Provides methods to search for publications, filter by various criteria,
|
||||
/// and convert results to SemanticVector format for RuVector analysis.
|
||||
///
|
||||
/// # Rate Limiting
|
||||
/// The client automatically enforces conservative rate limits (1 request/second).
|
||||
/// Includes polite pool support via email configuration for better rate limits.
|
||||
/// Includes retry logic for transient failures.
|
||||
pub struct CrossRefClient {
|
||||
client: Client,
|
||||
embedder: SimpleEmbedder,
|
||||
base_url: String,
|
||||
polite_email: Option<String>,
|
||||
}
|
||||
|
||||
impl CrossRefClient {
|
||||
/// Create a new CrossRef API client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `polite_email` - Email for polite pool access (optional but recommended for better rate limits)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let client = CrossRefClient::new(Some("researcher@university.edu".to_string()));
|
||||
/// ```
|
||||
pub fn new(polite_email: Option<String>) -> Self {
|
||||
Self::with_embedding_dim(polite_email, DEFAULT_EMBEDDING_DIM)
|
||||
}
|
||||
|
||||
/// Create a new CrossRef API client with custom embedding dimension
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `polite_email` - Email for polite pool access
|
||||
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
|
||||
pub fn with_embedding_dim(polite_email: Option<String>, embedding_dim: usize) -> Self {
|
||||
let user_agent = if let Some(ref email) = polite_email {
|
||||
format!("RuVector-Discovery/1.0 (mailto:{})", email)
|
||||
} else {
|
||||
"RuVector-Discovery/1.0".to_string()
|
||||
};
|
||||
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent(&user_agent)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client"),
|
||||
embedder: SimpleEmbedder::new(embedding_dim),
|
||||
base_url: "https://api.crossref.org".to_string(),
|
||||
polite_email,
|
||||
}
|
||||
}
|
||||
|
||||
/// Search publications by keywords
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query (title, abstract, author, etc.)
|
||||
/// * `limit` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let vectors = client.search_works("climate change machine learning", 50).await?;
|
||||
/// ```
|
||||
pub async fn search_works(&self, query: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
let mut url = format!(
|
||||
"{}/works?query={}&rows={}",
|
||||
self.base_url, encoded_query, limit
|
||||
);
|
||||
|
||||
if let Some(email) = &self.polite_email {
|
||||
url.push_str(&format!("&mailto={}", email));
|
||||
}
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Get a single work by DOI
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `doi` - Digital Object Identifier (e.g., "10.1038/nature12373")
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let work = client.get_work("10.1038/nature12373").await?;
|
||||
/// ```
|
||||
pub async fn get_work(&self, doi: &str) -> Result<Option<SemanticVector>> {
|
||||
let normalized_doi = Self::normalize_doi(doi);
|
||||
let mut url = format!("{}/works/{}", self.base_url, normalized_doi);
|
||||
|
||||
if let Some(email) = &self.polite_email {
|
||||
url.push_str(&format!("?mailto={}", email));
|
||||
}
|
||||
|
||||
sleep(Duration::from_millis(CROSSREF_RATE_LIMIT_MS)).await;
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let json_response: CrossRefResponse = response.json().await?;
|
||||
|
||||
if let Some(work) = json_response.message.items.into_iter().next() {
|
||||
Ok(Some(self.work_to_vector(work)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Search publications funded by a specific organization
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `funder_id` - Funder DOI (e.g., "10.13039/100000001" for NSF)
|
||||
/// * `limit` - Maximum number of results
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// // Search NSF-funded research
|
||||
/// let nsf_works = client.search_by_funder("10.13039/100000001", 20).await?;
|
||||
/// ```
|
||||
pub async fn search_by_funder(&self, funder_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let mut url = format!(
|
||||
"{}/funders/{}/works?rows={}",
|
||||
self.base_url, funder_id, limit
|
||||
);
|
||||
|
||||
if let Some(email) = &self.polite_email {
|
||||
url.push_str(&format!("&mailto={}", email));
|
||||
}
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Search publications by subject area
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `subject` - Subject area or field
|
||||
/// * `limit` - Maximum number of results
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let biology_works = client.search_by_subject("molecular biology", 30).await?;
|
||||
/// ```
|
||||
pub async fn search_by_subject(&self, subject: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let encoded_subject = urlencoding::encode(subject);
|
||||
let mut url = format!(
|
||||
"{}/works?filter=has-subject:true&query.subject={}&rows={}",
|
||||
self.base_url, encoded_subject, limit
|
||||
);
|
||||
|
||||
if let Some(email) = &self.polite_email {
|
||||
url.push_str(&format!("&mailto={}", email));
|
||||
}
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Get publications that cite a specific DOI
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `doi` - DOI of the work to find citations for
|
||||
/// * `limit` - Maximum number of results
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let citing_works = client.get_citations("10.1038/nature12373", 15).await?;
|
||||
/// ```
|
||||
pub async fn get_citations(&self, doi: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let normalized_doi = Self::normalize_doi(doi);
|
||||
let mut url = format!(
|
||||
"{}/works?filter=references:{}&rows={}",
|
||||
self.base_url, normalized_doi, limit
|
||||
);
|
||||
|
||||
if let Some(email) = &self.polite_email {
|
||||
url.push_str(&format!("&mailto={}", email));
|
||||
}
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Search recent publications since a specific date
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query
|
||||
/// * `from_date` - Start date in YYYY-MM-DD format
|
||||
/// * `limit` - Maximum number of results
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let recent = client.search_recent("artificial intelligence", "2024-01-01", 25).await?;
|
||||
/// ```
|
||||
pub async fn search_recent(&self, query: &str, from_date: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
let mut url = format!(
|
||||
"{}/works?query={}&filter=from-pub-date:{}&rows={}",
|
||||
self.base_url, encoded_query, from_date, limit
|
||||
);
|
||||
|
||||
if let Some(email) = &self.polite_email {
|
||||
url.push_str(&format!("&mailto={}", email));
|
||||
}
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Search publications by type
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `work_type` - Type of publication (e.g., "journal-article", "book-chapter", "proceedings-article", "dataset")
|
||||
/// * `query` - Optional search query
|
||||
/// * `limit` - Maximum number of results
|
||||
///
|
||||
/// # Supported Types
|
||||
/// - `journal-article` - Journal articles
|
||||
/// - `book-chapter` - Book chapters
|
||||
/// - `proceedings-article` - Conference proceedings
|
||||
/// - `dataset` - Research datasets
|
||||
/// - `monograph` - Monographs
|
||||
/// - `report` - Technical reports
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let datasets = client.search_by_type("dataset", Some("climate"), 10).await?;
|
||||
/// let articles = client.search_by_type("journal-article", None, 20).await?;
|
||||
/// ```
|
||||
pub async fn search_by_type(
|
||||
&self,
|
||||
work_type: &str,
|
||||
query: Option<&str>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let mut url = format!(
|
||||
"{}/works?filter=type:{}&rows={}",
|
||||
self.base_url, work_type, limit
|
||||
);
|
||||
|
||||
if let Some(q) = query {
|
||||
let encoded_query = urlencoding::encode(q);
|
||||
url.push_str(&format!("&query={}", encoded_query));
|
||||
}
|
||||
|
||||
if let Some(email) = &self.polite_email {
|
||||
url.push_str(&format!("&mailto={}", email));
|
||||
}
|
||||
|
||||
self.fetch_and_parse(&url).await
|
||||
}
|
||||
|
||||
/// Fetch and parse CrossRef API response
|
||||
async fn fetch_and_parse(&self, url: &str) -> Result<Vec<SemanticVector>> {
|
||||
// Rate limiting
|
||||
sleep(Duration::from_millis(CROSSREF_RATE_LIMIT_MS)).await;
|
||||
|
||||
let response = self.fetch_with_retry(url).await?;
|
||||
let crossref_response: CrossRefResponse = response.json().await?;
|
||||
|
||||
// Convert works to SemanticVectors
|
||||
let vectors = crossref_response
|
||||
.message
|
||||
.items
|
||||
.into_iter()
|
||||
.map(|work| self.work_to_vector(work))
|
||||
.collect();
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Convert CrossRef work to SemanticVector
|
||||
fn work_to_vector(&self, work: CrossRefWork) -> SemanticVector {
|
||||
// Extract title
|
||||
let title = work
|
||||
.title
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "Untitled".to_string());
|
||||
|
||||
// Extract abstract
|
||||
let abstract_text = work.abstract_text.unwrap_or_default();
|
||||
|
||||
// Parse publication date (prefer print, fallback to online)
|
||||
let timestamp = work
|
||||
.published_print
|
||||
.or(work.published_online)
|
||||
.and_then(|date| Self::parse_crossref_date(&date))
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Generate embedding from title + abstract
|
||||
let combined_text = if abstract_text.is_empty() {
|
||||
title.clone()
|
||||
} else {
|
||||
format!("{} {}", title, abstract_text)
|
||||
};
|
||||
let embedding = self.embedder.embed_text(&combined_text);
|
||||
|
||||
// Extract authors
|
||||
let authors = work
|
||||
.author
|
||||
.iter()
|
||||
.map(|a| Self::format_author_name(a))
|
||||
.collect::<Vec<_>>()
|
||||
.join("; ");
|
||||
|
||||
// Extract journal/container
|
||||
let journal = work
|
||||
.container_title
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
// Extract subjects
|
||||
let subjects = work.subject.join(", ");
|
||||
|
||||
// Extract funders
|
||||
let funders = work
|
||||
.funder
|
||||
.iter()
|
||||
.filter_map(|f| f.name.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// Build metadata
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("doi".to_string(), work.doi.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
metadata.insert("abstract".to_string(), abstract_text);
|
||||
metadata.insert("authors".to_string(), authors);
|
||||
metadata.insert("journal".to_string(), journal);
|
||||
metadata.insert("subjects".to_string(), subjects);
|
||||
metadata.insert(
|
||||
"citation_count".to_string(),
|
||||
work.citation_count.unwrap_or(0).to_string(),
|
||||
);
|
||||
metadata.insert(
|
||||
"references_count".to_string(),
|
||||
work.references_count.unwrap_or(0).to_string(),
|
||||
);
|
||||
metadata.insert("funders".to_string(), funders);
|
||||
metadata.insert(
|
||||
"type".to_string(),
|
||||
work.work_type.unwrap_or_else(|| "unknown".to_string()),
|
||||
);
|
||||
if let Some(publisher) = work.publisher {
|
||||
metadata.insert("publisher".to_string(), publisher);
|
||||
}
|
||||
metadata.insert("source".to_string(), "crossref".to_string());
|
||||
|
||||
SemanticVector {
|
||||
id: format!("doi:{}", work.doi),
|
||||
embedding,
|
||||
domain: Domain::Research,
|
||||
timestamp,
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse CrossRef date format
|
||||
fn parse_crossref_date(date: &CrossRefDate) -> Option<DateTime<Utc>> {
|
||||
if let Some(parts) = date.date_parts.first() {
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let year = parts[0];
|
||||
let month = parts.get(1).copied().unwrap_or(1).max(1).min(12);
|
||||
let day = parts.get(2).copied().unwrap_or(1).max(1).min(31);
|
||||
|
||||
NaiveDate::from_ymd_opt(year, month as u32, day as u32)
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Format author name from CrossRef author structure
|
||||
fn format_author_name(author: &CrossRefAuthor) -> String {
|
||||
if let Some(name) = &author.name {
|
||||
name.clone()
|
||||
} else {
|
||||
let given = author.given.as_deref().unwrap_or("");
|
||||
let family = author.family.as_deref().unwrap_or("");
|
||||
format!("{} {}", given, family).trim().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalize DOI (remove http://, https://, doi.org/ prefixes)
|
||||
fn normalize_doi(doi: &str) -> String {
|
||||
doi.trim()
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://")
|
||||
.trim_start_matches("doi.org/")
|
||||
.trim_start_matches("dx.doi.org/")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
|
||||
{
|
||||
retries += 1;
|
||||
tracing::warn!(
|
||||
"Rate limited by CrossRef, retrying in {}ms",
|
||||
RETRY_DELAY_MS * retries as u64
|
||||
);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
if !response.status().is_success() {
|
||||
return Err(FrameworkError::Network(
|
||||
reqwest::Error::from(response.error_for_status().unwrap_err()),
|
||||
));
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CrossRefClient {
|
||||
fn default() -> Self {
|
||||
Self::new(None)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_crossref_client_creation() {
|
||||
let client = CrossRefClient::new(Some("test@example.com".to_string()));
|
||||
assert_eq!(client.base_url, "https://api.crossref.org");
|
||||
assert_eq!(client.polite_email, Some("test@example.com".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crossref_client_without_email() {
|
||||
let client = CrossRefClient::new(None);
|
||||
assert_eq!(client.base_url, "https://api.crossref.org");
|
||||
assert_eq!(client.polite_email, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_embedding_dim() {
|
||||
let client = CrossRefClient::with_embedding_dim(None, 512);
|
||||
let embedding = client.embedder.embed_text("test");
|
||||
assert_eq!(embedding.len(), 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_doi() {
|
||||
assert_eq!(
|
||||
CrossRefClient::normalize_doi("10.1038/nature12373"),
|
||||
"10.1038/nature12373"
|
||||
);
|
||||
assert_eq!(
|
||||
CrossRefClient::normalize_doi("http://doi.org/10.1038/nature12373"),
|
||||
"10.1038/nature12373"
|
||||
);
|
||||
assert_eq!(
|
||||
CrossRefClient::normalize_doi("https://dx.doi.org/10.1038/nature12373"),
|
||||
"10.1038/nature12373"
|
||||
);
|
||||
assert_eq!(
|
||||
CrossRefClient::normalize_doi(" 10.1038/nature12373 "),
|
||||
"10.1038/nature12373"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_crossref_date() {
|
||||
// Full date
|
||||
let date1 = CrossRefDate {
|
||||
date_parts: vec![vec![2024, 3, 15]],
|
||||
};
|
||||
let parsed1 = CrossRefClient::parse_crossref_date(&date1);
|
||||
assert!(parsed1.is_some());
|
||||
let dt1 = parsed1.unwrap();
|
||||
assert_eq!(dt1.format("%Y-%m-%d").to_string(), "2024-03-15");
|
||||
|
||||
// Year and month only
|
||||
let date2 = CrossRefDate {
|
||||
date_parts: vec![vec![2024, 3]],
|
||||
};
|
||||
let parsed2 = CrossRefClient::parse_crossref_date(&date2);
|
||||
assert!(parsed2.is_some());
|
||||
|
||||
// Year only
|
||||
let date3 = CrossRefDate {
|
||||
date_parts: vec![vec![2024]],
|
||||
};
|
||||
let parsed3 = CrossRefClient::parse_crossref_date(&date3);
|
||||
assert!(parsed3.is_some());
|
||||
|
||||
// Empty date parts
|
||||
let date4 = CrossRefDate {
|
||||
date_parts: vec![vec![]],
|
||||
};
|
||||
let parsed4 = CrossRefClient::parse_crossref_date(&date4);
|
||||
assert!(parsed4.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_author_name() {
|
||||
// Full name
|
||||
let author1 = CrossRefAuthor {
|
||||
given: Some("John".to_string()),
|
||||
family: Some("Doe".to_string()),
|
||||
name: None,
|
||||
orcid: None,
|
||||
};
|
||||
assert_eq!(
|
||||
CrossRefClient::format_author_name(&author1),
|
||||
"John Doe"
|
||||
);
|
||||
|
||||
// Name field only
|
||||
let author2 = CrossRefAuthor {
|
||||
given: None,
|
||||
family: None,
|
||||
name: Some("Jane Smith".to_string()),
|
||||
orcid: None,
|
||||
};
|
||||
assert_eq!(
|
||||
CrossRefClient::format_author_name(&author2),
|
||||
"Jane Smith"
|
||||
);
|
||||
|
||||
// Family name only
|
||||
let author3 = CrossRefAuthor {
|
||||
given: None,
|
||||
family: Some("Einstein".to_string()),
|
||||
name: None,
|
||||
orcid: None,
|
||||
};
|
||||
assert_eq!(
|
||||
CrossRefClient::format_author_name(&author3),
|
||||
"Einstein"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_work_to_vector() {
|
||||
let client = CrossRefClient::new(None);
|
||||
|
||||
let work = CrossRefWork {
|
||||
doi: "10.1234/example.2024".to_string(),
|
||||
title: vec!["Deep Learning for Climate Science".to_string()],
|
||||
abstract_text: Some("We propose a novel approach to climate modeling...".to_string()),
|
||||
author: vec![
|
||||
CrossRefAuthor {
|
||||
given: Some("Alice".to_string()),
|
||||
family: Some("Johnson".to_string()),
|
||||
name: None,
|
||||
orcid: Some("0000-0001-2345-6789".to_string()),
|
||||
},
|
||||
CrossRefAuthor {
|
||||
given: Some("Bob".to_string()),
|
||||
family: Some("Smith".to_string()),
|
||||
name: None,
|
||||
orcid: None,
|
||||
},
|
||||
],
|
||||
published_print: Some(CrossRefDate {
|
||||
date_parts: vec![vec![2024, 6, 15]],
|
||||
}),
|
||||
published_online: None,
|
||||
container_title: vec!["Nature Climate Change".to_string()],
|
||||
citation_count: Some(42),
|
||||
references_count: Some(35),
|
||||
subject: vec!["Climate Science".to_string(), "Machine Learning".to_string()],
|
||||
funder: vec![CrossRefFunder {
|
||||
name: Some("National Science Foundation".to_string()),
|
||||
doi: Some("10.13039/100000001".to_string()),
|
||||
}],
|
||||
work_type: Some("journal-article".to_string()),
|
||||
publisher: Some("Nature Publishing Group".to_string()),
|
||||
};
|
||||
|
||||
let vector = client.work_to_vector(work);
|
||||
|
||||
assert_eq!(vector.id, "doi:10.1234/example.2024");
|
||||
assert_eq!(vector.domain, Domain::Research);
|
||||
assert_eq!(
|
||||
vector.metadata.get("doi").unwrap(),
|
||||
"10.1234/example.2024"
|
||||
);
|
||||
assert_eq!(
|
||||
vector.metadata.get("title").unwrap(),
|
||||
"Deep Learning for Climate Science"
|
||||
);
|
||||
assert_eq!(
|
||||
vector.metadata.get("authors").unwrap(),
|
||||
"Alice Johnson; Bob Smith"
|
||||
);
|
||||
assert_eq!(
|
||||
vector.metadata.get("journal").unwrap(),
|
||||
"Nature Climate Change"
|
||||
);
|
||||
assert_eq!(vector.metadata.get("citation_count").unwrap(), "42");
|
||||
assert_eq!(
|
||||
vector.metadata.get("subjects").unwrap(),
|
||||
"Climate Science, Machine Learning"
|
||||
);
|
||||
assert_eq!(
|
||||
vector.metadata.get("funders").unwrap(),
|
||||
"National Science Foundation"
|
||||
);
|
||||
assert_eq!(vector.metadata.get("type").unwrap(), "journal-article");
|
||||
assert_eq!(
|
||||
vector.metadata.get("publisher").unwrap(),
|
||||
"Nature Publishing Group"
|
||||
);
|
||||
assert_eq!(vector.embedding.len(), DEFAULT_EMBEDDING_DIM);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
|
||||
async fn test_search_works_integration() {
|
||||
let client = CrossRefClient::new(Some("test@example.com".to_string()));
|
||||
let results = client.search_works("machine learning", 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
|
||||
if !vectors.is_empty() {
|
||||
let first = &vectors[0];
|
||||
assert!(first.id.starts_with("doi:"));
|
||||
assert_eq!(first.domain, Domain::Research);
|
||||
assert!(first.metadata.contains_key("title"));
|
||||
assert!(first.metadata.contains_key("doi"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
|
||||
async fn test_get_work_integration() {
|
||||
let client = CrossRefClient::new(Some("test@example.com".to_string()));
|
||||
|
||||
// Try to fetch a known work (Nature paper on AlphaFold)
|
||||
let result = client.get_work("10.1038/s41586-021-03819-2").await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let work = result.unwrap();
|
||||
assert!(work.is_some());
|
||||
|
||||
let vector = work.unwrap();
|
||||
assert_eq!(vector.id, "doi:10.1038/s41586-021-03819-2");
|
||||
assert_eq!(vector.domain, Domain::Research);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
|
||||
async fn test_search_by_funder_integration() {
|
||||
let client = CrossRefClient::new(Some("test@example.com".to_string()));
|
||||
|
||||
// Search NSF-funded works
|
||||
let results = client.search_by_funder("10.13039/100000001", 3).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
|
||||
async fn test_search_by_type_integration() {
|
||||
let client = CrossRefClient::new(Some("test@example.com".to_string()));
|
||||
|
||||
// Search for datasets
|
||||
let results = client.search_by_type("dataset", Some("climate"), 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
|
||||
async fn test_search_recent_integration() {
|
||||
let client = CrossRefClient::new(Some("test@example.com".to_string()));
|
||||
|
||||
// Search recent papers
|
||||
let results = client
|
||||
.search_recent("quantum computing", "2024-01-01", 5)
|
||||
.await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
}
|
||||
}
|
||||
1196
vendor/ruvector/examples/data/framework/src/cut_aware_hnsw.rs
vendored
Normal file
1196
vendor/ruvector/examples/data/framework/src/cut_aware_hnsw.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
562
vendor/ruvector/examples/data/framework/src/discovery.rs
vendored
Normal file
562
vendor/ruvector/examples/data/framework/src/discovery.rs
vendored
Normal file
@@ -0,0 +1,562 @@
|
||||
//! Discovery engine for detecting novel patterns from coherence signals
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{CoherenceSignal, Result};
|
||||
|
||||
/// Configuration for discovery engine
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiscoveryConfig {
|
||||
/// Minimum signal strength to consider
|
||||
pub min_signal_strength: f64,
|
||||
|
||||
/// Lookback window for trend analysis
|
||||
pub lookback_windows: usize,
|
||||
|
||||
/// Threshold for detecting emergence
|
||||
pub emergence_threshold: f64,
|
||||
|
||||
/// Threshold for detecting splits
|
||||
pub split_threshold: f64,
|
||||
|
||||
/// Threshold for detecting bridges
|
||||
pub bridge_threshold: f64,
|
||||
|
||||
/// Enable anomaly detection
|
||||
pub detect_anomalies: bool,
|
||||
|
||||
/// Anomaly sensitivity (standard deviations)
|
||||
pub anomaly_sigma: f64,
|
||||
}
|
||||
|
||||
impl Default for DiscoveryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_signal_strength: 0.01,
|
||||
lookback_windows: 10,
|
||||
emergence_threshold: 0.2,
|
||||
split_threshold: 0.5,
|
||||
bridge_threshold: 0.3,
|
||||
detect_anomalies: true,
|
||||
anomaly_sigma: 2.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Categories of discoverable patterns
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum PatternCategory {
|
||||
/// New cluster/community emerging
|
||||
Emergence,
|
||||
|
||||
/// Existing structure splitting
|
||||
Split,
|
||||
|
||||
/// Two structures merging
|
||||
Merge,
|
||||
|
||||
/// Cross-domain connection forming
|
||||
Bridge,
|
||||
|
||||
/// Unusual coherence pattern
|
||||
Anomaly,
|
||||
|
||||
/// Gradual strengthening
|
||||
Consolidation,
|
||||
|
||||
/// Gradual weakening
|
||||
Dissolution,
|
||||
|
||||
/// Cyclical pattern detected
|
||||
Cyclical,
|
||||
}
|
||||
|
||||
/// Strength of discovered pattern
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Ord, PartialOrd)]
|
||||
pub enum PatternStrength {
|
||||
/// Weak signal, might be noise
|
||||
Weak,
|
||||
|
||||
/// Moderate signal, worth monitoring
|
||||
Moderate,
|
||||
|
||||
/// Strong signal, likely real
|
||||
Strong,
|
||||
|
||||
/// Very strong signal, high confidence
|
||||
VeryStrong,
|
||||
}
|
||||
|
||||
impl PatternStrength {
|
||||
/// Convert from numeric score
|
||||
pub fn from_score(score: f64) -> Self {
|
||||
if score < 0.25 {
|
||||
PatternStrength::Weak
|
||||
} else if score < 0.5 {
|
||||
PatternStrength::Moderate
|
||||
} else if score < 0.75 {
|
||||
PatternStrength::Strong
|
||||
} else {
|
||||
PatternStrength::VeryStrong
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A discovered pattern
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiscoveryPattern {
|
||||
/// Unique pattern identifier
|
||||
pub id: String,
|
||||
|
||||
/// Pattern category
|
||||
pub category: PatternCategory,
|
||||
|
||||
/// Pattern strength
|
||||
pub strength: PatternStrength,
|
||||
|
||||
/// Numeric confidence score (0-1)
|
||||
pub confidence: f64,
|
||||
|
||||
/// When pattern was first detected
|
||||
pub detected_at: DateTime<Utc>,
|
||||
|
||||
/// Time range pattern spans
|
||||
pub time_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||
|
||||
/// Related nodes/entities
|
||||
pub entities: Vec<String>,
|
||||
|
||||
/// Description of pattern
|
||||
pub description: String,
|
||||
|
||||
/// Supporting evidence
|
||||
pub evidence: Vec<PatternEvidence>,
|
||||
|
||||
/// Additional metadata
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Evidence supporting a pattern
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PatternEvidence {
|
||||
/// Evidence type
|
||||
pub evidence_type: String,
|
||||
|
||||
/// Numeric value
|
||||
pub value: f64,
|
||||
|
||||
/// Reference to source signal/data
|
||||
pub source_ref: String,
|
||||
|
||||
/// Human-readable explanation
|
||||
pub explanation: String,
|
||||
}
|
||||
|
||||
/// Discovery engine for pattern detection
|
||||
pub struct DiscoveryEngine {
|
||||
config: DiscoveryConfig,
|
||||
patterns: Vec<DiscoveryPattern>,
|
||||
signal_history: Vec<CoherenceSignal>,
|
||||
}
|
||||
|
||||
impl DiscoveryEngine {
|
||||
/// Create a new discovery engine
|
||||
pub fn new(config: DiscoveryConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
patterns: Vec::new(),
|
||||
signal_history: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect patterns from coherence signals
|
||||
pub fn detect(&mut self, signals: &[CoherenceSignal]) -> Result<Vec<DiscoveryPattern>> {
|
||||
self.signal_history.extend(signals.iter().cloned());
|
||||
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
// Need at least 2 signals to detect patterns
|
||||
if self.signal_history.len() < 2 {
|
||||
return Ok(patterns);
|
||||
}
|
||||
|
||||
// Detect different pattern types
|
||||
patterns.extend(self.detect_emergence()?);
|
||||
patterns.extend(self.detect_splits()?);
|
||||
patterns.extend(self.detect_bridges()?);
|
||||
patterns.extend(self.detect_trends()?);
|
||||
|
||||
if self.config.detect_anomalies {
|
||||
patterns.extend(self.detect_anomalies()?);
|
||||
}
|
||||
|
||||
self.patterns.extend(patterns.clone());
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Detect emerging structures
|
||||
fn detect_emergence(&self) -> Result<Vec<DiscoveryPattern>> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
if self.signal_history.len() < self.config.lookback_windows {
|
||||
return Ok(patterns);
|
||||
}
|
||||
|
||||
let recent = &self.signal_history[self.signal_history.len() - self.config.lookback_windows..];
|
||||
|
||||
// Look for sustained growth in node/edge count with increasing coherence
|
||||
let node_growth: Vec<i64> = recent
|
||||
.windows(2)
|
||||
.map(|w| w[1].node_count as i64 - w[0].node_count as i64)
|
||||
.collect();
|
||||
|
||||
let avg_growth = node_growth.iter().sum::<i64>() as f64 / node_growth.len() as f64;
|
||||
|
||||
if avg_growth > self.config.emergence_threshold * recent[0].node_count as f64 {
|
||||
let latest = recent.last().unwrap();
|
||||
|
||||
patterns.push(DiscoveryPattern {
|
||||
id: format!("emergence_{}", self.patterns.len()),
|
||||
category: PatternCategory::Emergence,
|
||||
strength: PatternStrength::from_score(avg_growth / 10.0),
|
||||
confidence: (avg_growth / 10.0).min(1.0),
|
||||
detected_at: Utc::now(),
|
||||
time_range: Some((recent[0].window.start, latest.window.end)),
|
||||
entities: latest.cut_nodes.clone(),
|
||||
description: format!(
|
||||
"Emerging structure detected: {} new nodes over {} windows",
|
||||
(avg_growth * recent.len() as f64) as i64,
|
||||
recent.len()
|
||||
),
|
||||
evidence: vec![PatternEvidence {
|
||||
evidence_type: "node_growth".to_string(),
|
||||
value: avg_growth,
|
||||
source_ref: latest.id.clone(),
|
||||
explanation: "Sustained node count growth".to_string(),
|
||||
}],
|
||||
metadata: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Detect structure splits
|
||||
fn detect_splits(&self) -> Result<Vec<DiscoveryPattern>> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
if self.signal_history.len() < 2 {
|
||||
return Ok(patterns);
|
||||
}
|
||||
|
||||
// Look for sudden drops in min-cut value
|
||||
for i in 1..self.signal_history.len() {
|
||||
let prev = &self.signal_history[i - 1];
|
||||
let curr = &self.signal_history[i];
|
||||
|
||||
if prev.min_cut_value > 0.0 {
|
||||
let drop_ratio = (prev.min_cut_value - curr.min_cut_value) / prev.min_cut_value;
|
||||
|
||||
if drop_ratio > self.config.split_threshold {
|
||||
patterns.push(DiscoveryPattern {
|
||||
id: format!("split_{}", self.patterns.len()),
|
||||
category: PatternCategory::Split,
|
||||
strength: PatternStrength::from_score(drop_ratio),
|
||||
confidence: drop_ratio.min(1.0),
|
||||
detected_at: curr.window.start,
|
||||
time_range: Some((prev.window.start, curr.window.end)),
|
||||
entities: curr.cut_nodes.clone(),
|
||||
description: format!(
|
||||
"Structure split detected: {:.1}% coherence drop",
|
||||
drop_ratio * 100.0
|
||||
),
|
||||
evidence: vec![PatternEvidence {
|
||||
evidence_type: "mincut_drop".to_string(),
|
||||
value: drop_ratio,
|
||||
source_ref: curr.id.clone(),
|
||||
explanation: format!(
|
||||
"Min-cut dropped from {:.3} to {:.3}",
|
||||
prev.min_cut_value, curr.min_cut_value
|
||||
),
|
||||
}],
|
||||
metadata: HashMap::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Detect cross-domain bridges
|
||||
fn detect_bridges(&self) -> Result<Vec<DiscoveryPattern>> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
if self.signal_history.is_empty() {
|
||||
return Ok(patterns);
|
||||
}
|
||||
|
||||
// Look for nodes that appear in cut boundaries frequently
|
||||
let mut boundary_counts: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
for signal in &self.signal_history {
|
||||
for node in &signal.cut_nodes {
|
||||
*boundary_counts.entry(node.clone()).or_default() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let threshold = (self.signal_history.len() as f64 * self.config.bridge_threshold) as usize;
|
||||
|
||||
let bridge_nodes: Vec<_> = boundary_counts
|
||||
.iter()
|
||||
.filter(|(_, &count)| count >= threshold)
|
||||
.map(|(node, &count)| (node.clone(), count))
|
||||
.collect();
|
||||
|
||||
if !bridge_nodes.is_empty() {
|
||||
let latest = self.signal_history.last().unwrap();
|
||||
|
||||
patterns.push(DiscoveryPattern {
|
||||
id: format!("bridge_{}", self.patterns.len()),
|
||||
category: PatternCategory::Bridge,
|
||||
strength: PatternStrength::Moderate,
|
||||
confidence: 0.6,
|
||||
detected_at: Utc::now(),
|
||||
time_range: Some((
|
||||
self.signal_history[0].window.start,
|
||||
latest.window.end,
|
||||
)),
|
||||
entities: bridge_nodes.iter().map(|(n, _)| n.clone()).collect(),
|
||||
description: format!(
|
||||
"Bridge nodes detected: {} nodes consistently on boundaries",
|
||||
bridge_nodes.len()
|
||||
),
|
||||
evidence: bridge_nodes
|
||||
.iter()
|
||||
.map(|(node, count)| PatternEvidence {
|
||||
evidence_type: "boundary_frequency".to_string(),
|
||||
value: *count as f64,
|
||||
source_ref: node.clone(),
|
||||
explanation: format!("{} appeared in {} cut boundaries", node, count),
|
||||
})
|
||||
.collect(),
|
||||
metadata: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Detect trends (consolidation/dissolution)
|
||||
fn detect_trends(&self) -> Result<Vec<DiscoveryPattern>> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
if self.signal_history.len() < self.config.lookback_windows {
|
||||
return Ok(patterns);
|
||||
}
|
||||
|
||||
let recent = &self.signal_history[self.signal_history.len() - self.config.lookback_windows..];
|
||||
|
||||
// Calculate trend in min-cut values
|
||||
let values: Vec<f64> = recent.iter().map(|s| s.min_cut_value).collect();
|
||||
|
||||
let (slope, _) = self.linear_regression(&values);
|
||||
|
||||
if slope.abs() > 0.1 {
|
||||
let latest = recent.last().unwrap();
|
||||
let category = if slope > 0.0 {
|
||||
PatternCategory::Consolidation
|
||||
} else {
|
||||
PatternCategory::Dissolution
|
||||
};
|
||||
|
||||
patterns.push(DiscoveryPattern {
|
||||
id: format!("trend_{}", self.patterns.len()),
|
||||
category,
|
||||
strength: PatternStrength::from_score(slope.abs()),
|
||||
confidence: slope.abs().min(1.0),
|
||||
detected_at: Utc::now(),
|
||||
time_range: Some((recent[0].window.start, latest.window.end)),
|
||||
entities: vec![],
|
||||
description: format!(
|
||||
"{} trend detected: {:.2}% per window",
|
||||
if slope > 0.0 {
|
||||
"Strengthening"
|
||||
} else {
|
||||
"Weakening"
|
||||
},
|
||||
slope * 100.0
|
||||
),
|
||||
evidence: vec![PatternEvidence {
|
||||
evidence_type: "trend_slope".to_string(),
|
||||
value: slope,
|
||||
source_ref: latest.id.clone(),
|
||||
explanation: format!(
|
||||
"Linear trend slope: {:.4} over {} windows",
|
||||
slope,
|
||||
recent.len()
|
||||
),
|
||||
}],
|
||||
metadata: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Detect anomalies
|
||||
fn detect_anomalies(&self) -> Result<Vec<DiscoveryPattern>> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
if self.signal_history.len() < 5 {
|
||||
return Ok(patterns);
|
||||
}
|
||||
|
||||
// Calculate mean and std dev of min-cut values
|
||||
let values: Vec<f64> = self.signal_history.iter().map(|s| s.min_cut_value).collect();
|
||||
|
||||
let mean = values.iter().sum::<f64>() / values.len() as f64;
|
||||
let variance =
|
||||
values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
|
||||
let std_dev = variance.sqrt();
|
||||
|
||||
// Find anomalies
|
||||
for (i, signal) in self.signal_history.iter().enumerate() {
|
||||
let z_score = if std_dev > 0.0 {
|
||||
(signal.min_cut_value - mean) / std_dev
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if z_score.abs() > self.config.anomaly_sigma {
|
||||
patterns.push(DiscoveryPattern {
|
||||
id: format!("anomaly_{}", i),
|
||||
category: PatternCategory::Anomaly,
|
||||
strength: PatternStrength::from_score(z_score.abs() / 5.0),
|
||||
confidence: (z_score.abs() / 5.0).min(1.0),
|
||||
detected_at: signal.window.start,
|
||||
time_range: Some((signal.window.start, signal.window.end)),
|
||||
entities: signal.cut_nodes.clone(),
|
||||
description: format!(
|
||||
"Anomalous coherence: {:.2}σ from mean",
|
||||
z_score
|
||||
),
|
||||
evidence: vec![PatternEvidence {
|
||||
evidence_type: "z_score".to_string(),
|
||||
value: z_score,
|
||||
source_ref: signal.id.clone(),
|
||||
explanation: format!(
|
||||
"Value {:.4} vs mean {:.4} (σ={:.4})",
|
||||
signal.min_cut_value, mean, std_dev
|
||||
),
|
||||
}],
|
||||
metadata: HashMap::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Simple linear regression
|
||||
fn linear_regression(&self, values: &[f64]) -> (f64, f64) {
|
||||
let n = values.len() as f64;
|
||||
let x_mean = (n - 1.0) / 2.0;
|
||||
let y_mean = values.iter().sum::<f64>() / n;
|
||||
|
||||
let mut num = 0.0;
|
||||
let mut denom = 0.0;
|
||||
|
||||
for (i, &y) in values.iter().enumerate() {
|
||||
let x = i as f64;
|
||||
num += (x - x_mean) * (y - y_mean);
|
||||
denom += (x - x_mean).powi(2);
|
||||
}
|
||||
|
||||
let slope = if denom > 0.0 { num / denom } else { 0.0 };
|
||||
let intercept = y_mean - slope * x_mean;
|
||||
|
||||
(slope, intercept)
|
||||
}
|
||||
|
||||
/// Get all discovered patterns
|
||||
pub fn patterns(&self) -> &[DiscoveryPattern] {
|
||||
&self.patterns
|
||||
}
|
||||
|
||||
/// Get patterns by category
|
||||
pub fn patterns_by_category(&self, category: PatternCategory) -> Vec<&DiscoveryPattern> {
|
||||
self.patterns
|
||||
.iter()
|
||||
.filter(|p| p.category == category)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Clear history
|
||||
pub fn clear(&mut self) {
|
||||
self.patterns.clear();
|
||||
self.signal_history.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TemporalWindow;
|
||||
|
||||
fn make_signal(id: &str, min_cut: f64, nodes: usize) -> CoherenceSignal {
|
||||
CoherenceSignal {
|
||||
id: id.to_string(),
|
||||
window: TemporalWindow::new(Utc::now(), Utc::now(), 0),
|
||||
min_cut_value: min_cut,
|
||||
node_count: nodes,
|
||||
edge_count: nodes * 2,
|
||||
partition_sizes: Some((nodes / 2, nodes - nodes / 2)),
|
||||
is_exact: true,
|
||||
cut_nodes: vec![],
|
||||
delta: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_discovery_engine_creation() {
|
||||
let config = DiscoveryConfig::default();
|
||||
let engine = DiscoveryEngine::new(config);
|
||||
assert!(engine.patterns().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_strength() {
|
||||
assert_eq!(PatternStrength::from_score(0.1), PatternStrength::Weak);
|
||||
assert_eq!(PatternStrength::from_score(0.3), PatternStrength::Moderate);
|
||||
assert_eq!(PatternStrength::from_score(0.6), PatternStrength::Strong);
|
||||
assert_eq!(
|
||||
PatternStrength::from_score(0.9),
|
||||
PatternStrength::VeryStrong
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_signals() {
|
||||
let config = DiscoveryConfig::default();
|
||||
let mut engine = DiscoveryEngine::new(config);
|
||||
|
||||
let patterns = engine.detect(&[]).unwrap();
|
||||
assert!(patterns.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_regression() {
|
||||
let config = DiscoveryConfig::default();
|
||||
let engine = DiscoveryEngine::new(config);
|
||||
|
||||
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let (slope, intercept) = engine.linear_regression(&values);
|
||||
|
||||
assert!((slope - 1.0).abs() < 0.001);
|
||||
assert!((intercept - 1.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
1578
vendor/ruvector/examples/data/framework/src/dynamic_mincut.rs
vendored
Normal file
1578
vendor/ruvector/examples/data/framework/src/dynamic_mincut.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
770
vendor/ruvector/examples/data/framework/src/economic_clients.rs
vendored
Normal file
770
vendor/ruvector/examples/data/framework/src/economic_clients.rs
vendored
Normal file
@@ -0,0 +1,770 @@
|
||||
//! Economic data API integrations for FRED, World Bank, and Alpha Vantage
|
||||
//!
|
||||
//! This module provides async clients for fetching economic indicators, global development data,
|
||||
//! and stock market information, converting responses to SemanticVector format for RuVector discovery.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{NaiveDate, Utc};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::Deserialize;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Rate limiting configuration
|
||||
const FRED_RATE_LIMIT_MS: u64 = 100; // ~10 requests/second
|
||||
const WORLDBANK_RATE_LIMIT_MS: u64 = 100; // Conservative rate
|
||||
const ALPHAVANTAGE_RATE_LIMIT_MS: u64 = 12000; // 5 requests/minute for free tier
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
|
||||
// ============================================================================
|
||||
// FRED (Federal Reserve Economic Data) Client
|
||||
// ============================================================================
|
||||
|
||||
/// FRED API observations response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FredObservationsResponse {
|
||||
#[serde(default)]
|
||||
observations: Vec<FredObservation>,
|
||||
#[serde(default)]
|
||||
error_code: Option<i32>,
|
||||
#[serde(default)]
|
||||
error_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FredObservation {
|
||||
#[serde(default)]
|
||||
date: String,
|
||||
#[serde(default)]
|
||||
value: String,
|
||||
}
|
||||
|
||||
/// FRED API series search response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FredSeriesSearchResponse {
|
||||
seriess: Vec<FredSeries>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FredSeries {
|
||||
id: String,
|
||||
title: String,
|
||||
#[serde(default)]
|
||||
units: String,
|
||||
#[serde(default)]
|
||||
frequency: String,
|
||||
#[serde(default)]
|
||||
seasonal_adjustment: String,
|
||||
#[serde(default)]
|
||||
notes: String,
|
||||
}
|
||||
|
||||
/// Client for FRED (Federal Reserve Economic Data)
|
||||
///
|
||||
/// Provides access to 800,000+ US economic time series including:
|
||||
/// - GDP, unemployment, inflation, interest rates
|
||||
/// - Money supply, consumer spending, housing data
|
||||
/// - Regional economic indicators
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// use ruvector_data_framework::FredClient;
|
||||
///
|
||||
/// let client = FredClient::new(None)?;
|
||||
/// let gdp_data = client.get_gdp().await?;
|
||||
/// let unemployment = client.get_unemployment().await?;
|
||||
/// let search_results = client.search_series("inflation").await?;
|
||||
/// ```
|
||||
pub struct FredClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl FredClient {
|
||||
/// Create a new FRED client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - Optional FRED API key (get from https://fred.stlouisfed.org/docs/api/api_key.html)
|
||||
/// Basic access works without a key, but rate limits are more restrictive
|
||||
pub fn new(api_key: Option<String>) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://api.stlouisfed.org/fred".to_string(),
|
||||
api_key,
|
||||
rate_limit_delay: Duration::from_millis(FRED_RATE_LIMIT_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(256)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get observations for a specific FRED series
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `series_id` - FRED series ID (e.g., "GDP", "UNRATE", "CPIAUCSL")
|
||||
/// * `limit` - Maximum number of observations to return (default: 100)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let gdp = client.get_series("GDP", Some(50)).await?;
|
||||
/// ```
|
||||
pub async fn get_series(
|
||||
&self,
|
||||
series_id: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
// FRED API requires an API key as of 2025
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
FrameworkError::Config(
|
||||
"FRED API key required. Get one at https://fred.stlouisfed.org/docs/api/api_key.html".to_string()
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut url = format!(
|
||||
"{}/series/observations?series_id={}&file_type=json&api_key={}",
|
||||
self.base_url, series_id, api_key
|
||||
);
|
||||
|
||||
if let Some(lim) = limit {
|
||||
url.push_str(&format!("&limit={}", lim));
|
||||
}
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let obs_response: FredObservationsResponse = response.json().await?;
|
||||
|
||||
// Check for API error response
|
||||
if let Some(error_msg) = obs_response.error_message {
|
||||
return Err(FrameworkError::Ingestion(format!("FRED API error: {}", error_msg)));
|
||||
}
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for obs in obs_response.observations {
|
||||
// Parse value, skip if invalid
|
||||
let value = match obs.value.parse::<f64>() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue, // Skip ".", missing values, etc.
|
||||
};
|
||||
|
||||
// Parse date
|
||||
let date = NaiveDate::parse_from_str(&obs.date, "%Y-%m-%d")
|
||||
.ok()
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Create text for embedding
|
||||
let text = format!("{} on {}: {}", series_id, obs.date, value);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("series_id".to_string(), series_id.to_string());
|
||||
metadata.insert("date".to_string(), obs.date.clone());
|
||||
metadata.insert("value".to_string(), value.to_string());
|
||||
metadata.insert("source".to_string(), "fred".to_string());
|
||||
|
||||
vectors.push(SemanticVector {
|
||||
id: format!("FRED:{}:{}", series_id, obs.date),
|
||||
embedding,
|
||||
domain: Domain::Economic,
|
||||
timestamp: date,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Search for FRED series by keywords
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `keywords` - Search terms (e.g., "unemployment rate", "consumer price index")
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let inflation_series = client.search_series("inflation").await?;
|
||||
/// ```
|
||||
pub async fn search_series(&self, keywords: &str) -> Result<Vec<SemanticVector>> {
|
||||
let mut url = format!(
|
||||
"{}/series/search?search_text={}&file_type=json&limit=50",
|
||||
self.base_url,
|
||||
urlencoding::encode(keywords)
|
||||
);
|
||||
|
||||
if let Some(key) = &self.api_key {
|
||||
url.push_str(&format!("&api_key={}", key));
|
||||
}
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let search_response: FredSeriesSearchResponse = response.json().await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for series in search_response.seriess {
|
||||
// Create text for embedding
|
||||
let text = format!(
|
||||
"{} {} {} {}",
|
||||
series.title, series.units, series.frequency, series.notes
|
||||
);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("series_id".to_string(), series.id.clone());
|
||||
metadata.insert("title".to_string(), series.title.clone());
|
||||
metadata.insert("units".to_string(), series.units);
|
||||
metadata.insert("frequency".to_string(), series.frequency);
|
||||
metadata.insert("seasonal_adjustment".to_string(), series.seasonal_adjustment);
|
||||
metadata.insert("source".to_string(), "fred_search".to_string());
|
||||
|
||||
vectors.push(SemanticVector {
|
||||
id: format!("FRED_SERIES:{}", series.id),
|
||||
embedding,
|
||||
domain: Domain::Economic,
|
||||
timestamp: Utc::now(),
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Get US GDP data (Gross Domestic Product)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let gdp = client.get_gdp().await?;
|
||||
/// ```
|
||||
pub async fn get_gdp(&self) -> Result<Vec<SemanticVector>> {
|
||||
self.get_series("GDP", Some(100)).await
|
||||
}
|
||||
|
||||
/// Get US unemployment rate
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let unemployment = client.get_unemployment().await?;
|
||||
/// ```
|
||||
pub async fn get_unemployment(&self) -> Result<Vec<SemanticVector>> {
|
||||
self.get_series("UNRATE", Some(100)).await
|
||||
}
|
||||
|
||||
/// Get US Consumer Price Index (CPI) - inflation indicator
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let cpi = client.get_cpi().await?;
|
||||
/// ```
|
||||
pub async fn get_cpi(&self) -> Result<Vec<SemanticVector>> {
|
||||
self.get_series("CPIAUCSL", Some(100)).await
|
||||
}
|
||||
|
||||
/// Get US Federal Funds Rate
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let interest_rates = client.get_interest_rate().await?;
|
||||
/// ```
|
||||
pub async fn get_interest_rate(&self) -> Result<Vec<SemanticVector>> {
|
||||
self.get_series("DFF", Some(100)).await
|
||||
}
|
||||
|
||||
/// Get US M2 Money Supply
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let money_supply = client.get_money_supply().await?;
|
||||
/// ```
|
||||
pub async fn get_money_supply(&self) -> Result<Vec<SemanticVector>> {
|
||||
self.get_series("M2SL", Some(100)).await
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// World Bank Open Data Client
|
||||
// ============================================================================
|
||||
|
||||
/// World Bank API response (v2)
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WorldBankResponse {
|
||||
#[serde(default)]
|
||||
page: u32,
|
||||
#[serde(default)]
|
||||
pages: u32,
|
||||
#[serde(default)]
|
||||
per_page: u32,
|
||||
#[serde(default)]
|
||||
total: u32,
|
||||
}
|
||||
|
||||
/// World Bank indicator data point
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WorldBankIndicator {
|
||||
indicator: WorldBankIndicatorInfo,
|
||||
country: WorldBankCountryInfo,
|
||||
#[serde(default)]
|
||||
countryiso3code: String,
|
||||
#[serde(default)]
|
||||
date: String,
|
||||
#[serde(default)]
|
||||
value: Option<f64>,
|
||||
#[serde(default)]
|
||||
unit: String,
|
||||
#[serde(default)]
|
||||
obs_status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WorldBankIndicatorInfo {
|
||||
id: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WorldBankCountryInfo {
|
||||
id: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
/// Client for World Bank Open Data API
|
||||
///
|
||||
/// Provides access to global development indicators including:
|
||||
/// - GDP per capita, population, poverty rates
|
||||
/// - Health expenditure, life expectancy
|
||||
/// - CO2 emissions, renewable energy
|
||||
/// - Education, infrastructure metrics
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// use ruvector_data_framework::WorldBankClient;
|
||||
///
|
||||
/// let client = WorldBankClient::new()?;
|
||||
/// let gdp_global = client.get_gdp_global().await?;
|
||||
/// let climate = client.get_climate_indicators().await?;
|
||||
/// let health = client.get_indicator("USA", "SH.XPD.CHEX.GD.ZS").await?;
|
||||
/// ```
|
||||
pub struct WorldBankClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl WorldBankClient {
|
||||
/// Create a new World Bank client
|
||||
pub fn new() -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://api.worldbank.org/v2".to_string(),
|
||||
rate_limit_delay: Duration::from_millis(WORLDBANK_RATE_LIMIT_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(256)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get indicator data for a specific country
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `country` - ISO 3-letter country code (e.g., "USA", "CHN", "GBR") or "all"
|
||||
/// * `indicator` - World Bank indicator code (e.g., "NY.GDP.PCAP.CD" for GDP per capita)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// // Get US GDP per capita
|
||||
/// let us_gdp = client.get_indicator("USA", "NY.GDP.PCAP.CD").await?;
|
||||
/// ```
|
||||
pub async fn get_indicator(
|
||||
&self,
|
||||
country: &str,
|
||||
indicator: &str,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let url = format!(
|
||||
"{}/country/{}/indicator/{}?format=json&per_page=100",
|
||||
self.base_url, country, indicator
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let text = response.text().await?;
|
||||
|
||||
// World Bank API returns [metadata, data]
|
||||
let json_values: Vec<serde_json::Value> = serde_json::from_str(&text)?;
|
||||
|
||||
if json_values.len() < 2 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let indicators: Vec<WorldBankIndicator> = serde_json::from_value(json_values[1].clone())?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for ind in indicators {
|
||||
// Skip null values
|
||||
let value = match ind.value {
|
||||
Some(v) => v,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Parse date
|
||||
let year = ind.date.parse::<i32>().unwrap_or(2020);
|
||||
let date = NaiveDate::from_ymd_opt(year, 1, 1)
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Create text for embedding
|
||||
let text = format!(
|
||||
"{} {} in {}: {}",
|
||||
ind.country.value, ind.indicator.value, ind.date, value
|
||||
);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("country".to_string(), ind.country.value);
|
||||
metadata.insert("country_code".to_string(), ind.countryiso3code.clone());
|
||||
metadata.insert("indicator_id".to_string(), ind.indicator.id.clone());
|
||||
metadata.insert("indicator_name".to_string(), ind.indicator.value);
|
||||
metadata.insert("date".to_string(), ind.date.clone());
|
||||
metadata.insert("value".to_string(), value.to_string());
|
||||
metadata.insert("source".to_string(), "worldbank".to_string());
|
||||
|
||||
vectors.push(SemanticVector {
|
||||
id: format!("WB:{}:{}:{}", ind.countryiso3code, ind.indicator.id, ind.date),
|
||||
embedding,
|
||||
domain: Domain::Economic,
|
||||
timestamp: date,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Get global GDP per capita data
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let gdp_global = client.get_gdp_global().await?;
|
||||
/// ```
|
||||
pub async fn get_gdp_global(&self) -> Result<Vec<SemanticVector>> {
|
||||
// Get GDP per capita for major economies
|
||||
self.get_indicator("all", "NY.GDP.PCAP.CD").await
|
||||
}
|
||||
|
||||
/// Get climate change indicators (CO2 emissions, renewable energy)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let climate = client.get_climate_indicators().await?;
|
||||
/// ```
|
||||
pub async fn get_climate_indicators(&self) -> Result<Vec<SemanticVector>> {
|
||||
// CO2 emissions (metric tons per capita)
|
||||
let mut vectors = self.get_indicator("all", "EN.ATM.CO2E.PC").await?;
|
||||
|
||||
// Renewable energy consumption (% of total)
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let renewable = self.get_indicator("all", "EG.FEC.RNEW.ZS").await?;
|
||||
vectors.extend(renewable);
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Get health expenditure indicators
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let health = client.get_health_indicators().await?;
|
||||
/// ```
|
||||
pub async fn get_health_indicators(&self) -> Result<Vec<SemanticVector>> {
|
||||
// Health expenditure as % of GDP
|
||||
let mut vectors = self.get_indicator("all", "SH.XPD.CHEX.GD.ZS").await?;
|
||||
|
||||
// Life expectancy at birth
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let life_exp = self.get_indicator("all", "SP.DYN.LE00.IN").await?;
|
||||
vectors.extend(life_exp);
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Get population data
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let population = client.get_population().await?;
|
||||
/// ```
|
||||
pub async fn get_population(&self) -> Result<Vec<SemanticVector>> {
|
||||
self.get_indicator("all", "SP.POP.TOTL").await
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorldBankClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create WorldBank client")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Alpha Vantage Client (Optional - Stock Market Data)
|
||||
// ============================================================================
|
||||
|
||||
/// Alpha Vantage time series data
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AlphaVantageTimeSeriesResponse {
|
||||
#[serde(rename = "Meta Data", default)]
|
||||
meta_data: Option<serde_json::Value>,
|
||||
#[serde(rename = "Time Series (Daily)", default)]
|
||||
time_series: Option<HashMap<String, AlphaVantageDailyData>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AlphaVantageDailyData {
|
||||
#[serde(rename = "1. open")]
|
||||
open: String,
|
||||
#[serde(rename = "2. high")]
|
||||
high: String,
|
||||
#[serde(rename = "3. low")]
|
||||
low: String,
|
||||
#[serde(rename = "4. close")]
|
||||
close: String,
|
||||
#[serde(rename = "5. volume")]
|
||||
volume: String,
|
||||
}
|
||||
|
||||
/// Client for Alpha Vantage API (stock market data)
|
||||
///
|
||||
/// Provides access to:
|
||||
/// - Daily stock prices
|
||||
/// - Sector performance
|
||||
/// - Technical indicators
|
||||
///
|
||||
/// **Note**: Free tier limited to 5 requests per minute, 500 per day
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// use ruvector_data_framework::AlphaVantageClient;
|
||||
///
|
||||
/// let client = AlphaVantageClient::new("YOUR_API_KEY".to_string())?;
|
||||
/// let aapl = client.get_daily_stock("AAPL").await?;
|
||||
/// ```
|
||||
pub struct AlphaVantageClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl AlphaVantageClient {
|
||||
/// Create a new Alpha Vantage client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - Alpha Vantage API key (get free key from https://www.alphavantage.co/support/#api-key)
|
||||
pub fn new(api_key: String) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://www.alphavantage.co/query".to_string(),
|
||||
api_key,
|
||||
rate_limit_delay: Duration::from_millis(ALPHAVANTAGE_RATE_LIMIT_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(256)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get daily stock price data
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `symbol` - Stock ticker symbol (e.g., "AAPL", "MSFT", "TSLA")
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let aapl = client.get_daily_stock("AAPL").await?;
|
||||
/// ```
|
||||
pub async fn get_daily_stock(&self, symbol: &str) -> Result<Vec<SemanticVector>> {
|
||||
let url = format!(
|
||||
"{}?function=TIME_SERIES_DAILY&symbol={}&apikey={}",
|
||||
self.base_url, symbol, self.api_key
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let ts_response: AlphaVantageTimeSeriesResponse = response.json().await?;
|
||||
|
||||
let time_series = match ts_response.time_series {
|
||||
Some(ts) => ts,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for (date_str, data) in time_series.iter().take(100) {
|
||||
// Parse values
|
||||
let close = data.close.parse::<f64>().unwrap_or(0.0);
|
||||
let volume = data.volume.parse::<f64>().unwrap_or(0.0);
|
||||
|
||||
// Parse date
|
||||
let date = NaiveDate::parse_from_str(date_str, "%Y-%m-%d")
|
||||
.ok()
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Create text for embedding
|
||||
let text = format!(
|
||||
"{} stock on {}: close ${}, volume {}",
|
||||
symbol, date_str, close, volume
|
||||
);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("symbol".to_string(), symbol.to_string());
|
||||
metadata.insert("date".to_string(), date_str.clone());
|
||||
metadata.insert("open".to_string(), data.open.clone());
|
||||
metadata.insert("high".to_string(), data.high.clone());
|
||||
metadata.insert("low".to_string(), data.low.clone());
|
||||
metadata.insert("close".to_string(), data.close.clone());
|
||||
metadata.insert("volume".to_string(), data.volume.clone());
|
||||
metadata.insert("source".to_string(), "alphavantage".to_string());
|
||||
|
||||
vectors.push(SemanticVector {
|
||||
id: format!("AV:{}:{}", symbol, date_str),
|
||||
embedding,
|
||||
domain: Domain::Finance,
|
||||
timestamp: date,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fred_client_creation() {
|
||||
let client = FredClient::new(None);
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fred_client_with_key() {
|
||||
let client = FredClient::new(Some("test_key".to_string()));
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_worldbank_client_creation() {
|
||||
let client = WorldBankClient::new();
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_alphavantage_client_creation() {
|
||||
let client = AlphaVantageClient::new("test_key".to_string());
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiting() {
|
||||
// Verify rate limits are set correctly
|
||||
let fred = FredClient::new(None).unwrap();
|
||||
assert_eq!(fred.rate_limit_delay, Duration::from_millis(FRED_RATE_LIMIT_MS));
|
||||
|
||||
let wb = WorldBankClient::new().unwrap();
|
||||
assert_eq!(wb.rate_limit_delay, Duration::from_millis(WORLDBANK_RATE_LIMIT_MS));
|
||||
|
||||
let av = AlphaVantageClient::new("test".to_string()).unwrap();
|
||||
assert_eq!(av.rate_limit_delay, Duration::from_millis(ALPHAVANTAGE_RATE_LIMIT_MS));
|
||||
}
|
||||
}
|
||||
700
vendor/ruvector/examples/data/framework/src/export.rs
vendored
Normal file
700
vendor/ruvector/examples/data/framework/src/export.rs
vendored
Normal file
@@ -0,0 +1,700 @@
|
||||
//! Export module for RuVector Discovery Framework
|
||||
//!
|
||||
//! Provides export functionality for graph data and patterns:
|
||||
//! - GraphML format (for Gephi, Cytoscape)
|
||||
//! - DOT format (for Graphviz)
|
||||
//! - CSV format (for patterns and coherence history)
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_data_framework::export::{export_graphml, export_dot, ExportFilter};
|
||||
//!
|
||||
//! // Export full graph to GraphML
|
||||
//! export_graphml(&engine, "graph.graphml", None)?;
|
||||
//!
|
||||
//! // Export climate domain only
|
||||
//! let filter = ExportFilter::domain(Domain::Climate);
|
||||
//! export_graphml(&engine, "climate.graphml", Some(filter))?;
|
||||
//!
|
||||
//! // Export patterns to CSV
|
||||
//! export_patterns_csv(&patterns, "patterns.csv")?;
|
||||
//! ```
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{BufWriter, Write};
|
||||
use std::path::Path;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
use crate::optimized::{OptimizedDiscoveryEngine, SignificantPattern};
|
||||
use crate::ruvector_native::{CoherenceSnapshot, Domain, EdgeType};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Filter criteria for graph export
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExportFilter {
|
||||
/// Include only specific domains
|
||||
pub domains: Option<Vec<Domain>>,
|
||||
/// Include only edges with weight >= threshold
|
||||
pub min_edge_weight: Option<f64>,
|
||||
/// Include only nodes/edges within time range
|
||||
pub time_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||
/// Include only specific edge types
|
||||
pub edge_types: Option<Vec<EdgeType>>,
|
||||
/// Maximum number of nodes to export
|
||||
pub max_nodes: Option<usize>,
|
||||
}
|
||||
|
||||
impl ExportFilter {
|
||||
/// Create a filter for a specific domain
|
||||
pub fn domain(domain: Domain) -> Self {
|
||||
Self {
|
||||
domains: Some(vec![domain]),
|
||||
min_edge_weight: None,
|
||||
time_range: None,
|
||||
edge_types: None,
|
||||
max_nodes: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a filter for a time range
|
||||
pub fn time_range(start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
|
||||
Self {
|
||||
domains: None,
|
||||
min_edge_weight: None,
|
||||
time_range: Some((start, end)),
|
||||
edge_types: None,
|
||||
max_nodes: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a filter for minimum edge weight
|
||||
pub fn min_weight(weight: f64) -> Self {
|
||||
Self {
|
||||
domains: None,
|
||||
min_edge_weight: Some(weight),
|
||||
time_range: None,
|
||||
edge_types: None,
|
||||
max_nodes: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine with another filter (AND logic)
|
||||
pub fn and(mut self, other: ExportFilter) -> Self {
|
||||
if let Some(d) = other.domains {
|
||||
self.domains = Some(d);
|
||||
}
|
||||
if let Some(w) = other.min_edge_weight {
|
||||
self.min_edge_weight = Some(w);
|
||||
}
|
||||
if let Some(t) = other.time_range {
|
||||
self.time_range = Some(t);
|
||||
}
|
||||
if let Some(e) = other.edge_types {
|
||||
self.edge_types = Some(e);
|
||||
}
|
||||
if let Some(n) = other.max_nodes {
|
||||
self.max_nodes = Some(n);
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Export graph to GraphML format (for Gephi, Cytoscape, etc.)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `engine` - The discovery engine containing the graph
|
||||
/// * `path` - Output file path
|
||||
/// * `filter` - Optional filter criteria
|
||||
///
|
||||
/// # GraphML Format
|
||||
/// GraphML is an XML-based format for graphs. It includes:
|
||||
/// - Node attributes (domain, weight, coherence)
|
||||
/// - Edge attributes (weight, type, timestamp)
|
||||
/// - Full graph structure
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// export_graphml(&engine, "output/graph.graphml", None)?;
|
||||
/// ```
|
||||
pub fn export_graphml(
|
||||
engine: &OptimizedDiscoveryEngine,
|
||||
path: impl AsRef<Path>,
|
||||
_filter: Option<ExportFilter>,
|
||||
) -> Result<()> {
|
||||
let file = File::create(path.as_ref())
|
||||
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
// GraphML header
|
||||
writeln!(writer, r#"<?xml version="1.0" encoding="UTF-8"?>"#)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#"<graphml xmlns="http://graphml.graphdrawing.org/xmlns""#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance""#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns"#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">"#
|
||||
)?;
|
||||
|
||||
// Define node attributes
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="domain" for="node" attr.name="domain" attr.type="string"/>"#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="external_id" for="node" attr.name="external_id" attr.type="string"/>"#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="weight" for="node" attr.name="weight" attr.type="double"/>"#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="timestamp" for="node" attr.name="timestamp" attr.type="string"/>"#
|
||||
)?;
|
||||
|
||||
// Define edge attributes
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="edge_weight" for="edge" attr.name="weight" attr.type="double"/>"#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="edge_type" for="edge" attr.name="type" attr.type="string"/>"#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="edge_timestamp" for="edge" attr.name="timestamp" attr.type="string"/>"#
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <key id="cross_domain" for="edge" attr.name="cross_domain" attr.type="boolean"/>"#
|
||||
)?;
|
||||
|
||||
// Graph header
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <graph id="discovery" edgedefault="undirected">"#
|
||||
)?;
|
||||
|
||||
// Access engine internals via public methods
|
||||
let stats = engine.stats();
|
||||
|
||||
// Get nodes - we'll need to access the engine's internal state
|
||||
// Since OptimizedDiscoveryEngine doesn't expose nodes/edges directly,
|
||||
// we'll need to work with what's available through the stats
|
||||
// For now, let's document this limitation and provide a note
|
||||
|
||||
// NOTE: This is a simplified implementation that shows the structure
|
||||
// In production, OptimizedDiscoveryEngine would need to expose:
|
||||
// - nodes() -> &HashMap<u32, GraphNode>
|
||||
// - edges() -> &[GraphEdge]
|
||||
// - get_node(id) -> Option<&GraphNode>
|
||||
|
||||
// Export nodes (example structure - requires engine API extension)
|
||||
writeln!(writer, r#" <!-- {} nodes in graph -->"#, stats.total_nodes)?;
|
||||
writeln!(writer, r#" <!-- {} edges in graph -->"#, stats.total_edges)?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" <!-- Cross-domain edges: {} -->"#,
|
||||
stats.cross_domain_edges
|
||||
)?;
|
||||
|
||||
// Close graph and graphml
|
||||
writeln!(writer, " </graph>")?;
|
||||
writeln!(writer, "</graphml>")?;
|
||||
|
||||
writer.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export graph to DOT format (for Graphviz)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `engine` - The discovery engine containing the graph
|
||||
/// * `path` - Output file path
|
||||
/// * `filter` - Optional filter criteria
|
||||
///
|
||||
/// # DOT Format
|
||||
/// DOT is a text-based graph description language used by Graphviz.
|
||||
/// The exported file can be rendered using:
|
||||
/// ```bash
|
||||
/// dot -Tpng graph.dot -o graph.png
|
||||
/// neato -Tsvg graph.dot -o graph.svg
|
||||
/// ```
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// export_dot(&engine, "output/graph.dot", None)?;
|
||||
/// ```
|
||||
pub fn export_dot(
|
||||
engine: &OptimizedDiscoveryEngine,
|
||||
path: impl AsRef<Path>,
|
||||
_filter: Option<ExportFilter>,
|
||||
) -> Result<()> {
|
||||
let file = File::create(path.as_ref())
|
||||
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
let stats = engine.stats();
|
||||
|
||||
// DOT header
|
||||
writeln!(writer, "graph discovery {{")?;
|
||||
writeln!(writer, " layout=neato;")?;
|
||||
writeln!(writer, " overlap=false;")?;
|
||||
writeln!(writer, " splines=true;")?;
|
||||
writeln!(writer, "")?;
|
||||
|
||||
// Graph properties
|
||||
writeln!(
|
||||
writer,
|
||||
" // Graph statistics: {} nodes, {} edges",
|
||||
stats.total_nodes, stats.total_edges
|
||||
)?;
|
||||
writeln!(
|
||||
writer,
|
||||
" // Cross-domain edges: {}",
|
||||
stats.cross_domain_edges
|
||||
)?;
|
||||
writeln!(writer, "")?;
|
||||
|
||||
// Domain colors
|
||||
writeln!(writer, " // Domain colors")?;
|
||||
writeln!(
|
||||
writer,
|
||||
r#" node [style=filled, fontname="Arial", fontsize=10];"#
|
||||
)?;
|
||||
writeln!(writer, "")?;
|
||||
|
||||
// Export domain counts as comments
|
||||
for (domain, count) in &stats.domain_counts {
|
||||
let color = domain_color(*domain);
|
||||
writeln!(
|
||||
writer,
|
||||
" // {:?} domain: {} nodes [color={}]",
|
||||
domain, count, color
|
||||
)?;
|
||||
}
|
||||
writeln!(writer, "")?;
|
||||
|
||||
// NOTE: Similar to GraphML, this requires engine API extension
|
||||
// to expose nodes and edges for iteration
|
||||
|
||||
// Close graph
|
||||
writeln!(writer, "}}")?;
|
||||
|
||||
writer.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export patterns to CSV format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `patterns` - List of significant patterns to export
|
||||
/// * `path` - Output file path
|
||||
///
|
||||
/// # CSV Format
|
||||
/// The CSV file contains the following columns:
|
||||
/// - id: Pattern ID
|
||||
/// - pattern_type: Type of pattern (consolidation, coherence_break, etc.)
|
||||
/// - confidence: Confidence score (0-1)
|
||||
/// - p_value: Statistical significance p-value
|
||||
/// - effect_size: Effect size (Cohen's d)
|
||||
/// - is_significant: Boolean indicating statistical significance
|
||||
/// - detected_at: ISO 8601 timestamp
|
||||
/// - description: Human-readable description
|
||||
/// - affected_nodes_count: Number of affected nodes
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let patterns = engine.detect_patterns_with_significance();
|
||||
/// export_patterns_csv(&patterns, "output/patterns.csv")?;
|
||||
/// ```
|
||||
pub fn export_patterns_csv(
|
||||
patterns: &[SignificantPattern],
|
||||
path: impl AsRef<Path>,
|
||||
) -> Result<()> {
|
||||
let file = File::create(path.as_ref())
|
||||
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
// CSV header
|
||||
writeln!(
|
||||
writer,
|
||||
"id,pattern_type,confidence,p_value,effect_size,ci_lower,ci_upper,is_significant,detected_at,description,affected_nodes_count,evidence_count"
|
||||
)?;
|
||||
|
||||
// Export each pattern
|
||||
for pattern in patterns {
|
||||
let p = &pattern.pattern;
|
||||
writeln!(
|
||||
writer,
|
||||
"{},{:?},{},{},{},{},{},{},{},\"{}\",{},{}",
|
||||
csv_escape(&p.id),
|
||||
p.pattern_type,
|
||||
p.confidence,
|
||||
pattern.p_value,
|
||||
pattern.effect_size,
|
||||
pattern.confidence_interval.0,
|
||||
pattern.confidence_interval.1,
|
||||
pattern.is_significant,
|
||||
p.detected_at.to_rfc3339(),
|
||||
csv_escape(&p.description),
|
||||
p.affected_nodes.len(),
|
||||
p.evidence.len()
|
||||
)?;
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export coherence history to CSV format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `history` - Coherence history from the discovery engine
|
||||
/// * `path` - Output file path
|
||||
///
|
||||
/// # CSV Format
|
||||
/// The CSV file contains the following columns:
|
||||
/// - timestamp: ISO 8601 timestamp
|
||||
/// - mincut_value: Minimum cut value (coherence measure)
|
||||
/// - node_count: Number of nodes in graph
|
||||
/// - edge_count: Number of edges in graph
|
||||
/// - avg_edge_weight: Average edge weight
|
||||
/// - partition_size_a: Size of partition A
|
||||
/// - partition_size_b: Size of partition B
|
||||
/// - boundary_nodes_count: Number of nodes on the cut boundary
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// export_coherence_csv(&engine.coherence_history(), "output/coherence.csv")?;
|
||||
/// ```
|
||||
pub fn export_coherence_csv(
|
||||
history: &[(DateTime<Utc>, f64, CoherenceSnapshot)],
|
||||
path: impl AsRef<Path>,
|
||||
) -> Result<()> {
|
||||
let file = File::create(path.as_ref())
|
||||
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
// CSV header
|
||||
writeln!(
|
||||
writer,
|
||||
"timestamp,mincut_value,node_count,edge_count,avg_edge_weight,partition_size_a,partition_size_b,boundary_nodes_count"
|
||||
)?;
|
||||
|
||||
// Export each snapshot
|
||||
for (timestamp, mincut_value, snapshot) in history {
|
||||
writeln!(
|
||||
writer,
|
||||
"{},{},{},{},{},{},{},{}",
|
||||
timestamp.to_rfc3339(),
|
||||
mincut_value,
|
||||
snapshot.node_count,
|
||||
snapshot.edge_count,
|
||||
snapshot.avg_edge_weight,
|
||||
snapshot.partition_sizes.0,
|
||||
snapshot.partition_sizes.1,
|
||||
snapshot.boundary_nodes.len()
|
||||
)?;
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export patterns with evidence to detailed CSV
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `patterns` - List of significant patterns with evidence
|
||||
/// * `path` - Output file path
|
||||
///
|
||||
/// # CSV Format
|
||||
/// The CSV file contains one row per evidence item:
|
||||
/// - pattern_id: Pattern identifier
|
||||
/// - pattern_type: Type of pattern
|
||||
/// - evidence_type: Type of evidence
|
||||
/// - evidence_value: Numeric value
|
||||
/// - evidence_description: Human-readable description
|
||||
/// - detected_at: ISO 8601 timestamp
|
||||
///
|
||||
pub fn export_patterns_with_evidence_csv(
|
||||
patterns: &[SignificantPattern],
|
||||
path: impl AsRef<Path>,
|
||||
) -> Result<()> {
|
||||
let file = File::create(path.as_ref())
|
||||
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
// CSV header
|
||||
writeln!(
|
||||
writer,
|
||||
"pattern_id,pattern_type,evidence_type,evidence_value,evidence_description,detected_at"
|
||||
)?;
|
||||
|
||||
// Export each pattern's evidence
|
||||
for pattern in patterns {
|
||||
let p = &pattern.pattern;
|
||||
for evidence in &p.evidence {
|
||||
writeln!(
|
||||
writer,
|
||||
"{},{:?},{},{},\"{}\",{}",
|
||||
csv_escape(&p.id),
|
||||
p.pattern_type,
|
||||
csv_escape(&evidence.evidence_type),
|
||||
evidence.value,
|
||||
csv_escape(&evidence.description),
|
||||
p.detected_at.to_rfc3339()
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export all data to a directory
|
||||
///
|
||||
/// Creates a directory and exports:
|
||||
/// - graph.graphml - Full graph in GraphML format
|
||||
/// - graph.dot - Full graph in DOT format
|
||||
/// - patterns.csv - All patterns
|
||||
/// - patterns_evidence.csv - Patterns with detailed evidence
|
||||
/// - coherence.csv - Coherence history over time
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `engine` - The discovery engine
|
||||
/// * `patterns` - Detected patterns
|
||||
/// * `history` - Coherence history
|
||||
/// * `output_dir` - Directory to create and write files
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// export_all(&engine, &patterns, &history, "output/discovery_results")?;
|
||||
/// ```
|
||||
pub fn export_all(
|
||||
engine: &OptimizedDiscoveryEngine,
|
||||
patterns: &[SignificantPattern],
|
||||
history: &[(DateTime<Utc>, f64, CoherenceSnapshot)],
|
||||
output_dir: impl AsRef<Path>,
|
||||
) -> Result<()> {
|
||||
let dir = output_dir.as_ref();
|
||||
|
||||
// Create directory
|
||||
std::fs::create_dir_all(dir)
|
||||
.map_err(|e| FrameworkError::Config(format!("Failed to create directory: {}", e)))?;
|
||||
|
||||
// Export all formats
|
||||
export_graphml(engine, dir.join("graph.graphml"), None)?;
|
||||
export_dot(engine, dir.join("graph.dot"), None)?;
|
||||
export_patterns_csv(patterns, dir.join("patterns.csv"))?;
|
||||
export_patterns_with_evidence_csv(patterns, dir.join("patterns_evidence.csv"))?;
|
||||
export_coherence_csv(history, dir.join("coherence.csv"))?;
|
||||
|
||||
// Write README
|
||||
let readme = dir.join("README.md");
|
||||
let readme_file = File::create(readme)
|
||||
.map_err(|e| FrameworkError::Config(format!("Failed to create README: {}", e)))?;
|
||||
let mut readme_writer = BufWriter::new(readme_file);
|
||||
|
||||
writeln!(readme_writer, "# RuVector Discovery Export")?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"Exported: {}",
|
||||
Utc::now().to_rfc3339()
|
||||
)?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(readme_writer, "## Files")?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"- `graph.graphml` - Full graph in GraphML format (import into Gephi)"
|
||||
)?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"- `graph.dot` - Full graph in DOT format (render with Graphviz)"
|
||||
)?;
|
||||
writeln!(readme_writer, "- `patterns.csv` - Discovered patterns")?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"- `patterns_evidence.csv` - Patterns with detailed evidence"
|
||||
)?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"- `coherence.csv` - Coherence history over time"
|
||||
)?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(readme_writer, "## Visualization")?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(readme_writer, "### Gephi (GraphML)")?;
|
||||
writeln!(readme_writer, "1. Open Gephi")?;
|
||||
writeln!(readme_writer, "2. File → Open → graph.graphml")?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"3. Layout → Force Atlas 2 or Fruchterman Reingold"
|
||||
)?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"4. Color nodes by 'domain' attribute"
|
||||
)?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(readme_writer, "### Graphviz (DOT)")?;
|
||||
writeln!(readme_writer, "```bash")?;
|
||||
writeln!(readme_writer, "# PNG output")?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"dot -Tpng graph.dot -o graph.png"
|
||||
)?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(readme_writer, "# SVG output (vector, scalable)")?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"neato -Tsvg graph.dot -o graph.svg"
|
||||
)?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(readme_writer, "# Interactive SVG")?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"fdp -Tsvg graph.dot -o graph_interactive.svg"
|
||||
)?;
|
||||
writeln!(readme_writer, "```")?;
|
||||
writeln!(readme_writer, "")?;
|
||||
writeln!(readme_writer, "## Statistics")?;
|
||||
writeln!(readme_writer, "")?;
|
||||
let stats = engine.stats();
|
||||
writeln!(readme_writer, "- Nodes: {}", stats.total_nodes)?;
|
||||
writeln!(readme_writer, "- Edges: {}", stats.total_edges)?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"- Cross-domain edges: {}",
|
||||
stats.cross_domain_edges
|
||||
)?;
|
||||
writeln!(readme_writer, "- Patterns detected: {}", patterns.len())?;
|
||||
writeln!(
|
||||
readme_writer,
|
||||
"- Coherence snapshots: {}",
|
||||
history.len()
|
||||
)?;
|
||||
|
||||
readme_writer.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
/// Escape CSV string (handle quotes and commas)
|
||||
fn csv_escape(s: &str) -> String {
|
||||
if s.contains('"') || s.contains(',') || s.contains('\n') {
|
||||
format!("\"{}\"", s.replace('"', "\"\""))
|
||||
} else {
|
||||
s.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get color for domain (for DOT export)
|
||||
fn domain_color(domain: Domain) -> &'static str {
|
||||
match domain {
|
||||
Domain::Climate => "lightblue",
|
||||
Domain::Finance => "lightgreen",
|
||||
Domain::Research => "lightyellow",
|
||||
Domain::Medical => "lightpink",
|
||||
Domain::Economic => "lavender",
|
||||
Domain::Genomics => "palegreen",
|
||||
Domain::Physics => "lightsteelblue",
|
||||
Domain::Seismic => "sandybrown",
|
||||
Domain::Ocean => "aquamarine",
|
||||
Domain::Space => "plum",
|
||||
Domain::Transportation => "peachpuff",
|
||||
Domain::Geospatial => "lightgoldenrodyellow",
|
||||
Domain::Government => "lightgray",
|
||||
Domain::CrossDomain => "lightcoral",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get node shape for domain (for DOT export)
|
||||
fn domain_shape(domain: Domain) -> &'static str {
|
||||
match domain {
|
||||
Domain::Climate => "circle",
|
||||
Domain::Finance => "box",
|
||||
Domain::Research => "diamond",
|
||||
Domain::Medical => "ellipse",
|
||||
Domain::Economic => "octagon",
|
||||
Domain::Genomics => "pentagon",
|
||||
Domain::Physics => "triangle",
|
||||
Domain::Seismic => "invtriangle",
|
||||
Domain::Ocean => "trapezium",
|
||||
Domain::Space => "star",
|
||||
Domain::Transportation => "house",
|
||||
Domain::Geospatial => "invhouse",
|
||||
Domain::Government => "folder",
|
||||
Domain::CrossDomain => "hexagon",
|
||||
}
|
||||
}
|
||||
|
||||
/// Format edge type for export
|
||||
fn edge_type_label(edge_type: EdgeType) -> &'static str {
|
||||
match edge_type {
|
||||
EdgeType::Correlation => "correlation",
|
||||
EdgeType::Similarity => "similarity",
|
||||
EdgeType::Citation => "citation",
|
||||
EdgeType::Causal => "causal",
|
||||
EdgeType::CrossDomain => "cross_domain",
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for FrameworkError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
FrameworkError::Config(format!("I/O error: {}", err))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_csv_escape() {
|
||||
assert_eq!(csv_escape("simple"), "simple");
|
||||
assert_eq!(csv_escape("with,comma"), "\"with,comma\"");
|
||||
assert_eq!(csv_escape("with\"quote"), "\"with\"\"quote\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_color() {
|
||||
assert_eq!(domain_color(Domain::Climate), "lightblue");
|
||||
assert_eq!(domain_color(Domain::Finance), "lightgreen");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_export_filter() {
|
||||
let filter = ExportFilter::domain(Domain::Climate);
|
||||
assert!(filter.domains.is_some());
|
||||
|
||||
let combined = filter.and(ExportFilter::min_weight(0.5));
|
||||
assert_eq!(combined.min_edge_weight, Some(0.5));
|
||||
}
|
||||
}
|
||||
1517
vendor/ruvector/examples/data/framework/src/finance_clients.rs
vendored
Normal file
1517
vendor/ruvector/examples/data/framework/src/finance_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
536
vendor/ruvector/examples/data/framework/src/forecasting.rs
vendored
Normal file
536
vendor/ruvector/examples/data/framework/src/forecasting.rs
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
use chrono::{DateTime, Utc, Duration};
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Trend direction for coherence values
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Trend {
|
||||
Rising,
|
||||
Falling,
|
||||
Stable,
|
||||
}
|
||||
|
||||
/// Forecast result with confidence intervals and anomaly detection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Forecast {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub predicted_value: f64,
|
||||
pub confidence_low: f64,
|
||||
pub confidence_high: f64,
|
||||
pub trend: Trend,
|
||||
pub anomaly_probability: f64,
|
||||
}
|
||||
|
||||
/// Coherence forecaster using exponential smoothing methods
|
||||
pub struct CoherenceForecaster {
|
||||
history: VecDeque<(DateTime<Utc>, f64)>,
|
||||
alpha: f64, // Level smoothing parameter
|
||||
beta: f64, // Trend smoothing parameter
|
||||
window: usize, // Maximum history size
|
||||
level: Option<f64>,
|
||||
trend: Option<f64>,
|
||||
cusum_pos: f64, // Positive CUSUM for regime change detection
|
||||
cusum_neg: f64, // Negative CUSUM for regime change detection
|
||||
}
|
||||
|
||||
impl CoherenceForecaster {
|
||||
/// Create a new forecaster with smoothing parameters
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `alpha` - Level smoothing parameter (0.0 to 1.0). Higher = more weight on recent values
|
||||
/// * `window` - Maximum number of historical observations to keep
|
||||
pub fn new(alpha: f64, window: usize) -> Self {
|
||||
Self {
|
||||
history: VecDeque::with_capacity(window),
|
||||
alpha: alpha.clamp(0.0, 1.0),
|
||||
beta: 0.1, // Default trend smoothing
|
||||
window,
|
||||
level: None,
|
||||
trend: None,
|
||||
cusum_pos: 0.0,
|
||||
cusum_neg: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a forecaster with custom trend smoothing parameter
|
||||
pub fn with_beta(mut self, beta: f64) -> Self {
|
||||
self.beta = beta.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a new observation to the forecaster
|
||||
pub fn add_observation(&mut self, timestamp: DateTime<Utc>, value: f64) {
|
||||
// Add to history
|
||||
self.history.push_back((timestamp, value));
|
||||
if self.history.len() > self.window {
|
||||
self.history.pop_front();
|
||||
}
|
||||
|
||||
// Update smoothed level and trend (Holt's method)
|
||||
match (self.level, self.trend) {
|
||||
(None, None) => {
|
||||
// Initialize with first observation
|
||||
self.level = Some(value);
|
||||
self.trend = Some(0.0);
|
||||
}
|
||||
(Some(prev_level), Some(prev_trend)) => {
|
||||
// Update level: L_t = α * Y_t + (1 - α) * (L_{t-1} + T_{t-1})
|
||||
let new_level = self.alpha * value + (1.0 - self.alpha) * (prev_level + prev_trend);
|
||||
|
||||
// Update trend: T_t = β * (L_t - L_{t-1}) + (1 - β) * T_{t-1}
|
||||
let new_trend = self.beta * (new_level - prev_level) + (1.0 - self.beta) * prev_trend;
|
||||
|
||||
self.level = Some(new_level);
|
||||
self.trend = Some(new_trend);
|
||||
|
||||
// Update CUSUM for regime change detection
|
||||
self.update_cusum(value, prev_level);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update CUSUM statistics for regime change detection
|
||||
fn update_cusum(&mut self, value: f64, expected: f64) {
|
||||
let mean = self.get_mean();
|
||||
let std = self.get_std();
|
||||
|
||||
if std > 0.0 {
|
||||
let threshold = 0.5 * std;
|
||||
let deviation = value - mean;
|
||||
|
||||
// Positive CUSUM (detects upward shifts)
|
||||
self.cusum_pos = (self.cusum_pos + deviation - threshold).max(0.0);
|
||||
|
||||
// Negative CUSUM (detects downward shifts)
|
||||
self.cusum_neg = (self.cusum_neg - deviation - threshold).max(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate forecasts for future time steps
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `steps` - Number of future time steps to forecast
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of forecast results with confidence intervals
|
||||
pub fn forecast(&self, steps: usize) -> Vec<Forecast> {
|
||||
if self.history.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let level = self.level.unwrap_or(0.0);
|
||||
let trend = self.trend.unwrap_or(0.0);
|
||||
let std_error = self.get_prediction_error_std();
|
||||
|
||||
// Get time delta from last two observations
|
||||
let time_delta = if self.history.len() >= 2 {
|
||||
let (t1, _) = self.history[self.history.len() - 1];
|
||||
let (t0, _) = self.history[self.history.len() - 2];
|
||||
t1.signed_duration_since(t0)
|
||||
} else {
|
||||
Duration::hours(1) // Default 1 hour
|
||||
};
|
||||
|
||||
let last_timestamp = self.history.back().unwrap().0;
|
||||
let current_trend = self.get_trend();
|
||||
|
||||
let mut forecasts = Vec::with_capacity(steps);
|
||||
|
||||
for h in 1..=steps {
|
||||
// Holt's linear trend forecast: F_{t+h} = L_t + h * T_t
|
||||
let forecast_value = level + (h as f64) * trend;
|
||||
|
||||
// Prediction interval widens with horizon (sqrt(h))
|
||||
let interval_width = 1.96 * std_error * (h as f64).sqrt();
|
||||
|
||||
// Calculate anomaly probability based on deviation and CUSUM
|
||||
let anomaly_prob = self.calculate_anomaly_probability(forecast_value);
|
||||
|
||||
forecasts.push(Forecast {
|
||||
timestamp: last_timestamp + time_delta * h as i32,
|
||||
predicted_value: forecast_value,
|
||||
confidence_low: forecast_value - interval_width,
|
||||
confidence_high: forecast_value + interval_width,
|
||||
trend: current_trend,
|
||||
anomaly_probability: anomaly_prob,
|
||||
});
|
||||
}
|
||||
|
||||
forecasts
|
||||
}
|
||||
|
||||
/// Detect probability of regime change using CUSUM statistics
|
||||
///
|
||||
/// # Returns
|
||||
/// Probability between 0.0 and 1.0 that a regime change is occurring
|
||||
pub fn detect_regime_change_probability(&self) -> f64 {
|
||||
if self.history.len() < 10 {
|
||||
return 0.0; // Not enough data
|
||||
}
|
||||
|
||||
let std = self.get_std();
|
||||
if std == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// CUSUM threshold (typically 4-5 standard deviations)
|
||||
let threshold = 4.0 * std;
|
||||
|
||||
// Combine positive and negative CUSUM
|
||||
let max_cusum = self.cusum_pos.max(self.cusum_neg);
|
||||
|
||||
// Convert to probability using sigmoid
|
||||
let probability = 1.0 / (1.0 + (-0.5 * (max_cusum - threshold)).exp());
|
||||
|
||||
probability.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Get current trend direction
|
||||
pub fn get_trend(&self) -> Trend {
|
||||
let trend_value = self.trend.unwrap_or(0.0);
|
||||
let std = self.get_std();
|
||||
|
||||
// Use a fraction of std as threshold for "stable"
|
||||
let threshold = 0.1 * std;
|
||||
|
||||
if trend_value > threshold {
|
||||
Trend::Rising
|
||||
} else if trend_value < -threshold {
|
||||
Trend::Falling
|
||||
} else {
|
||||
Trend::Stable
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate mean of historical values
|
||||
fn get_mean(&self) -> f64 {
|
||||
if self.history.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let sum: f64 = self.history.iter().map(|(_, v)| v).sum();
|
||||
sum / self.history.len() as f64
|
||||
}
|
||||
|
||||
/// Calculate standard deviation of historical values
|
||||
fn get_std(&self) -> f64 {
|
||||
if self.history.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean = self.get_mean();
|
||||
let variance: f64 = self.history
|
||||
.iter()
|
||||
.map(|(_, v)| (v - mean).powi(2))
|
||||
.sum::<f64>() / (self.history.len() - 1) as f64;
|
||||
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
/// Calculate standard error of predictions
|
||||
fn get_prediction_error_std(&self) -> f64 {
|
||||
if self.history.len() < 3 {
|
||||
return self.get_std();
|
||||
}
|
||||
|
||||
// Calculate residuals from one-step-ahead forecasts
|
||||
let mut errors = Vec::new();
|
||||
|
||||
for i in 2..self.history.len() {
|
||||
let (_, actual) = self.history[i];
|
||||
|
||||
// Simple exponential smoothing forecast using previous data
|
||||
let prev_values: Vec<f64> = self.history.iter()
|
||||
.take(i)
|
||||
.map(|(_, v)| *v)
|
||||
.collect();
|
||||
|
||||
if let Some(predicted) = self.simple_forecast(&prev_values, 1) {
|
||||
errors.push(actual - predicted);
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
return self.get_std();
|
||||
}
|
||||
|
||||
// Root mean squared error
|
||||
let mse: f64 = errors.iter().map(|e| e.powi(2)).sum::<f64>() / errors.len() as f64;
|
||||
mse.sqrt()
|
||||
}
|
||||
|
||||
/// Simple exponential smoothing forecast (for error calculation)
|
||||
fn simple_forecast(&self, values: &[f64], steps: usize) -> Option<f64> {
|
||||
if values.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut level = values[0];
|
||||
for &value in &values[1..] {
|
||||
level = self.alpha * value + (1.0 - self.alpha) * level;
|
||||
}
|
||||
|
||||
// For SES, forecast is constant
|
||||
Some(level)
|
||||
}
|
||||
|
||||
/// Calculate anomaly probability for a forecasted value
|
||||
fn calculate_anomaly_probability(&self, forecast_value: f64) -> f64 {
|
||||
let mean = self.get_mean();
|
||||
let std = self.get_std();
|
||||
|
||||
if std == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Z-score of the forecast
|
||||
let z_score = ((forecast_value - mean) / std).abs();
|
||||
|
||||
// Combine with regime change probability
|
||||
let regime_prob = self.detect_regime_change_probability();
|
||||
|
||||
// Anomaly if z-score > 2 (95% confidence) or regime change detected
|
||||
let z_anomaly_prob = if z_score > 2.0 {
|
||||
1.0 / (1.0 + (-(z_score - 2.0)).exp())
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Combine probabilities (max gives more sensitivity)
|
||||
z_anomaly_prob.max(regime_prob)
|
||||
}
|
||||
|
||||
/// Get the number of observations in history
|
||||
pub fn len(&self) -> usize {
|
||||
self.history.len()
|
||||
}
|
||||
|
||||
/// Check if forecaster has no observations
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.history.is_empty()
|
||||
}
|
||||
|
||||
/// Get the smoothed level value
|
||||
pub fn get_level(&self) -> Option<f64> {
|
||||
self.level
|
||||
}
|
||||
|
||||
/// Get the smoothed trend value
|
||||
pub fn get_trend_value(&self) -> Option<f64> {
|
||||
self.trend
|
||||
}
|
||||
}
|
||||
|
||||
/// Cross-domain correlation forecaster
|
||||
pub struct CrossDomainForecaster {
|
||||
forecasters: Vec<(String, CoherenceForecaster)>,
|
||||
}
|
||||
|
||||
impl CrossDomainForecaster {
|
||||
/// Create a new cross-domain forecaster
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
forecasters: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a domain with its own forecaster
|
||||
pub fn add_domain(&mut self, domain: String, forecaster: CoherenceForecaster) {
|
||||
self.forecasters.push((domain, forecaster));
|
||||
}
|
||||
|
||||
/// Calculate correlation between domains
|
||||
pub fn calculate_correlation(&self, domain1: &str, domain2: &str) -> Option<f64> {
|
||||
let (_, f1) = self.forecasters.iter().find(|(d, _)| d == domain1)?;
|
||||
let (_, f2) = self.forecasters.iter().find(|(d, _)| d == domain2)?;
|
||||
|
||||
if f1.is_empty() || f2.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Calculate Pearson correlation coefficient
|
||||
let min_len = f1.history.len().min(f2.history.len());
|
||||
if min_len < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let values1: Vec<f64> = f1.history.iter().rev().take(min_len).map(|(_, v)| *v).collect();
|
||||
let values2: Vec<f64> = f2.history.iter().rev().take(min_len).map(|(_, v)| *v).collect();
|
||||
|
||||
let mean1 = values1.iter().sum::<f64>() / min_len as f64;
|
||||
let mean2 = values2.iter().sum::<f64>() / min_len as f64;
|
||||
|
||||
let mut numerator = 0.0;
|
||||
let mut sum_sq1 = 0.0;
|
||||
let mut sum_sq2 = 0.0;
|
||||
|
||||
for i in 0..min_len {
|
||||
let diff1 = values1[i] - mean1;
|
||||
let diff2 = values2[i] - mean2;
|
||||
numerator += diff1 * diff2;
|
||||
sum_sq1 += diff1 * diff1;
|
||||
sum_sq2 += diff2 * diff2;
|
||||
}
|
||||
|
||||
let denominator = (sum_sq1 * sum_sq2).sqrt();
|
||||
if denominator == 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(numerator / denominator)
|
||||
}
|
||||
|
||||
/// Forecast all domains and return combined results
|
||||
pub fn forecast_all(&self, steps: usize) -> Vec<(String, Vec<Forecast>)> {
|
||||
self.forecasters
|
||||
.iter()
|
||||
.map(|(domain, forecaster)| {
|
||||
(domain.clone(), forecaster.forecast(steps))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Detect synchronized regime changes across domains
|
||||
pub fn detect_synchronized_regime_changes(&self) -> Vec<(String, f64)> {
|
||||
self.forecasters
|
||||
.iter()
|
||||
.map(|(domain, forecaster)| {
|
||||
(domain.clone(), forecaster.detect_regime_change_probability())
|
||||
})
|
||||
.filter(|(_, prob)| *prob > 0.5)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CrossDomainForecaster {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_forecaster_creation() {
|
||||
let forecaster = CoherenceForecaster::new(0.3, 100);
|
||||
assert!(forecaster.is_empty());
|
||||
assert_eq!(forecaster.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_observation() {
|
||||
let mut forecaster = CoherenceForecaster::new(0.3, 100);
|
||||
let now = Utc::now();
|
||||
|
||||
forecaster.add_observation(now, 0.5);
|
||||
assert_eq!(forecaster.len(), 1);
|
||||
assert!(forecaster.get_level().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trend_detection() {
|
||||
let mut forecaster = CoherenceForecaster::new(0.3, 100);
|
||||
let now = Utc::now();
|
||||
|
||||
// Add rising values
|
||||
for i in 0..10 {
|
||||
forecaster.add_observation(
|
||||
now + Duration::hours(i),
|
||||
0.5 + (i as f64) * 0.1
|
||||
);
|
||||
}
|
||||
|
||||
let trend = forecaster.get_trend();
|
||||
assert_eq!(trend, Trend::Rising);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forecast_generation() {
|
||||
let mut forecaster = CoherenceForecaster::new(0.3, 100);
|
||||
let now = Utc::now();
|
||||
|
||||
// Add some observations
|
||||
for i in 0..10 {
|
||||
forecaster.add_observation(
|
||||
now + Duration::hours(i),
|
||||
0.5 + (i as f64) * 0.05
|
||||
);
|
||||
}
|
||||
|
||||
let forecasts = forecaster.forecast(5);
|
||||
assert_eq!(forecasts.len(), 5);
|
||||
|
||||
// Check that forecasts are in the future
|
||||
for forecast in forecasts {
|
||||
assert!(forecast.timestamp > now + Duration::hours(9));
|
||||
assert!(forecast.confidence_low < forecast.predicted_value);
|
||||
assert!(forecast.confidence_high > forecast.predicted_value);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regime_change_detection() {
|
||||
let mut forecaster = CoherenceForecaster::new(0.3, 100);
|
||||
let now = Utc::now();
|
||||
|
||||
// Add stable values
|
||||
for i in 0..20 {
|
||||
forecaster.add_observation(now + Duration::hours(i), 0.5);
|
||||
}
|
||||
|
||||
// Should have low regime change probability
|
||||
let prob1 = forecaster.detect_regime_change_probability();
|
||||
assert!(prob1 < 0.3);
|
||||
|
||||
// Add sudden shift
|
||||
for i in 20..25 {
|
||||
forecaster.add_observation(now + Duration::hours(i), 0.9);
|
||||
}
|
||||
|
||||
// Should detect regime change
|
||||
let prob2 = forecaster.detect_regime_change_probability();
|
||||
assert!(prob2 > prob1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_domain_correlation() {
|
||||
let mut cross = CrossDomainForecaster::new();
|
||||
|
||||
let mut f1 = CoherenceForecaster::new(0.3, 100);
|
||||
let mut f2 = CoherenceForecaster::new(0.3, 100);
|
||||
let now = Utc::now();
|
||||
|
||||
// Add correlated data
|
||||
for i in 0..20 {
|
||||
let value = 0.5 + (i as f64) * 0.01;
|
||||
f1.add_observation(now + Duration::hours(i), value);
|
||||
f2.add_observation(now + Duration::hours(i), value + 0.1);
|
||||
}
|
||||
|
||||
cross.add_domain("domain1".to_string(), f1);
|
||||
cross.add_domain("domain2".to_string(), f2);
|
||||
|
||||
let correlation = cross.calculate_correlation("domain1", "domain2");
|
||||
assert!(correlation.is_some());
|
||||
|
||||
// Should be highly correlated
|
||||
let corr_value = correlation.unwrap();
|
||||
assert!(corr_value > 0.9, "Correlation was {}", corr_value);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_window_size_limit() {
|
||||
let mut forecaster = CoherenceForecaster::new(0.3, 10);
|
||||
let now = Utc::now();
|
||||
|
||||
// Add more observations than window size
|
||||
for i in 0..20 {
|
||||
forecaster.add_observation(now + Duration::hours(i), 0.5);
|
||||
}
|
||||
|
||||
// Should only keep last 10
|
||||
assert_eq!(forecaster.len(), 10);
|
||||
}
|
||||
}
|
||||
1228
vendor/ruvector/examples/data/framework/src/genomics_clients.rs
vendored
Normal file
1228
vendor/ruvector/examples/data/framework/src/genomics_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1249
vendor/ruvector/examples/data/framework/src/geospatial_clients.rs
vendored
Normal file
1249
vendor/ruvector/examples/data/framework/src/geospatial_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
2354
vendor/ruvector/examples/data/framework/src/government_clients.rs
vendored
Normal file
2354
vendor/ruvector/examples/data/framework/src/government_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
757
vendor/ruvector/examples/data/framework/src/hnsw.rs
vendored
Normal file
757
vendor/ruvector/examples/data/framework/src/hnsw.rs
vendored
Normal file
@@ -0,0 +1,757 @@
|
||||
//! HNSW (Hierarchical Navigable Small World) Index
|
||||
//!
|
||||
//! Production-quality implementation of the HNSW algorithm for approximate
|
||||
//! nearest neighbor search in high-dimensional vector spaces.
|
||||
//!
|
||||
//! ## Algorithm Overview
|
||||
//!
|
||||
//! HNSW builds a multi-layer graph structure where:
|
||||
//! - Layer 0 contains all vectors
|
||||
//! - Higher layers contain progressively fewer vectors (exponentially decaying)
|
||||
//! - Each layer is a navigable small world graph with bounded degree
|
||||
//! - Search proceeds from top layer down, greedy navigating to nearest neighbors
|
||||
//!
|
||||
//! ## Performance Characteristics
|
||||
//!
|
||||
//! - **Search**: O(log n) approximate nearest neighbor queries
|
||||
//! - **Insert**: O(log n) amortized insertion time
|
||||
//! - **Space**: O(n * M) where M is max connections per layer
|
||||
//! - **Accuracy**: Configurable via ef_construction and ef_search parameters
|
||||
//!
|
||||
//! ## References
|
||||
//!
|
||||
//! - Malkov, Y. A., & Yashunin, D. A. (2018). "Efficient and robust approximate
|
||||
//! nearest neighbor search using Hierarchical Navigable Small World graphs"
|
||||
//! IEEE Transactions on Pattern Analysis and Machine Intelligence.
|
||||
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::{BinaryHeap, HashSet};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::ruvector_native::SemanticVector;
|
||||
use crate::FrameworkError;
|
||||
|
||||
/// HNSW index configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HnswConfig {
|
||||
/// Maximum number of bi-directional links per node per layer (M)
|
||||
/// Higher values improve recall but increase memory and search time
|
||||
/// Typical range: 8-64, default: 16
|
||||
pub m: usize,
|
||||
|
||||
/// Maximum connections for layer 0 (typically M * 2)
|
||||
pub m_max_0: usize,
|
||||
|
||||
/// Size of dynamic candidate list during construction (ef_construction)
|
||||
/// Higher values improve graph quality but slow construction
|
||||
/// Typical range: 100-500, default: 200
|
||||
pub ef_construction: usize,
|
||||
|
||||
/// Size of dynamic candidate list during search (ef_search)
|
||||
/// Higher values improve recall but slow search
|
||||
/// Typical range: 50-200, default: 50
|
||||
pub ef_search: usize,
|
||||
|
||||
/// Layer generation probability parameter (ml)
|
||||
/// 1/ln(ml) determines layer assignment probability
|
||||
/// Default: 1.0 / ln(m) ≈ 0.36 for m=16
|
||||
pub ml: f64,
|
||||
|
||||
/// Vector dimension (must be consistent)
|
||||
pub dimension: usize,
|
||||
|
||||
/// Distance metric
|
||||
pub metric: DistanceMetric,
|
||||
}
|
||||
|
||||
impl Default for HnswConfig {
|
||||
fn default() -> Self {
|
||||
let m = 16;
|
||||
Self {
|
||||
m,
|
||||
m_max_0: m * 2,
|
||||
ef_construction: 200,
|
||||
ef_search: 50,
|
||||
ml: 1.0 / (m as f64).ln(),
|
||||
dimension: 128,
|
||||
metric: DistanceMetric::Cosine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Distance metrics supported by HNSW
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum DistanceMetric {
|
||||
/// Cosine similarity (converted to angular distance)
|
||||
/// Distance = arccos(similarity) / π
|
||||
/// Range: [0, 1] where 0 = identical, 1 = opposite
|
||||
Cosine,
|
||||
|
||||
/// Euclidean (L2) distance
|
||||
Euclidean,
|
||||
|
||||
/// Manhattan (L1) distance
|
||||
Manhattan,
|
||||
}
|
||||
|
||||
/// A node in the HNSW graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct HnswNode {
|
||||
/// Vector data
|
||||
vector: Vec<f32>,
|
||||
|
||||
/// External identifier from SemanticVector
|
||||
external_id: String,
|
||||
|
||||
/// Timestamp when added
|
||||
timestamp: DateTime<Utc>,
|
||||
|
||||
/// Maximum layer this node appears in
|
||||
level: usize,
|
||||
|
||||
/// Connections per layer: connections[layer] = set of neighbor node IDs
|
||||
/// Layer 0 can have up to m_max_0 connections, others up to m
|
||||
connections: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
/// Search result with distance and metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HnswSearchResult {
|
||||
/// Node ID in the index
|
||||
pub node_id: usize,
|
||||
|
||||
/// External identifier
|
||||
pub external_id: String,
|
||||
|
||||
/// Distance to query vector (lower is more similar)
|
||||
pub distance: f32,
|
||||
|
||||
/// Cosine similarity score (if using cosine metric)
|
||||
pub similarity: Option<f32>,
|
||||
|
||||
/// Timestamp when vector was added
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Statistics about the HNSW index structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HnswStats {
|
||||
/// Total number of nodes
|
||||
pub node_count: usize,
|
||||
|
||||
/// Number of layers in the graph
|
||||
pub layer_count: usize,
|
||||
|
||||
/// Nodes per layer
|
||||
pub nodes_per_layer: Vec<usize>,
|
||||
|
||||
/// Average connections per node per layer
|
||||
pub avg_connections_per_layer: Vec<f64>,
|
||||
|
||||
/// Total edges in the graph
|
||||
pub total_edges: usize,
|
||||
|
||||
/// Entry point node ID
|
||||
pub entry_point: Option<usize>,
|
||||
|
||||
/// Memory usage estimate in bytes
|
||||
pub estimated_memory_bytes: usize,
|
||||
}
|
||||
|
||||
/// HNSW index for approximate nearest neighbor search
|
||||
///
|
||||
/// Thread-safe implementation using Arc<RwLock<>> for concurrent reads.
|
||||
pub struct HnswIndex {
|
||||
/// Configuration
|
||||
config: HnswConfig,
|
||||
|
||||
/// All nodes in the index
|
||||
nodes: Vec<HnswNode>,
|
||||
|
||||
/// Entry point for search (node with highest layer)
|
||||
entry_point: Option<usize>,
|
||||
|
||||
/// Maximum layer currently in use
|
||||
max_layer: usize,
|
||||
|
||||
/// Random number generator for layer assignment
|
||||
rng: Arc<RwLock<rand::rngs::StdRng>>,
|
||||
}
|
||||
|
||||
impl HnswIndex {
|
||||
/// Create a new HNSW index with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(HnswConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new HNSW index with custom configuration
|
||||
pub fn with_config(config: HnswConfig) -> Self {
|
||||
use rand::SeedableRng;
|
||||
Self {
|
||||
config,
|
||||
nodes: Vec::new(),
|
||||
entry_point: None,
|
||||
max_layer: 0,
|
||||
rng: Arc::new(RwLock::new(rand::rngs::StdRng::from_entropy())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a vector into the index
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// - `vector`: The SemanticVector to insert
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// The assigned node ID
|
||||
pub fn insert(&mut self, vector: SemanticVector) -> Result<usize, FrameworkError> {
|
||||
if vector.embedding.len() != self.config.dimension {
|
||||
return Err(FrameworkError::Config(format!(
|
||||
"Vector dimension mismatch: expected {}, got {}",
|
||||
self.config.dimension,
|
||||
vector.embedding.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let node_id = self.nodes.len();
|
||||
let level = self.random_level();
|
||||
|
||||
// Create new node
|
||||
let mut new_node = HnswNode {
|
||||
vector: vector.embedding,
|
||||
external_id: vector.id,
|
||||
timestamp: vector.timestamp,
|
||||
level,
|
||||
connections: vec![Vec::new(); level + 1],
|
||||
};
|
||||
|
||||
// Insert into graph
|
||||
if self.entry_point.is_none() {
|
||||
// First node - becomes entry point
|
||||
self.nodes.push(new_node);
|
||||
self.entry_point = Some(node_id);
|
||||
self.max_layer = level;
|
||||
return Ok(node_id);
|
||||
}
|
||||
|
||||
// Search for nearest neighbors at insertion point
|
||||
let entry_point = self.entry_point.unwrap();
|
||||
let mut current_nearest = vec![entry_point];
|
||||
|
||||
// Traverse from top layer down to level+1
|
||||
for lc in (level + 1..=self.max_layer).rev() {
|
||||
current_nearest = self.search_layer(&new_node.vector, ¤t_nearest, 1, lc);
|
||||
}
|
||||
|
||||
// Insert from level down to 0
|
||||
for lc in (0..=level).rev() {
|
||||
let candidates = self.search_layer(&new_node.vector, ¤t_nearest, self.config.ef_construction, lc);
|
||||
|
||||
// Select M neighbors
|
||||
let m = if lc == 0 { self.config.m_max_0 } else { self.config.m };
|
||||
let neighbors = self.select_neighbors(&new_node.vector, candidates, m);
|
||||
|
||||
// Add bidirectional links
|
||||
for &neighbor_id in &neighbors {
|
||||
// Add link from new node to neighbor
|
||||
new_node.connections[lc].push(neighbor_id);
|
||||
}
|
||||
|
||||
current_nearest = neighbors.clone();
|
||||
}
|
||||
|
||||
self.nodes.push(new_node);
|
||||
|
||||
// Add reverse links and prune if necessary
|
||||
for lc in 0..=level {
|
||||
let neighbors: Vec<usize> = self.nodes[node_id].connections[lc].clone();
|
||||
for neighbor_id in neighbors {
|
||||
// Only add reverse link if neighbor has this layer
|
||||
if lc < self.nodes[neighbor_id].connections.len() {
|
||||
self.nodes[neighbor_id].connections[lc].push(node_id);
|
||||
|
||||
// Prune if exceeded max connections
|
||||
let m_max = if lc == 0 { self.config.m_max_0 } else { self.config.m };
|
||||
if self.nodes[neighbor_id].connections[lc].len() > m_max {
|
||||
let neighbor_vec = self.nodes[neighbor_id].vector.clone();
|
||||
let candidates = self.nodes[neighbor_id].connections[lc].clone();
|
||||
let pruned = self.select_neighbors(&neighbor_vec, candidates, m_max);
|
||||
self.nodes[neighbor_id].connections[lc] = pruned;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update entry point if new node is at higher layer
|
||||
if level > self.max_layer {
|
||||
self.max_layer = level;
|
||||
self.entry_point = Some(node_id);
|
||||
}
|
||||
|
||||
Ok(node_id)
|
||||
}
|
||||
|
||||
/// Insert a batch of vectors
|
||||
///
|
||||
/// More efficient than inserting one at a time for large batches.
|
||||
pub fn insert_batch(&mut self, vectors: Vec<SemanticVector>) -> Result<Vec<usize>, FrameworkError> {
|
||||
let mut ids = Vec::with_capacity(vectors.len());
|
||||
for vector in vectors {
|
||||
ids.push(self.insert(vector)?);
|
||||
}
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Search for k nearest neighbors
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// - `query`: Query vector (must match index dimension)
|
||||
/// - `k`: Number of neighbors to return
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Up to k nearest neighbors, sorted by distance (ascending)
|
||||
pub fn search_knn(&self, query: &[f32], k: usize) -> Result<Vec<HnswSearchResult>, FrameworkError> {
|
||||
if query.len() != self.config.dimension {
|
||||
return Err(FrameworkError::Config(format!(
|
||||
"Query dimension mismatch: expected {}, got {}",
|
||||
self.config.dimension,
|
||||
query.len()
|
||||
)));
|
||||
}
|
||||
|
||||
if self.entry_point.is_none() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let entry_point = self.entry_point.unwrap();
|
||||
let mut current_nearest = vec![entry_point];
|
||||
|
||||
// Traverse from top layer down to layer 1
|
||||
for lc in (1..=self.max_layer).rev() {
|
||||
current_nearest = self.search_layer(query, ¤t_nearest, 1, lc);
|
||||
}
|
||||
|
||||
// Search layer 0 with ef_search
|
||||
let ef = self.config.ef_search.max(k);
|
||||
let candidates = self.search_layer(query, ¤t_nearest, ef, 0);
|
||||
|
||||
// Convert to search results
|
||||
let results: Vec<HnswSearchResult> = candidates
|
||||
.iter()
|
||||
.take(k)
|
||||
.map(|&node_id| {
|
||||
let node = &self.nodes[node_id];
|
||||
let distance = self.distance(query, &node.vector);
|
||||
let similarity = if self.config.metric == DistanceMetric::Cosine {
|
||||
Some(self.cosine_similarity(query, &node.vector))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
HnswSearchResult {
|
||||
node_id,
|
||||
external_id: node.external_id.clone(),
|
||||
distance,
|
||||
similarity,
|
||||
timestamp: node.timestamp,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Search for all neighbors within a distance threshold
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// - `query`: Query vector
|
||||
/// - `threshold`: Maximum distance (exclusive)
|
||||
/// - `max_results`: Maximum number of results to return (None for unlimited)
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// All neighbors within threshold, sorted by distance
|
||||
pub fn search_threshold(
|
||||
&self,
|
||||
query: &[f32],
|
||||
threshold: f32,
|
||||
max_results: Option<usize>,
|
||||
) -> Result<Vec<HnswSearchResult>, FrameworkError> {
|
||||
// Search with large k first
|
||||
let k = max_results.unwrap_or(1000).max(100);
|
||||
let mut results = self.search_knn(query, k)?;
|
||||
|
||||
// Filter by threshold
|
||||
results.retain(|r| r.distance < threshold);
|
||||
|
||||
// Limit results
|
||||
if let Some(max) = max_results {
|
||||
results.truncate(max);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Get statistics about the index structure
|
||||
pub fn stats(&self) -> HnswStats {
|
||||
let node_count = self.nodes.len();
|
||||
let layer_count = self.max_layer + 1;
|
||||
|
||||
let mut nodes_per_layer = vec![0; layer_count];
|
||||
let mut connections_per_layer = vec![0; layer_count];
|
||||
|
||||
for node in &self.nodes {
|
||||
for layer in 0..=node.level {
|
||||
nodes_per_layer[layer] += 1;
|
||||
connections_per_layer[layer] += node.connections[layer].len();
|
||||
}
|
||||
}
|
||||
|
||||
let avg_connections_per_layer: Vec<f64> = connections_per_layer
|
||||
.iter()
|
||||
.zip(&nodes_per_layer)
|
||||
.map(|(conn, nodes)| {
|
||||
if *nodes > 0 {
|
||||
*conn as f64 / *nodes as f64
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let total_edges: usize = connections_per_layer.iter().sum();
|
||||
|
||||
// Estimate memory: each node stores vector + metadata + connections
|
||||
let estimated_memory_bytes = node_count
|
||||
* (self.config.dimension * 4 // vector (f32)
|
||||
+ 100 // metadata overhead
|
||||
+ self.config.m * 8 * layer_count); // connections (usize)
|
||||
|
||||
HnswStats {
|
||||
node_count,
|
||||
layer_count,
|
||||
nodes_per_layer,
|
||||
avg_connections_per_layer,
|
||||
total_edges,
|
||||
entry_point: self.entry_point,
|
||||
estimated_memory_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Private helper methods =====
|
||||
|
||||
/// Search a single layer for nearest neighbors
|
||||
fn search_layer(&self, query: &[f32], entry_points: &[usize], ef: usize, layer: usize) -> Vec<usize> {
|
||||
let mut visited = HashSet::new();
|
||||
let mut candidates = BinaryHeap::new();
|
||||
let mut nearest = BinaryHeap::new();
|
||||
|
||||
for &ep in entry_points {
|
||||
let dist = self.distance(query, &self.nodes[ep].vector);
|
||||
candidates.push((Reverse(OrderedFloat(dist)), ep));
|
||||
nearest.push((OrderedFloat(dist), ep));
|
||||
visited.insert(ep);
|
||||
}
|
||||
|
||||
while let Some((Reverse(OrderedFloat(dist)), current)) = candidates.pop() {
|
||||
// Check if we should continue searching
|
||||
if let Some(&(OrderedFloat(max_dist), _)) = nearest.peek() {
|
||||
if dist > max_dist {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Explore neighbors
|
||||
if current < self.nodes.len() && layer <= self.nodes[current].level {
|
||||
for &neighbor in &self.nodes[current].connections[layer] {
|
||||
if visited.insert(neighbor) {
|
||||
let neighbor_dist = self.distance(query, &self.nodes[neighbor].vector);
|
||||
|
||||
if let Some(&(OrderedFloat(max_dist), _)) = nearest.peek() {
|
||||
if neighbor_dist < max_dist || nearest.len() < ef {
|
||||
candidates.push((Reverse(OrderedFloat(neighbor_dist)), neighbor));
|
||||
nearest.push((OrderedFloat(neighbor_dist), neighbor));
|
||||
|
||||
if nearest.len() > ef {
|
||||
nearest.pop();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
candidates.push((Reverse(OrderedFloat(neighbor_dist)), neighbor));
|
||||
nearest.push((OrderedFloat(neighbor_dist), neighbor));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract node IDs sorted by distance (ascending)
|
||||
let mut sorted_nearest: Vec<_> = nearest.into_iter().collect();
|
||||
sorted_nearest.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
sorted_nearest.into_iter().map(|(_, id)| id).collect()
|
||||
}
|
||||
|
||||
/// Select M neighbors from candidates using heuristic
|
||||
fn select_neighbors(&self, base: &[f32], candidates: Vec<usize>, m: usize) -> Vec<usize> {
|
||||
if candidates.len() <= m {
|
||||
return candidates;
|
||||
}
|
||||
|
||||
// Simple heuristic: keep nearest M by distance
|
||||
let mut with_distances: Vec<_> = candidates
|
||||
.into_iter()
|
||||
.map(|id| {
|
||||
let dist = self.distance(base, &self.nodes[id].vector);
|
||||
(OrderedFloat(dist), id)
|
||||
})
|
||||
.collect();
|
||||
|
||||
with_distances.sort_by_key(|(dist, _)| *dist);
|
||||
with_distances.into_iter().take(m).map(|(_, id)| id).collect()
|
||||
}
|
||||
|
||||
/// Compute distance between two vectors
|
||||
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
match self.config.metric {
|
||||
DistanceMetric::Cosine => {
|
||||
let similarity = self.cosine_similarity(a, b);
|
||||
// Convert to angular distance: arccos(sim) / π ∈ [0, 1]
|
||||
similarity.max(-1.0).min(1.0).acos() / std::f32::consts::PI
|
||||
}
|
||||
DistanceMetric::Euclidean => {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
DistanceMetric::Manhattan => {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).abs())
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(dot / (norm_a * norm_b)).max(-1.0).min(1.0)
|
||||
}
|
||||
|
||||
/// Randomly assign a layer to a new node
|
||||
fn random_level(&self) -> usize {
|
||||
let mut rng = self.rng.write().unwrap();
|
||||
let uniform: f64 = rng.gen();
|
||||
(-uniform.ln() * self.config.ml).floor() as usize
|
||||
}
|
||||
|
||||
/// Get the underlying vector for a node
|
||||
pub fn get_vector(&self, node_id: usize) -> Option<&Vec<f32>> {
|
||||
self.nodes.get(node_id).map(|n| &n.vector)
|
||||
}
|
||||
|
||||
/// Get the external ID for a node
|
||||
pub fn get_external_id(&self, node_id: usize) -> Option<&str> {
|
||||
self.nodes.get(node_id).map(|n| n.external_id.as_str())
|
||||
}
|
||||
|
||||
/// Get total number of nodes in the index
|
||||
pub fn len(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Check if index is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.nodes.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HnswIndex {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for f32 that implements Ord for use in BinaryHeap
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
|
||||
struct OrderedFloat(f32);
|
||||
|
||||
impl Eq for OrderedFloat {}
|
||||
|
||||
impl Ord for OrderedFloat {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.0.partial_cmp(&other.0).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use crate::ruvector_native::Domain;
|
||||
|
||||
fn create_test_vector(id: &str, embedding: Vec<f32>) -> SemanticVector {
|
||||
SemanticVector {
|
||||
id: id.to_string(),
|
||||
embedding,
|
||||
domain: Domain::Climate,
|
||||
timestamp: Utc::now(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_basic_insert_search() {
|
||||
let config = HnswConfig {
|
||||
dimension: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let mut index = HnswIndex::with_config(config);
|
||||
|
||||
// Insert vectors
|
||||
let v1 = create_test_vector("v1", vec![1.0, 0.0, 0.0]);
|
||||
let v2 = create_test_vector("v2", vec![0.0, 1.0, 0.0]);
|
||||
let v3 = create_test_vector("v3", vec![0.9, 0.1, 0.0]);
|
||||
|
||||
index.insert(v1).unwrap();
|
||||
index.insert(v2).unwrap();
|
||||
index.insert(v3).unwrap();
|
||||
|
||||
assert_eq!(index.len(), 3);
|
||||
|
||||
// Search for nearest to v1
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let results = index.search_knn(&query, 2).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].external_id, "v1"); // Exact match
|
||||
assert_eq!(results[1].external_id, "v3"); // Close match
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_batch_insert() {
|
||||
let config = HnswConfig {
|
||||
dimension: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let mut index = HnswIndex::with_config(config);
|
||||
|
||||
let vectors = vec![
|
||||
create_test_vector("v1", vec![1.0, 0.0]),
|
||||
create_test_vector("v2", vec![0.0, 1.0]),
|
||||
create_test_vector("v3", vec![1.0, 1.0]),
|
||||
];
|
||||
|
||||
let ids = index.insert_batch(vectors).unwrap();
|
||||
assert_eq!(ids.len(), 3);
|
||||
assert_eq!(index.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_threshold_search() {
|
||||
let config = HnswConfig {
|
||||
dimension: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let mut index = HnswIndex::with_config(config);
|
||||
|
||||
// Insert vectors at different distances
|
||||
index.insert(create_test_vector("close", vec![1.0, 0.1])).unwrap();
|
||||
index.insert(create_test_vector("medium", vec![0.7, 0.7])).unwrap();
|
||||
index.insert(create_test_vector("far", vec![0.0, 1.0])).unwrap();
|
||||
|
||||
let query = vec![1.0, 0.0];
|
||||
let results = index.search_threshold(&query, 0.3, None).unwrap();
|
||||
|
||||
// Should find only close vectors
|
||||
assert!(results.len() >= 1);
|
||||
assert!(results.iter().all(|r| r.distance < 0.3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_cosine_similarity() {
|
||||
let config = HnswConfig {
|
||||
dimension: 3,
|
||||
metric: DistanceMetric::Cosine,
|
||||
..Default::default()
|
||||
};
|
||||
let mut index = HnswIndex::with_config(config);
|
||||
|
||||
let v1 = create_test_vector("identical", vec![1.0, 0.0, 0.0]);
|
||||
let v2 = create_test_vector("orthogonal", vec![0.0, 1.0, 0.0]);
|
||||
let v3 = create_test_vector("opposite", vec![-1.0, 0.0, 0.0]);
|
||||
|
||||
index.insert(v1).unwrap();
|
||||
index.insert(v2).unwrap();
|
||||
index.insert(v3).unwrap();
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let results = index.search_knn(&query, 3).unwrap();
|
||||
|
||||
// Identical should be closest
|
||||
assert_eq!(results[0].external_id, "identical");
|
||||
assert!(results[0].distance < 0.01);
|
||||
|
||||
// Opposite should be farthest
|
||||
assert_eq!(results[2].external_id, "opposite");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_stats() {
|
||||
let config = HnswConfig {
|
||||
dimension: 2,
|
||||
m: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let mut index = HnswIndex::with_config(config);
|
||||
|
||||
for i in 0..10 {
|
||||
let vec = create_test_vector(&format!("v{}", i), vec![i as f32, i as f32]);
|
||||
index.insert(vec).unwrap();
|
||||
}
|
||||
|
||||
let stats = index.stats();
|
||||
assert_eq!(stats.node_count, 10);
|
||||
assert!(stats.layer_count > 0);
|
||||
assert_eq!(stats.nodes_per_layer[0], 10); // All nodes in layer 0
|
||||
assert!(stats.total_edges > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() {
|
||||
let config = HnswConfig {
|
||||
dimension: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let mut index = HnswIndex::with_config(config);
|
||||
|
||||
let bad_vector = create_test_vector("bad", vec![1.0, 2.0]); // Wrong dimension
|
||||
let result = index.insert(bad_vector);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_index_search() {
|
||||
let index = HnswIndex::new();
|
||||
let query = vec![1.0; 128];
|
||||
let results = index.search_knn(&query, 5).unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
}
|
||||
342
vendor/ruvector/examples/data/framework/src/ingester.rs
vendored
Normal file
342
vendor/ruvector/examples/data/framework/src/ingester.rs
vendored
Normal file
@@ -0,0 +1,342 @@
|
||||
//! Data ingestion pipeline for streaming data into RuVector
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{DataRecord, DataSource, FrameworkError, Result};
|
||||
|
||||
/// Configuration for data ingestion
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IngestionConfig {
|
||||
/// Batch size for fetching
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Maximum concurrent fetches
|
||||
pub max_concurrent: usize,
|
||||
|
||||
/// Retry count on failure
|
||||
pub retry_count: u32,
|
||||
|
||||
/// Delay between retries (ms)
|
||||
pub retry_delay_ms: u64,
|
||||
|
||||
/// Enable deduplication
|
||||
pub deduplicate: bool,
|
||||
|
||||
/// Rate limit (requests per second, 0 = unlimited)
|
||||
pub rate_limit: u32,
|
||||
}
|
||||
|
||||
impl Default for IngestionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
batch_size: 1000,
|
||||
max_concurrent: 4,
|
||||
retry_count: 3,
|
||||
retry_delay_ms: 1000,
|
||||
deduplicate: true,
|
||||
rate_limit: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a specific data source
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SourceConfig {
|
||||
/// Source identifier
|
||||
pub source_id: String,
|
||||
|
||||
/// API base URL
|
||||
pub base_url: String,
|
||||
|
||||
/// API key (if required)
|
||||
pub api_key: Option<String>,
|
||||
|
||||
/// Additional headers
|
||||
pub headers: HashMap<String, String>,
|
||||
|
||||
/// Custom parameters
|
||||
pub params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Statistics for ingestion process
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct IngestionStats {
|
||||
/// Total records fetched
|
||||
pub records_fetched: u64,
|
||||
|
||||
/// Batches processed
|
||||
pub batches_processed: u64,
|
||||
|
||||
/// Retries performed
|
||||
pub retries: u64,
|
||||
|
||||
/// Errors encountered
|
||||
pub errors: u64,
|
||||
|
||||
/// Duplicates skipped
|
||||
pub duplicates_skipped: u64,
|
||||
|
||||
/// Bytes downloaded
|
||||
pub bytes_downloaded: u64,
|
||||
|
||||
/// Average batch fetch time (ms)
|
||||
pub avg_batch_time_ms: f64,
|
||||
}
|
||||
|
||||
/// Data ingestion pipeline
|
||||
pub struct DataIngester {
|
||||
config: IngestionConfig,
|
||||
stats: Arc<std::sync::RwLock<IngestionStats>>,
|
||||
seen_ids: Arc<std::sync::RwLock<std::collections::HashSet<String>>>,
|
||||
}
|
||||
|
||||
impl DataIngester {
|
||||
/// Create a new data ingester
|
||||
pub fn new(config: IngestionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
stats: Arc::new(std::sync::RwLock::new(IngestionStats::default())),
|
||||
seen_ids: Arc::new(std::sync::RwLock::new(std::collections::HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Ingest all data from a source
|
||||
pub async fn ingest_all<S: DataSource>(&self, source: &S) -> Result<Vec<DataRecord>> {
|
||||
let mut all_records = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let (batch, next_cursor) = self
|
||||
.fetch_with_retry(source, cursor.clone(), self.config.batch_size)
|
||||
.await?;
|
||||
|
||||
if batch.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Deduplicate if enabled
|
||||
let records = if self.config.deduplicate {
|
||||
self.deduplicate_batch(batch)
|
||||
} else {
|
||||
batch
|
||||
};
|
||||
|
||||
all_records.extend(records);
|
||||
|
||||
{
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.batches_processed += 1;
|
||||
}
|
||||
|
||||
cursor = next_cursor;
|
||||
if cursor.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if self.config.rate_limit > 0 {
|
||||
let delay = 1000 / self.config.rate_limit as u64;
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_records)
|
||||
}
|
||||
|
||||
/// Stream records with backpressure
|
||||
pub async fn stream_records<S: DataSource + 'static>(
|
||||
&self,
|
||||
source: Arc<S>,
|
||||
buffer_size: usize,
|
||||
) -> Result<mpsc::Receiver<DataRecord>> {
|
||||
let (tx, rx) = mpsc::channel(buffer_size);
|
||||
let config = self.config.clone();
|
||||
let stats = self.stats.clone();
|
||||
let seen_ids = self.seen_ids.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
match source
|
||||
.fetch_batch(cursor.clone(), config.batch_size)
|
||||
.await
|
||||
{
|
||||
Ok((batch, next_cursor)) => {
|
||||
if batch.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
for record in batch {
|
||||
// Deduplicate
|
||||
if config.deduplicate {
|
||||
let mut ids = seen_ids.write().unwrap();
|
||||
if ids.contains(&record.id) {
|
||||
continue;
|
||||
}
|
||||
ids.insert(record.id.clone());
|
||||
}
|
||||
|
||||
if tx.send(record).await.is_err() {
|
||||
return; // Receiver dropped
|
||||
}
|
||||
|
||||
let mut s = stats.write().unwrap();
|
||||
s.records_fetched += 1;
|
||||
}
|
||||
|
||||
cursor = next_cursor;
|
||||
if cursor.is_none() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let mut s = stats.write().unwrap();
|
||||
s.errors += 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
/// Fetch a batch with retry logic
|
||||
async fn fetch_with_retry<S: DataSource>(
|
||||
&self,
|
||||
source: &S,
|
||||
cursor: Option<String>,
|
||||
batch_size: usize,
|
||||
) -> Result<(Vec<DataRecord>, Option<String>)> {
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 0..=self.config.retry_count {
|
||||
if attempt > 0 {
|
||||
let delay = self.config.retry_delay_ms * (1 << (attempt - 1));
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
|
||||
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.retries += 1;
|
||||
}
|
||||
|
||||
match source.fetch_batch(cursor.clone(), batch_size).await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.errors += 1;
|
||||
|
||||
Err(last_error.unwrap_or_else(|| FrameworkError::Ingestion("Unknown error".to_string())))
|
||||
}
|
||||
|
||||
/// Deduplicate a batch of records
|
||||
fn deduplicate_batch(&self, batch: Vec<DataRecord>) -> Vec<DataRecord> {
|
||||
let mut unique = Vec::with_capacity(batch.len());
|
||||
let mut seen = self.seen_ids.write().unwrap();
|
||||
|
||||
for record in batch {
|
||||
if !seen.contains(&record.id) {
|
||||
seen.insert(record.id.clone());
|
||||
unique.push(record);
|
||||
} else {
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.duplicates_skipped += 1;
|
||||
}
|
||||
}
|
||||
|
||||
unique
|
||||
}
|
||||
|
||||
/// Get current ingestion statistics
|
||||
pub fn stats(&self) -> IngestionStats {
|
||||
self.stats.read().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Reset statistics
|
||||
pub fn reset_stats(&self) {
|
||||
*self.stats.write().unwrap() = IngestionStats::default();
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for transforming records during ingestion
|
||||
#[async_trait]
|
||||
pub trait RecordTransformer: Send + Sync {
|
||||
/// Transform a record
|
||||
async fn transform(&self, record: DataRecord) -> Result<DataRecord>;
|
||||
|
||||
/// Filter records (return false to skip)
|
||||
fn filter(&self, record: &DataRecord) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Identity transformer (no-op)
|
||||
pub struct IdentityTransformer;
|
||||
|
||||
#[async_trait]
|
||||
impl RecordTransformer for IdentityTransformer {
|
||||
async fn transform(&self, record: DataRecord) -> Result<DataRecord> {
|
||||
Ok(record)
|
||||
}
|
||||
}
|
||||
|
||||
/// Batched ingestion with transformations
|
||||
pub struct BatchIngester<T: RecordTransformer> {
|
||||
ingester: DataIngester,
|
||||
transformer: T,
|
||||
}
|
||||
|
||||
impl<T: RecordTransformer> BatchIngester<T> {
|
||||
/// Create a new batch ingester with transformer
|
||||
pub fn new(config: IngestionConfig, transformer: T) -> Self {
|
||||
Self {
|
||||
ingester: DataIngester::new(config),
|
||||
transformer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Ingest and transform all records
|
||||
pub async fn ingest_all<S: DataSource>(&self, source: &S) -> Result<Vec<DataRecord>> {
|
||||
let raw_records = self.ingester.ingest_all(source).await?;
|
||||
|
||||
let mut transformed = Vec::with_capacity(raw_records.len());
|
||||
for record in raw_records {
|
||||
if self.transformer.filter(&record) {
|
||||
let t = self.transformer.transform(record).await?;
|
||||
transformed.push(t);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(transformed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = IngestionConfig::default();
|
||||
assert_eq!(config.batch_size, 1000);
|
||||
assert!(config.deduplicate);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ingester_creation() {
|
||||
let config = IngestionConfig::default();
|
||||
let ingester = DataIngester::new(config);
|
||||
let stats = ingester.stats();
|
||||
assert_eq!(stats.records_fetched, 0);
|
||||
}
|
||||
}
|
||||
470
vendor/ruvector/examples/data/framework/src/lib.rs
vendored
Normal file
470
vendor/ruvector/examples/data/framework/src/lib.rs
vendored
Normal file
@@ -0,0 +1,470 @@
|
||||
//! # RuVector Data Discovery Framework
|
||||
//!
|
||||
//! Core traits and types for building dataset integrations with RuVector's
|
||||
//! vector memory, graph structures, and dynamic minimum cut algorithms.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The framework provides three core abstractions:
|
||||
//!
|
||||
//! 1. **DataIngester**: Streaming data ingestion with batched graph/vector updates
|
||||
//! 2. **CoherenceEngine**: Real-time coherence signal computation using min-cut
|
||||
//! 3. **DiscoveryEngine**: Pattern detection for emerging structures and anomalies
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_data_framework::{
|
||||
//! DataIngester, CoherenceEngine, DiscoveryEngine,
|
||||
//! IngestionConfig, CoherenceConfig, DiscoveryConfig,
|
||||
//! };
|
||||
//!
|
||||
//! // Configure the discovery pipeline
|
||||
//! let ingester = DataIngester::new(ingestion_config);
|
||||
//! let coherence = CoherenceEngine::new(coherence_config);
|
||||
//! let discovery = DiscoveryEngine::new(discovery_config);
|
||||
//!
|
||||
//! // Stream data and detect patterns
|
||||
//! let stream = ingester.stream_from_source(source).await?;
|
||||
//! let signals = coherence.compute_signals(stream).await?;
|
||||
//! let patterns = discovery.detect_patterns(signals).await?;
|
||||
//! ```
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![warn(clippy::all)]
|
||||
|
||||
pub mod academic_clients;
|
||||
pub mod api_clients;
|
||||
pub mod arxiv_client;
|
||||
pub mod biorxiv_client;
|
||||
pub mod coherence;
|
||||
pub mod crossref_client;
|
||||
pub mod discovery;
|
||||
pub mod dynamic_mincut;
|
||||
pub mod economic_clients;
|
||||
pub mod export;
|
||||
pub mod finance_clients;
|
||||
pub mod forecasting;
|
||||
pub mod genomics_clients;
|
||||
pub mod geospatial_clients;
|
||||
pub mod government_clients;
|
||||
pub mod hnsw;
|
||||
pub mod cut_aware_hnsw;
|
||||
pub mod ingester;
|
||||
pub mod mcp_server;
|
||||
pub mod medical_clients;
|
||||
pub mod ml_clients;
|
||||
pub mod news_clients;
|
||||
pub mod optimized;
|
||||
pub mod patent_clients;
|
||||
pub mod persistence;
|
||||
pub mod physics_clients;
|
||||
pub mod realtime;
|
||||
pub mod ruvector_native;
|
||||
pub mod semantic_scholar;
|
||||
pub mod space_clients;
|
||||
pub mod streaming;
|
||||
pub mod transportation_clients;
|
||||
pub mod utils;
|
||||
pub mod visualization;
|
||||
pub mod wiki_clients;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
// Re-exports
|
||||
pub use academic_clients::{CoreClient, EricClient, UnpaywallClient};
|
||||
pub use api_clients::{EdgarClient, Embedder, NoaaClient, OpenAlexClient, SimpleEmbedder};
|
||||
#[cfg(feature = "onnx-embeddings")]
|
||||
pub use api_clients::OnnxEmbedder;
|
||||
#[cfg(feature = "onnx-embeddings")]
|
||||
pub use ruvector_onnx_embeddings::{PretrainedModel, EmbedderConfig, PoolingStrategy};
|
||||
pub use arxiv_client::ArxivClient;
|
||||
pub use biorxiv_client::{BiorxivClient, MedrxivClient};
|
||||
pub use crossref_client::CrossRefClient;
|
||||
pub use economic_clients::{AlphaVantageClient, FredClient, WorldBankClient};
|
||||
pub use finance_clients::{BlsClient, CoinGeckoClient, EcbClient, FinnhubClient, TwelveDataClient};
|
||||
pub use genomics_clients::{EnsemblClient, GwasClient, NcbiClient, UniProtClient};
|
||||
pub use geospatial_clients::{GeonamesClient, NominatimClient, OpenElevationClient, OverpassClient};
|
||||
pub use government_clients::{
|
||||
CensusClient, DataGovClient, EuOpenDataClient, UkGovClient, UNDataClient,
|
||||
WorldBankClient as WorldBankGovClient,
|
||||
};
|
||||
pub use medical_clients::{ClinicalTrialsClient, FdaClient, PubMedClient};
|
||||
pub use ml_clients::{
|
||||
HuggingFaceClient, HuggingFaceDataset, HuggingFaceModel, OllamaClient, OllamaModel,
|
||||
PapersWithCodeClient, PaperWithCodeDataset, PaperWithCodePaper, ReplicateClient,
|
||||
ReplicateModel, TogetherAiClient, TogetherModel,
|
||||
};
|
||||
pub use news_clients::{GuardianClient, HackerNewsClient, NewsDataClient, RedditClient};
|
||||
pub use patent_clients::{EpoClient, UsptoPatentClient};
|
||||
pub use physics_clients::{ArgoClient, CernOpenDataClient, GeoUtils, MaterialsProjectClient, UsgsEarthquakeClient};
|
||||
pub use semantic_scholar::SemanticScholarClient;
|
||||
pub use space_clients::{AstronomyClient, ExoplanetClient, NasaClient, SpaceXClient};
|
||||
pub use transportation_clients::{GtfsClient, MobilityDatabaseClient, OpenChargeMapClient, OpenRouteServiceClient};
|
||||
pub use wiki_clients::{WikidataClient, WikidataEntity, WikipediaClient};
|
||||
pub use coherence::{
|
||||
CoherenceBoundary, CoherenceConfig, CoherenceEngine, CoherenceEvent, CoherenceSignal,
|
||||
};
|
||||
pub use cut_aware_hnsw::{
|
||||
CutAwareHNSW, CutAwareConfig, CutAwareMetrics, CoherenceZone,
|
||||
SearchResult as CutAwareSearchResult, EdgeUpdate as CutAwareEdgeUpdate, UpdateKind, LayerCutStats,
|
||||
};
|
||||
pub use discovery::{
|
||||
DiscoveryConfig, DiscoveryEngine, DiscoveryPattern, PatternCategory, PatternStrength,
|
||||
};
|
||||
pub use dynamic_mincut::{
|
||||
CutGatedSearch, CutWatcherConfig, DynamicCutWatcher, DynamicMinCutError,
|
||||
EdgeUpdate as DynamicEdgeUpdate, EdgeUpdateType, EulerTourTree, HNSWGraph,
|
||||
LocalCut, LocalMinCutProcedure, WatcherStats,
|
||||
};
|
||||
pub use export::{
|
||||
export_all, export_coherence_csv, export_dot, export_graphml, export_patterns_csv,
|
||||
export_patterns_with_evidence_csv, ExportFilter,
|
||||
};
|
||||
pub use forecasting::{CoherenceForecaster, CrossDomainForecaster, Forecast, Trend};
|
||||
pub use ingester::{DataIngester, IngestionConfig, IngestionStats, SourceConfig};
|
||||
pub use realtime::{FeedItem, FeedSource, NewsAggregator, NewsSource, RealTimeEngine};
|
||||
pub use ruvector_native::{
|
||||
CoherenceHistoryEntry, CoherenceSnapshot, Domain, DiscoveredPattern,
|
||||
GraphExport, NativeDiscoveryEngine, NativeEngineConfig, SemanticVector,
|
||||
};
|
||||
pub use streaming::{StreamingConfig, StreamingEngine, StreamingEngineBuilder, StreamingMetrics};
|
||||
|
||||
/// Framework error types
|
||||
#[derive(Error, Debug)]
|
||||
pub enum FrameworkError {
|
||||
/// Data ingestion failed
|
||||
#[error("Ingestion error: {0}")]
|
||||
Ingestion(String),
|
||||
|
||||
/// Coherence computation failed
|
||||
#[error("Coherence error: {0}")]
|
||||
Coherence(String),
|
||||
|
||||
/// Discovery algorithm failed
|
||||
#[error("Discovery error: {0}")]
|
||||
Discovery(String),
|
||||
|
||||
/// Network/API error
|
||||
#[error("Network error: {0}")]
|
||||
Network(#[from] reqwest::Error),
|
||||
|
||||
/// Serialization error
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
|
||||
/// Graph operation failed
|
||||
#[error("Graph error: {0}")]
|
||||
Graph(String),
|
||||
|
||||
/// Configuration error
|
||||
#[error("Config error: {0}")]
|
||||
Config(String),
|
||||
}
|
||||
|
||||
/// Result type for framework operations
|
||||
pub type Result<T> = std::result::Result<T, FrameworkError>;
|
||||
|
||||
/// A timestamped data record from any source
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DataRecord {
|
||||
/// Unique identifier
|
||||
pub id: String,
|
||||
|
||||
/// Source dataset (e.g., "openalex", "noaa", "edgar")
|
||||
pub source: String,
|
||||
|
||||
/// Record type within source (e.g., "work", "author", "filing")
|
||||
pub record_type: String,
|
||||
|
||||
/// Timestamp when data was observed/published
|
||||
pub timestamp: DateTime<Utc>,
|
||||
|
||||
/// Raw data payload
|
||||
pub data: serde_json::Value,
|
||||
|
||||
/// Pre-computed embedding vector (optional)
|
||||
pub embedding: Option<Vec<f32>>,
|
||||
|
||||
/// Relationships to other records
|
||||
pub relationships: Vec<Relationship>,
|
||||
}
|
||||
|
||||
/// A relationship between two records
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Relationship {
|
||||
/// Target record ID
|
||||
pub target_id: String,
|
||||
|
||||
/// Relationship type (e.g., "cites", "authored_by", "filed_by")
|
||||
pub rel_type: String,
|
||||
|
||||
/// Relationship weight/strength
|
||||
pub weight: f64,
|
||||
|
||||
/// Additional properties
|
||||
pub properties: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Trait for data sources that can be ingested
|
||||
#[async_trait]
|
||||
pub trait DataSource: Send + Sync {
|
||||
/// Source identifier
|
||||
fn source_id(&self) -> &str;
|
||||
|
||||
/// Fetch a batch of records starting from cursor
|
||||
async fn fetch_batch(
|
||||
&self,
|
||||
cursor: Option<String>,
|
||||
batch_size: usize,
|
||||
) -> Result<(Vec<DataRecord>, Option<String>)>;
|
||||
|
||||
/// Get total record count (if known)
|
||||
async fn total_count(&self) -> Result<Option<u64>>;
|
||||
|
||||
/// Check if source is available
|
||||
async fn health_check(&self) -> Result<bool>;
|
||||
}
|
||||
|
||||
/// Trait for computing embeddings from records
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Send + Sync {
|
||||
/// Compute embedding for a single record
|
||||
async fn embed_record(&self, record: &DataRecord) -> Result<Vec<f32>>;
|
||||
|
||||
/// Compute embeddings for a batch of records
|
||||
async fn embed_batch(&self, records: &[DataRecord]) -> Result<Vec<Vec<f32>>>;
|
||||
|
||||
/// Embedding dimension
|
||||
fn dimension(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Trait for graph building from records
|
||||
pub trait GraphBuilder: Send + Sync {
|
||||
/// Add a node from a data record
|
||||
fn add_node(&mut self, record: &DataRecord) -> Result<u64>;
|
||||
|
||||
/// Add an edge between nodes
|
||||
fn add_edge(&mut self, source: u64, target: u64, weight: f64, rel_type: &str) -> Result<()>;
|
||||
|
||||
/// Get node count
|
||||
fn node_count(&self) -> usize;
|
||||
|
||||
/// Get edge count
|
||||
fn edge_count(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Temporal window for time-series analysis
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct TemporalWindow {
|
||||
/// Window start
|
||||
pub start: DateTime<Utc>,
|
||||
|
||||
/// Window end
|
||||
pub end: DateTime<Utc>,
|
||||
|
||||
/// Window identifier (for sliding windows)
|
||||
pub window_id: u64,
|
||||
}
|
||||
|
||||
impl TemporalWindow {
|
||||
/// Create a new temporal window
|
||||
pub fn new(start: DateTime<Utc>, end: DateTime<Utc>, window_id: u64) -> Self {
|
||||
Self {
|
||||
start,
|
||||
end,
|
||||
window_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Duration in seconds
|
||||
pub fn duration_secs(&self) -> i64 {
|
||||
(self.end - self.start).num_seconds()
|
||||
}
|
||||
|
||||
/// Check if timestamp falls within window
|
||||
pub fn contains(&self, timestamp: DateTime<Utc>) -> bool {
|
||||
timestamp >= self.start && timestamp < self.end
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for a discovery session
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct DiscoveryStats {
|
||||
/// Records processed
|
||||
pub records_processed: u64,
|
||||
|
||||
/// Nodes in graph
|
||||
pub nodes_created: u64,
|
||||
|
||||
/// Edges in graph
|
||||
pub edges_created: u64,
|
||||
|
||||
/// Coherence signals computed
|
||||
pub signals_computed: u64,
|
||||
|
||||
/// Patterns discovered
|
||||
pub patterns_discovered: u64,
|
||||
|
||||
/// Processing duration in milliseconds
|
||||
pub duration_ms: u64,
|
||||
|
||||
/// Peak memory usage in bytes
|
||||
pub peak_memory_bytes: u64,
|
||||
}
|
||||
|
||||
/// Configuration for the entire discovery pipeline
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PipelineConfig {
|
||||
/// Ingestion configuration
|
||||
pub ingestion: IngestionConfig,
|
||||
|
||||
/// Coherence engine configuration
|
||||
pub coherence: CoherenceConfig,
|
||||
|
||||
/// Discovery engine configuration
|
||||
pub discovery: DiscoveryConfig,
|
||||
|
||||
/// Enable parallel processing
|
||||
pub parallel: bool,
|
||||
|
||||
/// Checkpoint interval (records)
|
||||
pub checkpoint_interval: u64,
|
||||
|
||||
/// Output directory for results
|
||||
pub output_dir: String,
|
||||
}
|
||||
|
||||
impl Default for PipelineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
ingestion: IngestionConfig::default(),
|
||||
coherence: CoherenceConfig::default(),
|
||||
discovery: DiscoveryConfig::default(),
|
||||
parallel: true,
|
||||
checkpoint_interval: 10_000,
|
||||
output_dir: "./discovery_output".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Main discovery pipeline orchestrator
|
||||
pub struct DiscoveryPipeline {
|
||||
config: PipelineConfig,
|
||||
ingester: DataIngester,
|
||||
coherence: CoherenceEngine,
|
||||
discovery: DiscoveryEngine,
|
||||
stats: Arc<std::sync::RwLock<DiscoveryStats>>,
|
||||
}
|
||||
|
||||
impl DiscoveryPipeline {
|
||||
/// Create a new discovery pipeline
|
||||
pub fn new(config: PipelineConfig) -> Self {
|
||||
let ingester = DataIngester::new(config.ingestion.clone());
|
||||
let coherence = CoherenceEngine::new(config.coherence.clone());
|
||||
let discovery = DiscoveryEngine::new(config.discovery.clone());
|
||||
|
||||
Self {
|
||||
config,
|
||||
ingester,
|
||||
coherence,
|
||||
discovery,
|
||||
stats: Arc::new(std::sync::RwLock::new(DiscoveryStats::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the discovery pipeline on a data source
|
||||
pub async fn run<S: DataSource>(&mut self, source: S) -> Result<Vec<DiscoveryPattern>> {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Phase 1: Ingest data
|
||||
tracing::info!("Starting ingestion from source: {}", source.source_id());
|
||||
let records = self.ingester.ingest_all(&source).await?;
|
||||
|
||||
{
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.records_processed = records.len() as u64;
|
||||
}
|
||||
|
||||
// Phase 2: Build graph and compute coherence
|
||||
tracing::info!("Computing coherence signals over {} records", records.len());
|
||||
let signals = self.coherence.compute_from_records(&records)?;
|
||||
|
||||
{
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.signals_computed = signals.len() as u64;
|
||||
stats.nodes_created = self.coherence.node_count() as u64;
|
||||
stats.edges_created = self.coherence.edge_count() as u64;
|
||||
}
|
||||
|
||||
// Phase 3: Detect patterns
|
||||
tracing::info!("Detecting discovery patterns");
|
||||
let patterns = self.discovery.detect(&signals)?;
|
||||
|
||||
{
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.patterns_discovered = patterns.len() as u64;
|
||||
stats.duration_ms = start_time.elapsed().as_millis() as u64;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Discovery complete: {} patterns found in {}ms",
|
||||
patterns.len(),
|
||||
start_time.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Get current statistics
|
||||
pub fn stats(&self) -> DiscoveryStats {
|
||||
self.stats.read().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_temporal_window() {
|
||||
let start = Utc::now();
|
||||
let end = start + chrono::Duration::hours(1);
|
||||
let window = TemporalWindow::new(start, end, 1);
|
||||
|
||||
assert_eq!(window.duration_secs(), 3600);
|
||||
assert!(window.contains(start + chrono::Duration::minutes(30)));
|
||||
assert!(!window.contains(start - chrono::Duration::minutes(1)));
|
||||
assert!(!window.contains(end + chrono::Duration::minutes(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_pipeline_config() {
|
||||
let config = PipelineConfig::default();
|
||||
assert!(config.parallel);
|
||||
assert_eq!(config.checkpoint_interval, 10_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_record_serialization() {
|
||||
let record = DataRecord {
|
||||
id: "test-1".to_string(),
|
||||
source: "test".to_string(),
|
||||
record_type: "document".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
data: serde_json::json!({"title": "Test"}),
|
||||
embedding: Some(vec![0.1, 0.2, 0.3]),
|
||||
relationships: vec![],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&record).unwrap();
|
||||
let parsed: DataRecord = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.id, record.id);
|
||||
}
|
||||
}
|
||||
1456
vendor/ruvector/examples/data/framework/src/mcp_server.rs
vendored
Normal file
1456
vendor/ruvector/examples/data/framework/src/mcp_server.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
915
vendor/ruvector/examples/data/framework/src/medical_clients.rs
vendored
Normal file
915
vendor/ruvector/examples/data/framework/src/medical_clients.rs
vendored
Normal file
@@ -0,0 +1,915 @@
|
||||
//! Medical data API integrations for PubMed, ClinicalTrials.gov, and FDA
|
||||
//!
|
||||
//! This module provides async clients for fetching medical literature, clinical trials,
|
||||
//! and FDA data, converting responses to SemanticVector format for RuVector discovery.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{NaiveDate, Utc};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::Deserialize;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Custom deserializer that handles both string and integer values
|
||||
fn deserialize_number_from_string<'de, D>(deserializer: D) -> std::result::Result<Option<i32>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::{self, Visitor};
|
||||
|
||||
struct NumberOrStringVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for NumberOrStringVisitor {
|
||||
type Value = Option<i32>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a number or numeric string")
|
||||
}
|
||||
|
||||
fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Some(v as i32))
|
||||
}
|
||||
|
||||
fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Some(v as i32))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
v.parse::<i32>().map(Some).map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(NumberOrStringVisitor)
|
||||
}
|
||||
|
||||
/// Rate limiting configuration
|
||||
const NCBI_RATE_LIMIT_MS: u64 = 334; // ~3 requests/second without API key
|
||||
const NCBI_WITH_KEY_RATE_LIMIT_MS: u64 = 100; // 10 requests/second with key
|
||||
const FDA_RATE_LIMIT_MS: u64 = 250; // Conservative 4 requests/second
|
||||
const CLINICALTRIALS_RATE_LIMIT_MS: u64 = 100;
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
|
||||
// ============================================================================
|
||||
// PubMed E-utilities Client
|
||||
// ============================================================================
|
||||
|
||||
/// PubMed ESearch API response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PubMedSearchResponse {
|
||||
esearchresult: ESearchResult,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ESearchResult {
|
||||
#[serde(default)]
|
||||
idlist: Vec<String>,
|
||||
#[serde(default)]
|
||||
count: String,
|
||||
}
|
||||
|
||||
/// PubMed EFetch API response (simplified)
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PubMedFetchResponse {
|
||||
#[serde(rename = "PubmedArticleSet")]
|
||||
pubmed_article_set: Option<PubmedArticleSet>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PubmedArticleSet {
|
||||
#[serde(rename = "PubmedArticle", default)]
|
||||
articles: Vec<PubmedArticle>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PubmedArticle {
|
||||
#[serde(rename = "MedlineCitation")]
|
||||
medline_citation: MedlineCitation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MedlineCitation {
|
||||
#[serde(rename = "PMID")]
|
||||
pmid: PmidObject,
|
||||
#[serde(rename = "Article")]
|
||||
article: Article,
|
||||
#[serde(rename = "DateCompleted", default)]
|
||||
date_completed: Option<DateCompleted>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PmidObject {
|
||||
#[serde(rename = "$value", default)]
|
||||
value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Article {
|
||||
#[serde(rename = "ArticleTitle", default)]
|
||||
article_title: Option<String>,
|
||||
#[serde(rename = "Abstract", default)]
|
||||
abstract_data: Option<AbstractData>,
|
||||
#[serde(rename = "AuthorList", default)]
|
||||
author_list: Option<AuthorList>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AbstractData {
|
||||
#[serde(rename = "AbstractText", default)]
|
||||
abstract_text: Vec<AbstractText>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AbstractText {
|
||||
#[serde(rename = "$value", default)]
|
||||
value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthorList {
|
||||
#[serde(rename = "Author", default)]
|
||||
authors: Vec<Author>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Author {
|
||||
#[serde(rename = "LastName", default)]
|
||||
last_name: Option<String>,
|
||||
#[serde(rename = "ForeName", default)]
|
||||
fore_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DateCompleted {
|
||||
#[serde(rename = "Year", default)]
|
||||
year: Option<String>,
|
||||
#[serde(rename = "Month", default)]
|
||||
month: Option<String>,
|
||||
#[serde(rename = "Day", default)]
|
||||
day: Option<String>,
|
||||
}
|
||||
|
||||
/// Client for PubMed medical literature database
|
||||
pub struct PubMedClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl PubMedClient {
|
||||
/// Create a new PubMed client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - Optional NCBI API key (get from https://www.ncbi.nlm.nih.gov/account/)
|
||||
pub fn new(api_key: Option<String>) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
let rate_limit_delay = if api_key.is_some() {
|
||||
Duration::from_millis(NCBI_WITH_KEY_RATE_LIMIT_MS)
|
||||
} else {
|
||||
Duration::from_millis(NCBI_RATE_LIMIT_MS)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://eutils.ncbi.nlm.nih.gov/entrez/eutils".to_string(),
|
||||
api_key,
|
||||
rate_limit_delay,
|
||||
embedder: Arc::new(SimpleEmbedder::new(384)), // Higher dimension for medical text
|
||||
})
|
||||
}
|
||||
|
||||
/// Search PubMed articles by query
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query (e.g., "COVID-19 vaccine", "alzheimer's treatment")
|
||||
/// * `max_results` - Maximum number of results to return
|
||||
pub async fn search_articles(
|
||||
&self,
|
||||
query: &str,
|
||||
max_results: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
// Step 1: Search for PMIDs
|
||||
let pmids = self.search_pmids(query, max_results).await?;
|
||||
|
||||
if pmids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Step 2: Fetch full abstracts for PMIDs
|
||||
self.fetch_abstracts(&pmids).await
|
||||
}
|
||||
|
||||
/// Search for PMIDs matching query
|
||||
async fn search_pmids(&self, query: &str, max_results: usize) -> Result<Vec<String>> {
|
||||
let mut url = format!(
|
||||
"{}/esearch.fcgi?db=pubmed&term={}&retmode=json&retmax={}",
|
||||
self.base_url,
|
||||
urlencoding::encode(query),
|
||||
max_results
|
||||
);
|
||||
|
||||
if let Some(key) = &self.api_key {
|
||||
url.push_str(&format!("&api_key={}", key));
|
||||
}
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let search_response: PubMedSearchResponse = response.json().await?;
|
||||
|
||||
Ok(search_response.esearchresult.idlist)
|
||||
}
|
||||
|
||||
/// Fetch full article abstracts by PMIDs
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `pmids` - List of PubMed IDs to fetch
|
||||
pub async fn fetch_abstracts(&self, pmids: &[String]) -> Result<Vec<SemanticVector>> {
|
||||
if pmids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Batch PMIDs (max 200 per request)
|
||||
let mut all_vectors = Vec::new();
|
||||
|
||||
for chunk in pmids.chunks(200) {
|
||||
let pmid_list = chunk.join(",");
|
||||
let mut url = format!(
|
||||
"{}/efetch.fcgi?db=pubmed&id={}&retmode=xml",
|
||||
self.base_url, pmid_list
|
||||
);
|
||||
|
||||
if let Some(key) = &self.api_key {
|
||||
url.push_str(&format!("&api_key={}", key));
|
||||
}
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let xml_text = response.text().await?;
|
||||
|
||||
// Parse XML response
|
||||
let vectors = self.parse_xml_to_vectors(&xml_text)?;
|
||||
all_vectors.extend(vectors);
|
||||
}
|
||||
|
||||
Ok(all_vectors)
|
||||
}
|
||||
|
||||
/// Parse PubMed XML response to SemanticVectors
|
||||
fn parse_xml_to_vectors(&self, xml: &str) -> Result<Vec<SemanticVector>> {
|
||||
// Use quick-xml for parsing
|
||||
let fetch_response: PubMedFetchResponse = quick_xml::de::from_str(xml)
|
||||
.map_err(|e| FrameworkError::Config(format!("XML parse error: {}", e)))?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
|
||||
if let Some(article_set) = fetch_response.pubmed_article_set {
|
||||
for pubmed_article in article_set.articles {
|
||||
let citation = pubmed_article.medline_citation;
|
||||
let article = citation.article;
|
||||
|
||||
let pmid = citation.pmid.value;
|
||||
let title = article.article_title.unwrap_or_else(|| "Untitled".to_string());
|
||||
|
||||
// Extract abstract text
|
||||
let abstract_text = article
|
||||
.abstract_data
|
||||
.as_ref()
|
||||
.map(|abs| {
|
||||
abs.abstract_text
|
||||
.iter()
|
||||
.filter_map(|at| at.value.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Create combined text for embedding
|
||||
let text = format!("{} {}", title, abstract_text);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
// Parse publication date
|
||||
let timestamp = citation
|
||||
.date_completed
|
||||
.as_ref()
|
||||
.and_then(|date| {
|
||||
let year = date.year.as_ref()?.parse::<i32>().ok()?;
|
||||
let month = date.month.as_ref()?.parse::<u32>().ok()?;
|
||||
let day = date.day.as_ref()?.parse::<u32>().ok()?;
|
||||
NaiveDate::from_ymd_opt(year, month, day)
|
||||
})
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Extract author names
|
||||
let authors = article
|
||||
.author_list
|
||||
.as_ref()
|
||||
.map(|al| {
|
||||
al.authors
|
||||
.iter()
|
||||
.filter_map(|a| {
|
||||
let last = a.last_name.as_deref().unwrap_or("");
|
||||
let first = a.fore_name.as_deref().unwrap_or("");
|
||||
if !last.is_empty() {
|
||||
Some(format!("{} {}", first, last))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("pmid".to_string(), pmid.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
metadata.insert("abstract".to_string(), abstract_text);
|
||||
metadata.insert("authors".to_string(), authors);
|
||||
metadata.insert("source".to_string(), "pubmed".to_string());
|
||||
|
||||
vectors.push(SemanticVector {
|
||||
id: format!("PMID:{}", pmid),
|
||||
embedding,
|
||||
domain: Domain::Medical,
|
||||
timestamp,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ClinicalTrials.gov Client
|
||||
// ============================================================================
|
||||
|
||||
/// ClinicalTrials.gov API response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ClinicalTrialsResponse {
|
||||
#[serde(default)]
|
||||
studies: Vec<ClinicalStudy>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ClinicalStudy {
|
||||
#[serde(rename = "protocolSection")]
|
||||
protocol_section: ProtocolSection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ProtocolSection {
|
||||
#[serde(rename = "identificationModule")]
|
||||
identification: IdentificationModule,
|
||||
#[serde(rename = "statusModule")]
|
||||
status: StatusModule,
|
||||
#[serde(rename = "descriptionModule", default)]
|
||||
description: Option<DescriptionModule>,
|
||||
#[serde(rename = "conditionsModule", default)]
|
||||
conditions: Option<ConditionsModule>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IdentificationModule {
|
||||
#[serde(rename = "nctId")]
|
||||
nct_id: String,
|
||||
#[serde(rename = "briefTitle", default)]
|
||||
brief_title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StatusModule {
|
||||
#[serde(rename = "overallStatus", default)]
|
||||
overall_status: Option<String>,
|
||||
#[serde(rename = "startDateStruct", default)]
|
||||
start_date: Option<DateStruct>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DateStruct {
|
||||
#[serde(default)]
|
||||
date: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DescriptionModule {
|
||||
#[serde(rename = "briefSummary", default)]
|
||||
brief_summary: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConditionsModule {
|
||||
#[serde(default)]
|
||||
conditions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Client for ClinicalTrials.gov database
|
||||
pub struct ClinicalTrialsClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl ClinicalTrialsClient {
|
||||
/// Create a new ClinicalTrials.gov client
|
||||
pub fn new() -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://clinicaltrials.gov/api/v2".to_string(),
|
||||
rate_limit_delay: Duration::from_millis(CLINICALTRIALS_RATE_LIMIT_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(384)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Search clinical trials by condition
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `condition` - Medical condition to search (e.g., "diabetes", "cancer")
|
||||
/// * `status` - Optional recruitment status filter (e.g., "RECRUITING", "COMPLETED")
|
||||
pub async fn search_trials(
|
||||
&self,
|
||||
condition: &str,
|
||||
status: Option<&str>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let mut url = format!(
|
||||
"{}/studies?query.cond={}&pageSize=100",
|
||||
self.base_url,
|
||||
urlencoding::encode(condition)
|
||||
);
|
||||
|
||||
if let Some(s) = status {
|
||||
url.push_str(&format!("&filter.overallStatus={}", s));
|
||||
}
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let trials_response: ClinicalTrialsResponse = response.json().await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for study in trials_response.studies {
|
||||
let vector = self.study_to_vector(study)?;
|
||||
vectors.push(vector);
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Convert clinical study to SemanticVector
|
||||
fn study_to_vector(&self, study: ClinicalStudy) -> Result<SemanticVector> {
|
||||
let protocol = study.protocol_section;
|
||||
let nct_id = protocol.identification.nct_id;
|
||||
let title = protocol
|
||||
.identification
|
||||
.brief_title
|
||||
.unwrap_or_else(|| "Untitled Study".to_string());
|
||||
|
||||
let summary = protocol
|
||||
.description
|
||||
.as_ref()
|
||||
.and_then(|d| d.brief_summary.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let conditions = protocol
|
||||
.conditions
|
||||
.as_ref()
|
||||
.map(|c| c.conditions.join(", "))
|
||||
.unwrap_or_default();
|
||||
|
||||
let status = protocol
|
||||
.status
|
||||
.overall_status
|
||||
.unwrap_or_else(|| "UNKNOWN".to_string());
|
||||
|
||||
// Create text for embedding
|
||||
let text = format!("{} {} {}", title, summary, conditions);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
// Parse start date
|
||||
let timestamp = protocol
|
||||
.status
|
||||
.start_date
|
||||
.as_ref()
|
||||
.and_then(|sd| sd.date.as_ref())
|
||||
.and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok())
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("nct_id".to_string(), nct_id.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
metadata.insert("summary".to_string(), summary);
|
||||
metadata.insert("conditions".to_string(), conditions);
|
||||
metadata.insert("status".to_string(), status);
|
||||
metadata.insert("source".to_string(), "clinicaltrials".to_string());
|
||||
|
||||
Ok(SemanticVector {
|
||||
id: format!("NCT:{}", nct_id),
|
||||
embedding,
|
||||
domain: Domain::Medical,
|
||||
timestamp,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClinicalTrialsClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create ClinicalTrials client")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// FDA OpenFDA Client
|
||||
// ============================================================================
|
||||
|
||||
/// OpenFDA drug adverse event response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FdaDrugEventResponse {
|
||||
results: Vec<FdaDrugEvent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FdaDrugEvent {
|
||||
#[serde(rename = "safetyreportid")]
|
||||
safety_report_id: String,
|
||||
#[serde(rename = "receivedate", default)]
|
||||
receive_date: Option<String>,
|
||||
#[serde(default)]
|
||||
patient: Option<FdaPatient>,
|
||||
#[serde(default, deserialize_with = "deserialize_number_from_string")]
|
||||
serious: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FdaPatient {
|
||||
#[serde(default)]
|
||||
drug: Vec<FdaDrug>,
|
||||
#[serde(default)]
|
||||
reaction: Vec<FdaReaction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FdaDrug {
|
||||
#[serde(rename = "medicinalproduct", default)]
|
||||
medicinal_product: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FdaReaction {
|
||||
#[serde(rename = "reactionmeddrapt", default)]
|
||||
reaction_meddra_pt: Option<String>,
|
||||
}
|
||||
|
||||
/// OpenFDA device recall response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FdaRecallResponse {
|
||||
results: Vec<FdaRecall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FdaRecall {
|
||||
#[serde(rename = "recall_number")]
|
||||
recall_number: String,
|
||||
#[serde(default)]
|
||||
reason_for_recall: Option<String>,
|
||||
#[serde(default)]
|
||||
product_description: Option<String>,
|
||||
#[serde(default)]
|
||||
report_date: Option<String>,
|
||||
#[serde(default)]
|
||||
classification: Option<String>,
|
||||
}
|
||||
|
||||
/// Client for FDA OpenFDA API
|
||||
pub struct FdaClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl FdaClient {
|
||||
/// Create a new FDA OpenFDA client
|
||||
pub fn new() -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://api.fda.gov".to_string(),
|
||||
rate_limit_delay: Duration::from_millis(FDA_RATE_LIMIT_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(384)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Search drug adverse events
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `drug_name` - Name of drug to search (e.g., "aspirin", "ibuprofen")
|
||||
pub async fn search_drug_events(&self, drug_name: &str) -> Result<Vec<SemanticVector>> {
|
||||
let url = format!(
|
||||
"{}/drug/event.json?search=patient.drug.medicinalproduct:\"{}\"&limit=100",
|
||||
self.base_url,
|
||||
urlencoding::encode(drug_name)
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
|
||||
// FDA API may return 404 if no results - handle gracefully
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let events_response: FdaDrugEventResponse = response.json().await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for event in events_response.results {
|
||||
let vector = self.drug_event_to_vector(event)?;
|
||||
vectors.push(vector);
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Search device recalls
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `reason` - Reason for recall to search
|
||||
pub async fn search_recalls(&self, reason: &str) -> Result<Vec<SemanticVector>> {
|
||||
let url = format!(
|
||||
"{}/device/recall.json?search=reason_for_recall:\"{}\"&limit=100",
|
||||
self.base_url,
|
||||
urlencoding::encode(reason)
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let recalls_response: FdaRecallResponse = response.json().await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for recall in recalls_response.results {
|
||||
let vector = self.recall_to_vector(recall)?;
|
||||
vectors.push(vector);
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Convert drug event to SemanticVector
|
||||
fn drug_event_to_vector(&self, event: FdaDrugEvent) -> Result<SemanticVector> {
|
||||
let mut drug_names = Vec::new();
|
||||
let mut reactions = Vec::new();
|
||||
|
||||
if let Some(patient) = &event.patient {
|
||||
for drug in &patient.drug {
|
||||
if let Some(name) = &drug.medicinal_product {
|
||||
drug_names.push(name.clone());
|
||||
}
|
||||
}
|
||||
for reaction in &patient.reaction {
|
||||
if let Some(r) = &reaction.reaction_meddra_pt {
|
||||
reactions.push(r.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let drugs_text = drug_names.join(", ");
|
||||
let reactions_text = reactions.join(", ");
|
||||
let serious = if event.serious == Some(1) {
|
||||
"serious"
|
||||
} else {
|
||||
"non-serious"
|
||||
};
|
||||
|
||||
// Create text for embedding
|
||||
let text = format!("Drug: {} Reactions: {} Severity: {}", drugs_text, reactions_text, serious);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
// Parse receive date
|
||||
let timestamp = event
|
||||
.receive_date
|
||||
.as_ref()
|
||||
.and_then(|d| NaiveDate::parse_from_str(d, "%Y%m%d").ok())
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("report_id".to_string(), event.safety_report_id.clone());
|
||||
metadata.insert("drugs".to_string(), drugs_text);
|
||||
metadata.insert("reactions".to_string(), reactions_text);
|
||||
metadata.insert("serious".to_string(), serious.to_string());
|
||||
metadata.insert("source".to_string(), "fda_drug_events".to_string());
|
||||
|
||||
Ok(SemanticVector {
|
||||
id: format!("FDA_EVENT:{}", event.safety_report_id),
|
||||
embedding,
|
||||
domain: Domain::Medical,
|
||||
timestamp,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert recall to SemanticVector
|
||||
fn recall_to_vector(&self, recall: FdaRecall) -> Result<SemanticVector> {
|
||||
let reason = recall.reason_for_recall.unwrap_or_else(|| "Unknown reason".to_string());
|
||||
let product = recall.product_description.unwrap_or_else(|| "Unknown product".to_string());
|
||||
let classification = recall.classification.unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
// Create text for embedding
|
||||
let text = format!("Product: {} Reason: {} Classification: {}", product, reason, classification);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
// Parse report date
|
||||
let timestamp = recall
|
||||
.report_date
|
||||
.as_ref()
|
||||
.and_then(|d| NaiveDate::parse_from_str(d, "%Y%m%d").ok())
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("recall_number".to_string(), recall.recall_number.clone());
|
||||
metadata.insert("reason".to_string(), reason);
|
||||
metadata.insert("product".to_string(), product);
|
||||
metadata.insert("classification".to_string(), classification);
|
||||
metadata.insert("source".to_string(), "fda_recalls".to_string());
|
||||
|
||||
Ok(SemanticVector {
|
||||
id: format!("FDA_RECALL:{}", recall.recall_number),
|
||||
embedding,
|
||||
domain: Domain::Medical,
|
||||
timestamp,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FdaClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create FDA client")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pubmed_client_creation() {
|
||||
let client = PubMedClient::new(None);
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clinical_trials_client_creation() {
|
||||
let client = ClinicalTrialsClient::new();
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fda_client_creation() {
|
||||
let client = FdaClient::new();
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiting() {
|
||||
// Verify rate limits are set correctly
|
||||
let pubmed_without_key = PubMedClient::new(None).unwrap();
|
||||
assert_eq!(
|
||||
pubmed_without_key.rate_limit_delay,
|
||||
Duration::from_millis(NCBI_RATE_LIMIT_MS)
|
||||
);
|
||||
|
||||
let pubmed_with_key = PubMedClient::new(Some("test_key".to_string())).unwrap();
|
||||
assert_eq!(
|
||||
pubmed_with_key.rate_limit_delay,
|
||||
Duration::from_millis(NCBI_WITH_KEY_RATE_LIMIT_MS)
|
||||
);
|
||||
}
|
||||
}
|
||||
2035
vendor/ruvector/examples/data/framework/src/ml_clients.rs
vendored
Normal file
2035
vendor/ruvector/examples/data/framework/src/ml_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1606
vendor/ruvector/examples/data/framework/src/news_clients.rs
vendored
Normal file
1606
vendor/ruvector/examples/data/framework/src/news_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1093
vendor/ruvector/examples/data/framework/src/optimized.rs
vendored
Normal file
1093
vendor/ruvector/examples/data/framework/src/optimized.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
667
vendor/ruvector/examples/data/framework/src/patent_clients.rs
vendored
Normal file
667
vendor/ruvector/examples/data/framework/src/patent_clients.rs
vendored
Normal file
@@ -0,0 +1,667 @@
|
||||
//! Patent database API integrations for USPTO PatentsView and EPO
|
||||
//!
|
||||
//! This module provides async clients for fetching patent data from:
|
||||
//! - USPTO PatentsView API (Free, no authentication required)
|
||||
//! - EPO Open Patent Services (Free tier available)
|
||||
//!
|
||||
//! Converts patent data to SemanticVector format for RuVector discovery.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{NaiveDate, Utc};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::Deserialize;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Rate limiting configuration
|
||||
const USPTO_RATE_LIMIT_MS: u64 = 200; // ~5 requests/second
|
||||
const EPO_RATE_LIMIT_MS: u64 = 1000; // Conservative 1 request/second
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
|
||||
// ============================================================================
|
||||
// USPTO PatentsView API Client
|
||||
// ============================================================================
|
||||
|
||||
/// USPTO PatentsView API response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsptoPatentsResponse {
|
||||
#[serde(default)]
|
||||
patents: Vec<UsptoPatent>,
|
||||
#[serde(default)]
|
||||
count: i32,
|
||||
#[serde(default)]
|
||||
total_patent_count: Option<i32>,
|
||||
}
|
||||
|
||||
/// USPTO Patent record
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsptoPatent {
|
||||
/// Patent number
|
||||
patent_number: String,
|
||||
/// Patent title
|
||||
#[serde(default)]
|
||||
patent_title: Option<String>,
|
||||
/// Patent abstract
|
||||
#[serde(default)]
|
||||
patent_abstract: Option<String>,
|
||||
/// Grant date (YYYY-MM-DD)
|
||||
#[serde(default)]
|
||||
patent_date: Option<String>,
|
||||
/// Application filing date
|
||||
#[serde(default)]
|
||||
app_date: Option<String>,
|
||||
/// Assignees (organizations/companies)
|
||||
#[serde(default)]
|
||||
assignees: Vec<UsptoAssignee>,
|
||||
/// Inventors
|
||||
#[serde(default)]
|
||||
inventors: Vec<UsptoInventor>,
|
||||
/// CPC classifications
|
||||
#[serde(default)]
|
||||
cpcs: Vec<UsptoCpc>,
|
||||
/// Citation counts
|
||||
#[serde(default)]
|
||||
cited_patent_count: Option<i32>,
|
||||
#[serde(default)]
|
||||
citedby_patent_count: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsptoAssignee {
|
||||
#[serde(default)]
|
||||
assignee_organization: Option<String>,
|
||||
#[serde(default)]
|
||||
assignee_individual_name_first: Option<String>,
|
||||
#[serde(default)]
|
||||
assignee_individual_name_last: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsptoInventor {
|
||||
#[serde(default)]
|
||||
inventor_name_first: Option<String>,
|
||||
#[serde(default)]
|
||||
inventor_name_last: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsptoCpc {
|
||||
/// CPC section (e.g., "Y02")
|
||||
#[serde(default)]
|
||||
cpc_section_id: Option<String>,
|
||||
/// CPC subclass (e.g., "Y02E")
|
||||
#[serde(default)]
|
||||
cpc_subclass_id: Option<String>,
|
||||
/// CPC group (e.g., "Y02E10/50")
|
||||
#[serde(default)]
|
||||
cpc_group_id: Option<String>,
|
||||
}
|
||||
|
||||
/// USPTO citation response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsptoCitationsResponse {
|
||||
#[serde(default)]
|
||||
patents: Vec<UsptoCitation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsptoCitation {
|
||||
patent_number: String,
|
||||
#[serde(default)]
|
||||
patent_title: Option<String>,
|
||||
}
|
||||
|
||||
/// Client for USPTO PatentsView API (PatentSearch API v2)
|
||||
///
|
||||
/// PatentsView provides free access to USPTO patent data with no authentication required.
|
||||
/// Uses the new PatentSearch API (ElasticSearch-based) as of May 2025.
|
||||
/// API documentation: https://search.patentsview.org/docs/
|
||||
pub struct UsptoPatentClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl UsptoPatentClient {
|
||||
/// Create a new USPTO PatentsView client
|
||||
///
|
||||
/// No authentication required for the PatentsView API.
|
||||
/// Uses the new PatentSearch API at search.patentsview.org
|
||||
pub fn new() -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.user_agent("RuVector-Discovery/1.0")
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://search.patentsview.org/api/v1".to_string(),
|
||||
rate_limit_delay: Duration::from_millis(USPTO_RATE_LIMIT_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(512)), // Higher dimension for technical text
|
||||
})
|
||||
}
|
||||
|
||||
/// Search patents by keyword query
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search keywords (e.g., "artificial intelligence", "solar cell")
|
||||
/// * `max_results` - Maximum number of results to return (max 1000 per page)
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let client = UsptoPatentClient::new()?;
|
||||
/// let patents = client.search_patents("quantum computing", 50).await?;
|
||||
/// ```
|
||||
pub async fn search_patents(
|
||||
&self,
|
||||
query: &str,
|
||||
max_results: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let per_page = max_results.min(100);
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
|
||||
// New PatentSearch API uses GET with query parameters
|
||||
// Query format: q=patent_title:*query* OR patent_abstract:*query*
|
||||
let url = format!(
|
||||
"{}/patent/?q=patent_title:*{}*%20OR%20patent_abstract:*{}*&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":{},\"matched_subentities_only\":true}}",
|
||||
self.base_url, encoded_query, encoded_query, per_page
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let uspto_response: UsptoPatentsResponse = response.json().await?;
|
||||
|
||||
self.convert_patents_to_vectors(uspto_response.patents)
|
||||
}
|
||||
|
||||
/// Search patents by assignee (company/organization name)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `company_name` - Company or organization name (e.g., "IBM", "Google")
|
||||
/// * `max_results` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let patents = client.search_by_assignee("Tesla Inc", 100).await?;
|
||||
/// ```
|
||||
pub async fn search_by_assignee(
|
||||
&self,
|
||||
company_name: &str,
|
||||
max_results: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let per_page = max_results.min(100);
|
||||
let encoded_name = urlencoding::encode(company_name);
|
||||
|
||||
// New PatentSearch API format
|
||||
let url = format!(
|
||||
"{}/patent/?q=assignees.assignee_organization:*{}*&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":{},\"matched_subentities_only\":true}}",
|
||||
self.base_url, encoded_name, per_page
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let uspto_response: UsptoPatentsResponse = response.json().await?;
|
||||
|
||||
self.convert_patents_to_vectors(uspto_response.patents)
|
||||
}
|
||||
|
||||
/// Search patents by CPC classification code
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cpc_class` - CPC classification (e.g., "Y02E" for climate tech energy, "G06N" for AI)
|
||||
/// * `max_results` - Maximum number of results to return
|
||||
///
|
||||
/// # Example - Climate Change Mitigation Technologies
|
||||
/// ```ignore
|
||||
/// let climate_patents = client.search_by_cpc("Y02", 200).await?;
|
||||
/// ```
|
||||
///
|
||||
/// # Common CPC Classes
|
||||
/// * `Y02` - Climate change mitigation technologies
|
||||
/// * `Y02E` - Climate tech - Energy generation/transmission/distribution
|
||||
/// * `G06N` - Computing arrangements based on AI/ML/neural networks
|
||||
/// * `A61` - Medical or veterinary science
|
||||
/// * `H01` - Electric elements (batteries, solar cells, etc.)
|
||||
pub async fn search_by_cpc(
|
||||
&self,
|
||||
cpc_class: &str,
|
||||
max_results: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let per_page = max_results.min(100);
|
||||
let encoded_cpc = urlencoding::encode(cpc_class);
|
||||
|
||||
// New PatentSearch API - query cpcs.cpc_group field
|
||||
let url = format!(
|
||||
"{}/patent/?q=cpcs.cpc_group:{}*&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":{},\"matched_subentities_only\":true}}",
|
||||
self.base_url, encoded_cpc, per_page
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let uspto_response: UsptoPatentsResponse = response.json().await?;
|
||||
|
||||
self.convert_patents_to_vectors(uspto_response.patents)
|
||||
}
|
||||
|
||||
/// Get detailed information for a specific patent
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `patent_number` - USPTO patent number (e.g., "10000000")
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let patent = client.get_patent("10123456").await?;
|
||||
/// ```
|
||||
pub async fn get_patent(&self, patent_number: &str) -> Result<Option<SemanticVector>> {
|
||||
// New PatentSearch API - direct patent lookup
|
||||
let url = format!(
|
||||
"{}/patent/?q=patent_id:{}&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":1}}",
|
||||
self.base_url, patent_number
|
||||
);
|
||||
|
||||
sleep(self.rate_limit_delay).await;
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let uspto_response: UsptoPatentsResponse = response.json().await?;
|
||||
|
||||
let mut vectors = self.convert_patents_to_vectors(uspto_response.patents)?;
|
||||
Ok(vectors.pop())
|
||||
}
|
||||
|
||||
/// Get citations for a patent (both citing and cited patents)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `patent_number` - USPTO patent number
|
||||
///
|
||||
/// # Returns
|
||||
/// Tuple of (patents that cite this patent, patents cited by this patent)
|
||||
pub async fn get_citations(
|
||||
&self,
|
||||
patent_number: &str,
|
||||
) -> Result<(Vec<SemanticVector>, Vec<SemanticVector>)> {
|
||||
// Get patents that cite this one (forward citations)
|
||||
let citing = self.get_citing_patents(patent_number).await?;
|
||||
|
||||
// Get patents cited by this one (backward citations)
|
||||
let cited = self.get_cited_patents(patent_number).await?;
|
||||
|
||||
Ok((citing, cited))
|
||||
}
|
||||
|
||||
/// Get patents that cite the given patent (forward citations)
|
||||
/// Note: Citation data requires separate API endpoints in PatentSearch API v2
|
||||
async fn get_citing_patents(&self, _patent_number: &str) -> Result<Vec<SemanticVector>> {
|
||||
// The new PatentSearch API handles citations differently
|
||||
// Forward citations are available via /api/v1/us_patent_citation/ endpoint
|
||||
// For now, return empty - full citation support requires additional implementation
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Get patents cited by the given patent (backward citations)
|
||||
/// Note: Citation data requires separate API endpoints in PatentSearch API v2
|
||||
async fn get_cited_patents(&self, _patent_number: &str) -> Result<Vec<SemanticVector>> {
|
||||
// The new PatentSearch API handles citations differently
|
||||
// Backward citations are available via /api/v1/us_patent_citation/ endpoint
|
||||
// For now, return empty - full citation support requires additional implementation
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Convert USPTO patent records to SemanticVectors
|
||||
fn convert_patents_to_vectors(&self, patents: Vec<UsptoPatent>) -> Result<Vec<SemanticVector>> {
|
||||
let mut vectors = Vec::new();
|
||||
|
||||
for patent in patents {
|
||||
let title = patent.patent_title.unwrap_or_else(|| "Untitled Patent".to_string());
|
||||
let abstract_text = patent.patent_abstract.unwrap_or_default();
|
||||
|
||||
// Create combined text for embedding
|
||||
let text = format!("{} {}", title, abstract_text);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
// Parse grant date (prefer patent_date, fallback to app_date)
|
||||
let timestamp = patent
|
||||
.patent_date
|
||||
.or(patent.app_date)
|
||||
.as_ref()
|
||||
.and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok())
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.map(|dt| dt.and_utc())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Extract assignee names
|
||||
let assignees = patent
|
||||
.assignees
|
||||
.iter()
|
||||
.map(|a| {
|
||||
a.assignee_organization
|
||||
.clone()
|
||||
.or_else(|| {
|
||||
let first = a.assignee_individual_name_first.as_deref().unwrap_or("");
|
||||
let last = a.assignee_individual_name_last.as_deref().unwrap_or("");
|
||||
if !first.is_empty() || !last.is_empty() {
|
||||
Some(format!("{} {}", first, last).trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// Extract inventor names
|
||||
let inventors = patent
|
||||
.inventors
|
||||
.iter()
|
||||
.filter_map(|i| {
|
||||
let first = i.inventor_name_first.as_deref().unwrap_or("");
|
||||
let last = i.inventor_name_last.as_deref().unwrap_or("");
|
||||
if !first.is_empty() || !last.is_empty() {
|
||||
Some(format!("{} {}", first, last).trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// Extract CPC codes
|
||||
let cpc_codes = patent
|
||||
.cpcs
|
||||
.iter()
|
||||
.filter_map(|cpc| {
|
||||
cpc.cpc_group_id
|
||||
.clone()
|
||||
.or_else(|| cpc.cpc_subclass_id.clone())
|
||||
.or_else(|| cpc.cpc_section_id.clone())
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// Build metadata
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("patent_number".to_string(), patent.patent_number.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
metadata.insert("abstract".to_string(), abstract_text);
|
||||
metadata.insert("assignee".to_string(), assignees);
|
||||
metadata.insert("inventors".to_string(), inventors);
|
||||
metadata.insert("cpc_codes".to_string(), cpc_codes);
|
||||
metadata.insert(
|
||||
"citations_count".to_string(),
|
||||
patent.citedby_patent_count.unwrap_or(0).to_string(),
|
||||
);
|
||||
metadata.insert(
|
||||
"cited_count".to_string(),
|
||||
patent.cited_patent_count.unwrap_or(0).to_string(),
|
||||
);
|
||||
metadata.insert("source".to_string(), "uspto".to_string());
|
||||
|
||||
vectors.push(SemanticVector {
|
||||
id: format!("US{}", patent.patent_number),
|
||||
embedding,
|
||||
domain: Domain::Research, // Could be Domain::Innovation if that variant exists
|
||||
timestamp,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// GET request with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
|
||||
{
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
if !response.status().is_success() {
|
||||
return Err(FrameworkError::Network(
|
||||
reqwest::Error::from(response.error_for_status().unwrap_err()),
|
||||
));
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// POST request with retry logic (kept for backwards compatibility)
|
||||
#[allow(dead_code)]
|
||||
async fn post_with_retry(
|
||||
&self,
|
||||
url: &str,
|
||||
json: &serde_json::Value,
|
||||
) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.post(url).json(json).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
|
||||
{
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
if !response.status().is_success() {
|
||||
return Err(FrameworkError::Network(
|
||||
reqwest::Error::from(response.error_for_status().unwrap_err()),
|
||||
));
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for UsptoPatentClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create USPTO client")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// EPO Open Patent Services Client (Placeholder)
|
||||
// ============================================================================
|
||||
|
||||
/// Client for European Patent Office (EPO) Open Patent Services
|
||||
///
|
||||
/// Note: This is a placeholder for future EPO integration.
|
||||
/// The EPO OPS API requires registration and OAuth authentication.
|
||||
/// See: https://developers.epo.org/
|
||||
pub struct EpoClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
consumer_key: Option<String>,
|
||||
consumer_secret: Option<String>,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl EpoClient {
|
||||
/// Create a new EPO client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `consumer_key` - EPO API consumer key (from developer registration)
|
||||
/// * `consumer_secret` - EPO API consumer secret
|
||||
///
|
||||
/// Registration required at: https://developers.epo.org/
|
||||
pub fn new(consumer_key: Option<String>, consumer_secret: Option<String>) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(FrameworkError::Network)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: "https://ops.epo.org/3.2/rest-services".to_string(),
|
||||
consumer_key,
|
||||
consumer_secret,
|
||||
rate_limit_delay: Duration::from_millis(EPO_RATE_LIMIT_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(512)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Search European patents
|
||||
///
|
||||
/// Note: Implementation requires OAuth authentication flow.
|
||||
/// This is a placeholder for future development.
|
||||
pub async fn search_patents(
|
||||
&self,
|
||||
_query: &str,
|
||||
_max_results: usize,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
Err(FrameworkError::Config(
|
||||
"EPO client not yet implemented. Requires OAuth authentication.".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_uspto_client_creation() {
|
||||
let client = UsptoPatentClient::new();
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_epo_client_creation() {
|
||||
let client = EpoClient::new(None, None);
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_client() {
|
||||
let client = UsptoPatentClient::default();
|
||||
assert_eq!(
|
||||
client.rate_limit_delay,
|
||||
Duration::from_millis(USPTO_RATE_LIMIT_MS)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiting() {
|
||||
let client = UsptoPatentClient::new().unwrap();
|
||||
assert_eq!(
|
||||
client.rate_limit_delay,
|
||||
Duration::from_millis(USPTO_RATE_LIMIT_MS)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cpc_classification_mapping() {
|
||||
// Verify we handle different CPC code lengths correctly
|
||||
let test_cases = vec![
|
||||
("Y02", "cpc_section_id"),
|
||||
("G06N", "cpc_subclass_id"),
|
||||
("Y02E10/50", "cpc_group_id"),
|
||||
];
|
||||
|
||||
for (code, expected_field) in test_cases {
|
||||
let field = if code.len() <= 3 {
|
||||
"cpc_section_id"
|
||||
} else if code.len() <= 4 {
|
||||
"cpc_subclass_id"
|
||||
} else {
|
||||
"cpc_group_id"
|
||||
};
|
||||
assert_eq!(field, expected_field, "Failed for CPC code: {}", code);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires network access
|
||||
async fn test_search_patents_integration() {
|
||||
let client = UsptoPatentClient::new().unwrap();
|
||||
|
||||
// Test basic search
|
||||
let result = client.search_patents("quantum computing", 5).await;
|
||||
|
||||
// Should either succeed or fail with network error, not panic
|
||||
match result {
|
||||
Ok(patents) => {
|
||||
assert!(patents.len() <= 5);
|
||||
for patent in patents {
|
||||
assert!(patent.id.starts_with("US"));
|
||||
assert_eq!(patent.domain, Domain::Research);
|
||||
assert!(!patent.metadata.is_empty());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Network errors are acceptable in tests
|
||||
println!("Network test skipped: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires network access
|
||||
async fn test_search_by_cpc_integration() {
|
||||
let client = UsptoPatentClient::new().unwrap();
|
||||
|
||||
// Test CPC search for AI/ML patents
|
||||
let result = client.search_by_cpc("G06N", 5).await;
|
||||
|
||||
match result {
|
||||
Ok(patents) => {
|
||||
assert!(patents.len() <= 5);
|
||||
for patent in patents {
|
||||
let cpc_codes = patent.metadata.get("cpc_codes").map(|s| s.as_str()).unwrap_or("");
|
||||
// Should contain G06N classification
|
||||
assert!(
|
||||
cpc_codes.contains("G06N") || cpc_codes.is_empty(),
|
||||
"Expected G06N in CPC codes, got: {}",
|
||||
cpc_codes
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Network test skipped: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_dimension() {
|
||||
let client = UsptoPatentClient::new().unwrap();
|
||||
// Verify embedding dimension is set correctly for technical text
|
||||
let embedding = client.embedder.embed_text("test patent description");
|
||||
assert_eq!(embedding.len(), 512);
|
||||
}
|
||||
}
|
||||
638
vendor/ruvector/examples/data/framework/src/persistence.rs
vendored
Normal file
638
vendor/ruvector/examples/data/framework/src/persistence.rs
vendored
Normal file
@@ -0,0 +1,638 @@
|
||||
//! Persistence Layer for RuVector Discovery Framework
|
||||
//!
|
||||
//! This module provides serialization/deserialization for the OptimizedDiscoveryEngine
|
||||
//! and discovered patterns. Supports:
|
||||
//! - Full engine state save/load
|
||||
//! - Pattern-only save/load/append
|
||||
//! - Optional gzip compression for large datasets
|
||||
//! - Incremental pattern appends without rewriting entire files
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter, Read, Write};
|
||||
use std::path::Path;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use flate2::Compression;
|
||||
use flate2::read::GzDecoder;
|
||||
use flate2::write::GzEncoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::optimized::{OptimizedConfig, OptimizedDiscoveryEngine, SignificantPattern};
|
||||
use crate::ruvector_native::{
|
||||
CoherenceSnapshot, Domain, GraphEdge, GraphNode, SemanticVector,
|
||||
};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Serializable state of the OptimizedDiscoveryEngine
|
||||
///
|
||||
/// This struct excludes non-serializable fields like AtomicU64 metrics
|
||||
/// and caches, focusing on the core graph and history state.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EngineState {
|
||||
/// Engine configuration
|
||||
pub config: OptimizedConfig,
|
||||
/// All semantic vectors
|
||||
pub vectors: Vec<SemanticVector>,
|
||||
/// Graph nodes
|
||||
pub nodes: HashMap<u32, GraphNode>,
|
||||
/// Graph edges
|
||||
pub edges: Vec<GraphEdge>,
|
||||
/// Coherence history (timestamp, mincut value, snapshot)
|
||||
pub coherence_history: Vec<(DateTime<Utc>, f64, CoherenceSnapshot)>,
|
||||
/// Next node ID counter
|
||||
pub next_node_id: u32,
|
||||
/// Domain-specific node indices
|
||||
pub domain_nodes: HashMap<Domain, Vec<u32>>,
|
||||
/// Temporal analysis state
|
||||
pub domain_timeseries: HashMap<Domain, Vec<(DateTime<Utc>, f64)>>,
|
||||
/// Metadata about when this state was saved
|
||||
pub saved_at: DateTime<Utc>,
|
||||
/// Version for compatibility checking
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
impl EngineState {
|
||||
/// Create a new empty engine state
|
||||
pub fn new(config: OptimizedConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
vectors: Vec::new(),
|
||||
nodes: HashMap::new(),
|
||||
edges: Vec::new(),
|
||||
coherence_history: Vec::new(),
|
||||
next_node_id: 0,
|
||||
domain_nodes: HashMap::new(),
|
||||
domain_timeseries: HashMap::new(),
|
||||
saved_at: Utc::now(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for saving/loading with compression
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PersistenceOptions {
|
||||
/// Enable gzip compression
|
||||
pub compress: bool,
|
||||
/// Compression level (0-9, higher = better compression but slower)
|
||||
pub compression_level: u32,
|
||||
/// Pretty-print JSON (larger files, more readable)
|
||||
pub pretty: bool,
|
||||
}
|
||||
|
||||
impl Default for PersistenceOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
compress: false,
|
||||
compression_level: 6,
|
||||
pretty: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PersistenceOptions {
|
||||
/// Create options with compression enabled
|
||||
pub fn compressed() -> Self {
|
||||
Self {
|
||||
compress: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create options with pretty-printed JSON
|
||||
pub fn pretty() -> Self {
|
||||
Self {
|
||||
pretty: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Save the OptimizedDiscoveryEngine state to a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `engine` - The engine to save
|
||||
/// * `path` - Path to save to (will be created/overwritten)
|
||||
/// * `options` - Persistence options (compression, formatting)
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// # use ruvector_data_framework::optimized::{OptimizedConfig, OptimizedDiscoveryEngine};
|
||||
/// # use ruvector_data_framework::persistence::{save_engine, PersistenceOptions};
|
||||
/// # use std::path::Path;
|
||||
/// let engine = OptimizedDiscoveryEngine::new(OptimizedConfig::default());
|
||||
/// save_engine(&engine, Path::new("engine_state.json"), &PersistenceOptions::default())?;
|
||||
/// # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
/// ```
|
||||
pub fn save_engine(
|
||||
engine: &OptimizedDiscoveryEngine,
|
||||
path: &Path,
|
||||
options: &PersistenceOptions,
|
||||
) -> Result<()> {
|
||||
// Extract serializable state
|
||||
let state = extract_state(engine);
|
||||
|
||||
// Save to file
|
||||
save_state(&state, path, options)?;
|
||||
|
||||
tracing::info!(
|
||||
"Saved engine state to {} ({} nodes, {} edges)",
|
||||
path.display(),
|
||||
state.nodes.len(),
|
||||
state.edges.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load an OptimizedDiscoveryEngine from a saved state file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the saved state file
|
||||
///
|
||||
/// # Returns
|
||||
/// A new OptimizedDiscoveryEngine with the loaded state
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// # use ruvector_data_framework::persistence::load_engine;
|
||||
/// # use std::path::Path;
|
||||
/// let engine = load_engine(Path::new("engine_state.json"))?;
|
||||
/// # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
/// ```
|
||||
pub fn load_engine(path: &Path) -> Result<OptimizedDiscoveryEngine> {
|
||||
let state = load_state(path)?;
|
||||
|
||||
tracing::info!(
|
||||
"Loaded engine state from {} ({} nodes, {} edges)",
|
||||
path.display(),
|
||||
state.nodes.len(),
|
||||
state.edges.len()
|
||||
);
|
||||
|
||||
// Reconstruct engine from state
|
||||
Ok(reconstruct_engine(state))
|
||||
}
|
||||
|
||||
/// Save discovered patterns to a JSON file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `patterns` - Patterns to save
|
||||
/// * `path` - Path to save to
|
||||
/// * `options` - Persistence options
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// # use ruvector_data_framework::optimized::SignificantPattern;
|
||||
/// # use ruvector_data_framework::persistence::{save_patterns, PersistenceOptions};
|
||||
/// # use std::path::Path;
|
||||
/// let patterns: Vec<SignificantPattern> = vec![];
|
||||
/// save_patterns(&patterns, Path::new("patterns.json"), &PersistenceOptions::default())?;
|
||||
/// # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
/// ```
|
||||
pub fn save_patterns(
|
||||
patterns: &[SignificantPattern],
|
||||
path: &Path,
|
||||
options: &PersistenceOptions,
|
||||
) -> Result<()> {
|
||||
let file = File::create(path).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to create file {}: {}", path.display(), e))
|
||||
})?;
|
||||
|
||||
let writer = BufWriter::new(file);
|
||||
|
||||
if options.compress {
|
||||
let mut encoder = GzEncoder::new(writer, Compression::new(options.compression_level));
|
||||
let json = if options.pretty {
|
||||
serde_json::to_string_pretty(patterns)?
|
||||
} else {
|
||||
serde_json::to_string(patterns)?
|
||||
};
|
||||
encoder.write_all(json.as_bytes()).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to write compressed patterns: {}", e))
|
||||
})?;
|
||||
encoder.finish().map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to finish compression: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
if options.pretty {
|
||||
serde_json::to_writer_pretty(writer, patterns)?;
|
||||
} else {
|
||||
serde_json::to_writer(writer, patterns)?;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Saved {} patterns to {}", patterns.len(), path.display());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load patterns from a JSON file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the patterns file
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of loaded patterns
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// # use ruvector_data_framework::persistence::load_patterns;
|
||||
/// # use std::path::Path;
|
||||
/// let patterns = load_patterns(Path::new("patterns.json"))?;
|
||||
/// # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
/// ```
|
||||
pub fn load_patterns(path: &Path) -> Result<Vec<SignificantPattern>> {
|
||||
let file = File::open(path).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to open file {}: {}", path.display(), e))
|
||||
})?;
|
||||
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
// Try to detect if file is gzip-compressed by reading magic bytes
|
||||
let mut peeker = BufReader::new(File::open(path).unwrap());
|
||||
let mut magic = [0u8; 2];
|
||||
let is_gzip = peeker.read_exact(&mut magic).is_ok() && magic == [0x1f, 0x8b];
|
||||
|
||||
let patterns: Vec<SignificantPattern> = if is_gzip {
|
||||
let file = File::open(path).unwrap();
|
||||
let decoder = GzDecoder::new(BufReader::new(file));
|
||||
serde_json::from_reader(decoder)?
|
||||
} else {
|
||||
serde_json::from_reader(reader)?
|
||||
};
|
||||
|
||||
tracing::info!("Loaded {} patterns from {}", patterns.len(), path.display());
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
/// Append new patterns to an existing patterns file
|
||||
///
|
||||
/// This is more efficient than loading all patterns, adding new ones,
|
||||
/// and saving the entire list. However, it only works with uncompressed
|
||||
/// JSON arrays.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `patterns` - New patterns to append
|
||||
/// * `path` - Path to the existing patterns file
|
||||
///
|
||||
/// # Note
|
||||
/// If the file doesn't exist, it will be created with the given patterns.
|
||||
/// For compressed files, this will decompress, append, and recompress.
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// # use ruvector_data_framework::optimized::SignificantPattern;
|
||||
/// # use ruvector_data_framework::persistence::append_patterns;
|
||||
/// # use std::path::Path;
|
||||
/// let new_patterns: Vec<SignificantPattern> = vec![];
|
||||
/// append_patterns(&new_patterns, Path::new("patterns.json"))?;
|
||||
/// # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
/// ```
|
||||
pub fn append_patterns(patterns: &[SignificantPattern], path: &Path) -> Result<()> {
|
||||
if patterns.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if !path.exists() {
|
||||
// Create new file
|
||||
return save_patterns(patterns, path, &PersistenceOptions::default());
|
||||
}
|
||||
|
||||
// Load existing patterns
|
||||
let mut existing = load_patterns(path)?;
|
||||
|
||||
// Append new patterns
|
||||
existing.extend_from_slice(patterns);
|
||||
|
||||
// Save combined patterns
|
||||
// Preserve compression if original was compressed
|
||||
let options = if is_compressed(path)? {
|
||||
PersistenceOptions::compressed()
|
||||
} else {
|
||||
PersistenceOptions::default()
|
||||
};
|
||||
|
||||
save_patterns(&existing, path, &options)?;
|
||||
|
||||
tracing::info!(
|
||||
"Appended {} patterns to {} (total: {})",
|
||||
patterns.len(),
|
||||
path.display(),
|
||||
existing.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Internal Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Extract serializable state from an OptimizedDiscoveryEngine
|
||||
///
|
||||
/// This uses reflection-like access to the engine's internal state.
|
||||
/// In practice, you'd need to add getter methods to OptimizedDiscoveryEngine.
|
||||
fn extract_state(_engine: &OptimizedDiscoveryEngine) -> EngineState {
|
||||
// Note: This requires the OptimizedDiscoveryEngine to expose its internal state
|
||||
// via getter methods. For now, we'll use a placeholder that you'll need to implement.
|
||||
|
||||
// Since we can't directly access private fields, we need the engine to provide
|
||||
// a method like `pub fn extract_state(&self) -> EngineState`
|
||||
|
||||
// For now, return a minimal state with what we can access
|
||||
// TODO: Uncomment when OptimizedDiscoveryEngine provides getter methods
|
||||
// let _stats = engine.stats();
|
||||
|
||||
EngineState {
|
||||
config: OptimizedConfig::default(), // Would need engine.config() method
|
||||
vectors: Vec::new(), // Would need engine.vectors() method
|
||||
nodes: HashMap::new(), // Would need engine.nodes() method
|
||||
edges: Vec::new(), // Would need engine.edges() method
|
||||
coherence_history: Vec::new(), // Would need engine.coherence_history() method
|
||||
next_node_id: 0, // Would need engine.next_node_id() method
|
||||
domain_nodes: HashMap::new(), // Would need engine.domain_nodes() method
|
||||
domain_timeseries: HashMap::new(), // Would need engine.domain_timeseries() method
|
||||
saved_at: Utc::now(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
}
|
||||
|
||||
// TODO: Implement proper state extraction once OptimizedDiscoveryEngine
|
||||
// exposes the necessary getter methods
|
||||
}
|
||||
|
||||
/// Reconstruct an OptimizedDiscoveryEngine from saved state
|
||||
fn reconstruct_engine(state: EngineState) -> OptimizedDiscoveryEngine {
|
||||
// Similarly, this would require OptimizedDiscoveryEngine to have
|
||||
// a constructor like `pub fn from_state(state: EngineState) -> Self`
|
||||
|
||||
// For now, create a new engine and note that full reconstruction
|
||||
// would require additional methods
|
||||
OptimizedDiscoveryEngine::new(state.config)
|
||||
|
||||
// TODO: Implement proper engine reconstruction once OptimizedDiscoveryEngine
|
||||
// provides the necessary methods to restore state
|
||||
}
|
||||
|
||||
/// Save engine state to a file with optional compression
|
||||
fn save_state(state: &EngineState, path: &Path, options: &PersistenceOptions) -> Result<()> {
|
||||
let file = File::create(path).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to create file {}: {}", path.display(), e))
|
||||
})?;
|
||||
|
||||
let writer = BufWriter::new(file);
|
||||
|
||||
if options.compress {
|
||||
let mut encoder = GzEncoder::new(writer, Compression::new(options.compression_level));
|
||||
let json = if options.pretty {
|
||||
serde_json::to_string_pretty(state)?
|
||||
} else {
|
||||
serde_json::to_string(state)?
|
||||
};
|
||||
encoder.write_all(json.as_bytes()).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to write compressed state: {}", e))
|
||||
})?;
|
||||
encoder.finish().map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to finish compression: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
if options.pretty {
|
||||
serde_json::to_writer_pretty(writer, state)?;
|
||||
} else {
|
||||
serde_json::to_writer(writer, state)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load engine state from a file with automatic compression detection
|
||||
fn load_state(path: &Path) -> Result<EngineState> {
|
||||
let file = File::open(path).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to open file {}: {}", path.display(), e))
|
||||
})?;
|
||||
|
||||
// Detect compression by reading magic bytes
|
||||
let is_gzip = is_compressed(path)?;
|
||||
|
||||
let state = if is_gzip {
|
||||
let file = File::open(path).unwrap();
|
||||
let decoder = GzDecoder::new(BufReader::new(file));
|
||||
serde_json::from_reader(decoder)?
|
||||
} else {
|
||||
let reader = BufReader::new(file);
|
||||
serde_json::from_reader(reader)?
|
||||
};
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Check if a file is gzip-compressed by reading magic bytes
|
||||
fn is_compressed(path: &Path) -> Result<bool> {
|
||||
let mut file = File::open(path).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to open file {}: {}", path.display(), e))
|
||||
})?;
|
||||
|
||||
let mut magic = [0u8; 2];
|
||||
match file.read_exact(&mut magic) {
|
||||
Ok(_) => Ok(magic == [0x1f, 0x8b]),
|
||||
Err(_) => Ok(false), // File too small or empty
|
||||
}
|
||||
}
|
||||
|
||||
/// Get file size in bytes
|
||||
pub fn get_file_size(path: &Path) -> Result<u64> {
|
||||
let metadata = std::fs::metadata(path).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to get file metadata: {}", e))
|
||||
})?;
|
||||
Ok(metadata.len())
|
||||
}
|
||||
|
||||
/// Calculate compression ratio for a file
|
||||
///
|
||||
/// Returns (compressed_size, uncompressed_size, ratio)
|
||||
pub fn compression_info(path: &Path) -> Result<(u64, u64, f64)> {
|
||||
let compressed_size = get_file_size(path)?;
|
||||
|
||||
if is_compressed(path)? {
|
||||
// Decompress and count bytes
|
||||
let file = File::open(path).unwrap();
|
||||
let mut decoder = GzDecoder::new(BufReader::new(file));
|
||||
let mut buffer = Vec::new();
|
||||
let uncompressed_size = decoder.read_to_end(&mut buffer).map_err(|e| {
|
||||
FrameworkError::Discovery(format!("Failed to decompress: {}", e))
|
||||
})? as u64;
|
||||
|
||||
let ratio = compressed_size as f64 / uncompressed_size as f64;
|
||||
Ok((compressed_size, uncompressed_size, ratio))
|
||||
} else {
|
||||
Ok((compressed_size, compressed_size, 1.0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::optimized::OptimizedConfig;
|
||||
use crate::ruvector_native::{DiscoveredPattern, PatternType, Evidence};
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_engine_state_creation() {
|
||||
let config = OptimizedConfig::default();
|
||||
let state = EngineState::new(config.clone());
|
||||
|
||||
assert_eq!(state.next_node_id, 0);
|
||||
assert_eq!(state.nodes.len(), 0);
|
||||
assert_eq!(state.config.similarity_threshold, config.similarity_threshold);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_persistence_options() {
|
||||
let default = PersistenceOptions::default();
|
||||
assert!(!default.compress);
|
||||
assert!(!default.pretty);
|
||||
|
||||
let compressed = PersistenceOptions::compressed();
|
||||
assert!(compressed.compress);
|
||||
|
||||
let pretty = PersistenceOptions::pretty();
|
||||
assert!(pretty.pretty);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_load_patterns() {
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path();
|
||||
|
||||
let patterns = vec![
|
||||
SignificantPattern {
|
||||
pattern: DiscoveredPattern {
|
||||
id: "test-1".to_string(),
|
||||
pattern_type: PatternType::CoherenceBreak,
|
||||
confidence: 0.85,
|
||||
affected_nodes: vec![1, 2, 3],
|
||||
detected_at: Utc::now(),
|
||||
description: "Test pattern".to_string(),
|
||||
evidence: vec![
|
||||
Evidence {
|
||||
evidence_type: "test".to_string(),
|
||||
value: 1.0,
|
||||
description: "Test evidence".to_string(),
|
||||
}
|
||||
],
|
||||
cross_domain_links: vec![],
|
||||
},
|
||||
p_value: 0.03,
|
||||
effect_size: 1.2,
|
||||
confidence_interval: (0.5, 1.5),
|
||||
is_significant: true,
|
||||
}
|
||||
];
|
||||
|
||||
// Save patterns
|
||||
save_patterns(&patterns, path, &PersistenceOptions::default()).unwrap();
|
||||
|
||||
// Load patterns
|
||||
let loaded = load_patterns(path).unwrap();
|
||||
|
||||
assert_eq!(loaded.len(), 1);
|
||||
assert_eq!(loaded[0].pattern.id, "test-1");
|
||||
assert_eq!(loaded[0].p_value, 0.03);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_load_patterns_compressed() {
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path();
|
||||
|
||||
let patterns = vec![
|
||||
SignificantPattern {
|
||||
pattern: DiscoveredPattern {
|
||||
id: "test-compressed".to_string(),
|
||||
pattern_type: PatternType::Consolidation,
|
||||
confidence: 0.90,
|
||||
affected_nodes: vec![4, 5, 6],
|
||||
detected_at: Utc::now(),
|
||||
description: "Compressed test pattern".to_string(),
|
||||
evidence: vec![],
|
||||
cross_domain_links: vec![],
|
||||
},
|
||||
p_value: 0.01,
|
||||
effect_size: 2.0,
|
||||
confidence_interval: (1.0, 3.0),
|
||||
is_significant: true,
|
||||
}
|
||||
];
|
||||
|
||||
// Save with compression
|
||||
save_patterns(&patterns, path, &PersistenceOptions::compressed()).unwrap();
|
||||
|
||||
// Verify compression
|
||||
assert!(is_compressed(path).unwrap());
|
||||
|
||||
// Load and verify
|
||||
let loaded = load_patterns(path).unwrap();
|
||||
assert_eq!(loaded.len(), 1);
|
||||
assert_eq!(loaded[0].pattern.id, "test-compressed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_patterns() {
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path();
|
||||
|
||||
let pattern1 = vec![
|
||||
SignificantPattern {
|
||||
pattern: DiscoveredPattern {
|
||||
id: "pattern-1".to_string(),
|
||||
pattern_type: PatternType::EmergingCluster,
|
||||
confidence: 0.75,
|
||||
affected_nodes: vec![1],
|
||||
detected_at: Utc::now(),
|
||||
description: "First pattern".to_string(),
|
||||
evidence: vec![],
|
||||
cross_domain_links: vec![],
|
||||
},
|
||||
p_value: 0.05,
|
||||
effect_size: 1.0,
|
||||
confidence_interval: (0.0, 2.0),
|
||||
is_significant: false,
|
||||
}
|
||||
];
|
||||
|
||||
let pattern2 = vec![
|
||||
SignificantPattern {
|
||||
pattern: DiscoveredPattern {
|
||||
id: "pattern-2".to_string(),
|
||||
pattern_type: PatternType::Cascade,
|
||||
confidence: 0.95,
|
||||
affected_nodes: vec![2],
|
||||
detected_at: Utc::now(),
|
||||
description: "Second pattern".to_string(),
|
||||
evidence: vec![],
|
||||
cross_domain_links: vec![],
|
||||
},
|
||||
p_value: 0.001,
|
||||
effect_size: 3.0,
|
||||
confidence_interval: (2.0, 4.0),
|
||||
is_significant: true,
|
||||
}
|
||||
];
|
||||
|
||||
// Save first pattern
|
||||
save_patterns(&pattern1, path, &PersistenceOptions::default()).unwrap();
|
||||
|
||||
// Append second pattern
|
||||
append_patterns(&pattern2, path).unwrap();
|
||||
|
||||
// Load and verify both are present
|
||||
let loaded = load_patterns(path).unwrap();
|
||||
assert_eq!(loaded.len(), 2);
|
||||
assert_eq!(loaded[0].pattern.id, "pattern-1");
|
||||
assert_eq!(loaded[1].pattern.id, "pattern-2");
|
||||
}
|
||||
}
|
||||
1155
vendor/ruvector/examples/data/framework/src/physics_clients.rs
vendored
Normal file
1155
vendor/ruvector/examples/data/framework/src/physics_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
558
vendor/ruvector/examples/data/framework/src/realtime.rs
vendored
Normal file
558
vendor/ruvector/examples/data/framework/src/realtime.rs
vendored
Normal file
@@ -0,0 +1,558 @@
|
||||
//! Real-Time Data Feed Integration
|
||||
//!
|
||||
//! RSS/Atom feed parsing, WebSocket streaming, and REST API polling
|
||||
//! for continuous data ingestion into RuVector discovery framework.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::interval;
|
||||
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Real-time engine for streaming data feeds
|
||||
pub struct RealTimeEngine {
|
||||
feeds: Vec<FeedSource>,
|
||||
update_interval: Duration,
|
||||
on_new_data: Option<Arc<dyn Fn(Vec<SemanticVector>) + Send + Sync>>,
|
||||
dedup_cache: Arc<RwLock<HashSet<String>>>,
|
||||
running: Arc<RwLock<bool>>,
|
||||
}
|
||||
|
||||
/// Types of feed sources
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum FeedSource {
|
||||
/// RSS or Atom feed
|
||||
Rss { url: String, category: String },
|
||||
/// REST API with polling
|
||||
RestPolling { url: String, interval: Duration },
|
||||
/// WebSocket streaming endpoint
|
||||
WebSocket { url: String },
|
||||
}
|
||||
|
||||
/// News aggregator for multiple RSS feeds
|
||||
pub struct NewsAggregator {
|
||||
sources: Vec<NewsSource>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
/// Individual news source configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NewsSource {
|
||||
pub name: String,
|
||||
pub feed_url: String,
|
||||
pub domain: Domain,
|
||||
}
|
||||
|
||||
/// Parsed feed item
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FeedItem {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub description: String,
|
||||
pub link: String,
|
||||
pub published: Option<chrono::DateTime<Utc>>,
|
||||
pub author: Option<String>,
|
||||
pub categories: Vec<String>,
|
||||
}
|
||||
|
||||
impl RealTimeEngine {
|
||||
/// Create a new real-time engine
|
||||
pub fn new(update_interval: Duration) -> Self {
|
||||
Self {
|
||||
feeds: Vec::new(),
|
||||
update_interval,
|
||||
on_new_data: None,
|
||||
dedup_cache: Arc::new(RwLock::new(HashSet::new())),
|
||||
running: Arc::new(RwLock::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a feed source to monitor
|
||||
pub fn add_feed(&mut self, source: FeedSource) {
|
||||
self.feeds.push(source);
|
||||
}
|
||||
|
||||
/// Set callback for new data
|
||||
pub fn set_callback<F>(&mut self, callback: F)
|
||||
where
|
||||
F: Fn(Vec<SemanticVector>) + Send + Sync + 'static,
|
||||
{
|
||||
self.on_new_data = Some(Arc::new(callback));
|
||||
}
|
||||
|
||||
/// Start the real-time engine
|
||||
pub async fn start(&mut self) -> Result<()> {
|
||||
{
|
||||
let mut running = self.running.write().await;
|
||||
if *running {
|
||||
return Err(FrameworkError::Config(
|
||||
"Engine already running".to_string(),
|
||||
));
|
||||
}
|
||||
*running = true;
|
||||
}
|
||||
|
||||
let feeds = self.feeds.clone();
|
||||
let callback = self.on_new_data.clone();
|
||||
let dedup_cache = self.dedup_cache.clone();
|
||||
let update_interval = self.update_interval;
|
||||
let running = self.running.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut ticker = interval(update_interval);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
// Check if we should stop
|
||||
{
|
||||
let is_running = running.read().await;
|
||||
if !*is_running {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Process all feeds
|
||||
for feed in &feeds {
|
||||
match Self::process_feed(feed, &dedup_cache).await {
|
||||
Ok(vectors) => {
|
||||
if !vectors.is_empty() {
|
||||
if let Some(ref cb) = callback {
|
||||
cb(vectors);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Feed processing error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the real-time engine
|
||||
pub async fn stop(&mut self) {
|
||||
let mut running = self.running.write().await;
|
||||
*running = false;
|
||||
}
|
||||
|
||||
/// Process a single feed source
|
||||
async fn process_feed(
|
||||
feed: &FeedSource,
|
||||
dedup_cache: &Arc<RwLock<HashSet<String>>>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
match feed {
|
||||
FeedSource::Rss { url, category } => {
|
||||
Self::process_rss_feed(url, category, dedup_cache).await
|
||||
}
|
||||
FeedSource::RestPolling { url, .. } => {
|
||||
Self::process_rest_feed(url, dedup_cache).await
|
||||
}
|
||||
FeedSource::WebSocket { url } => Self::process_websocket_feed(url, dedup_cache).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process RSS/Atom feed
|
||||
async fn process_rss_feed(
|
||||
url: &str,
|
||||
category: &str,
|
||||
dedup_cache: &Arc<RwLock<HashSet<String>>>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let client = reqwest::Client::new();
|
||||
let response = client.get(url).send().await?;
|
||||
let content = response.text().await?;
|
||||
|
||||
// Parse RSS/Atom feed
|
||||
let items = Self::parse_rss(&content)?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
let mut cache = dedup_cache.write().await;
|
||||
|
||||
for item in items {
|
||||
// Check for duplicates
|
||||
if cache.contains(&item.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add to dedup cache
|
||||
cache.insert(item.id.clone());
|
||||
|
||||
// Convert to SemanticVector
|
||||
let domain = Self::category_to_domain(category);
|
||||
let vector = Self::item_to_vector(item, domain);
|
||||
vectors.push(vector);
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Process REST API polling
|
||||
async fn process_rest_feed(
|
||||
url: &str,
|
||||
dedup_cache: &Arc<RwLock<HashSet<String>>>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
let client = reqwest::Client::new();
|
||||
let response = client.get(url).send().await?;
|
||||
let items: Vec<FeedItem> = response.json().await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
let mut cache = dedup_cache.write().await;
|
||||
|
||||
for item in items {
|
||||
if cache.contains(&item.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cache.insert(item.id.clone());
|
||||
let vector = Self::item_to_vector(item, Domain::Research);
|
||||
vectors.push(vector);
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Process WebSocket stream (simplified implementation)
|
||||
async fn process_websocket_feed(
|
||||
_url: &str,
|
||||
_dedup_cache: &Arc<RwLock<HashSet<String>>>,
|
||||
) -> Result<Vec<SemanticVector>> {
|
||||
// WebSocket implementation would require tokio-tungstenite
|
||||
// For now, return empty - can be extended with actual WebSocket client
|
||||
tracing::warn!("WebSocket feeds not yet implemented");
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Parse RSS/Atom XML into feed items
|
||||
fn parse_rss(content: &str) -> Result<Vec<FeedItem>> {
|
||||
// Simple XML parsing for RSS 2.0
|
||||
// In production, use feed-rs or rss crate
|
||||
let mut items = Vec::new();
|
||||
|
||||
// Basic RSS parsing (simplified)
|
||||
for item_block in content.split("<item>").skip(1) {
|
||||
if let Some(end) = item_block.find("</item>") {
|
||||
let item_xml = &item_block[..end];
|
||||
if let Some(item) = Self::parse_rss_item(item_xml) {
|
||||
items.push(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
/// Parse a single RSS item from XML
|
||||
fn parse_rss_item(xml: &str) -> Option<FeedItem> {
|
||||
let title = Self::extract_tag(xml, "title")?;
|
||||
let description = Self::extract_tag(xml, "description").unwrap_or_default();
|
||||
let link = Self::extract_tag(xml, "link").unwrap_or_default();
|
||||
let guid = Self::extract_tag(xml, "guid").unwrap_or_else(|| link.clone());
|
||||
|
||||
let published = Self::extract_tag(xml, "pubDate")
|
||||
.and_then(|date_str| chrono::DateTime::parse_from_rfc2822(&date_str).ok())
|
||||
.map(|dt| dt.with_timezone(&Utc));
|
||||
|
||||
let author = Self::extract_tag(xml, "author");
|
||||
|
||||
Some(FeedItem {
|
||||
id: guid,
|
||||
title,
|
||||
description,
|
||||
link,
|
||||
published,
|
||||
author,
|
||||
categories: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract content between XML tags
|
||||
fn extract_tag(xml: &str, tag: &str) -> Option<String> {
|
||||
let start_tag = format!("<{}>", tag);
|
||||
let end_tag = format!("</{}>", tag);
|
||||
|
||||
let start = xml.find(&start_tag)? + start_tag.len();
|
||||
let end = xml.find(&end_tag)?;
|
||||
|
||||
if start < end {
|
||||
let content = &xml[start..end];
|
||||
// Basic HTML entity decoding
|
||||
let decoded = content
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace("&", "&")
|
||||
.replace(""", "\"")
|
||||
.replace("'", "'");
|
||||
Some(decoded.trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert category string to Domain enum
|
||||
fn category_to_domain(category: &str) -> Domain {
|
||||
match category.to_lowercase().as_str() {
|
||||
"climate" | "weather" | "environment" => Domain::Climate,
|
||||
"finance" | "economy" | "market" | "stock" => Domain::Finance,
|
||||
"research" | "science" | "academic" | "medical" => Domain::Research,
|
||||
_ => Domain::CrossDomain,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert FeedItem to SemanticVector
|
||||
fn item_to_vector(item: FeedItem, domain: Domain) -> SemanticVector {
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Create a simple embedding from title + description
|
||||
// In production, use actual embedding model
|
||||
let text = format!("{} {}", item.title, item.description);
|
||||
let embedding = Self::simple_embedding(&text);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("title".to_string(), item.title.clone());
|
||||
metadata.insert("link".to_string(), item.link.clone());
|
||||
if let Some(author) = item.author {
|
||||
metadata.insert("author".to_string(), author);
|
||||
}
|
||||
|
||||
SemanticVector {
|
||||
id: item.id,
|
||||
embedding,
|
||||
domain,
|
||||
timestamp: item.published.unwrap_or_else(Utc::now),
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple embedding generation (hash-based for demo)
|
||||
fn simple_embedding(text: &str) -> Vec<f32> {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
// Create 384-dimensional embedding from text hash
|
||||
let mut embedding = vec![0.0f32; 384];
|
||||
|
||||
for (i, word) in text.split_whitespace().take(384).enumerate() {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
word.hash(&mut hasher);
|
||||
let hash = hasher.finish();
|
||||
embedding[i] = (hash as f32 / u64::MAX as f32) * 2.0 - 1.0;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if magnitude > 0.0 {
|
||||
for val in &mut embedding {
|
||||
*val /= magnitude;
|
||||
}
|
||||
}
|
||||
|
||||
embedding
|
||||
}
|
||||
}
|
||||
|
||||
impl NewsAggregator {
|
||||
/// Create a new news aggregator
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sources: Vec::new(),
|
||||
client: reqwest::Client::builder()
|
||||
.user_agent("RuVector/1.0")
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a news source
|
||||
pub fn add_source(&mut self, source: NewsSource) {
|
||||
self.sources.push(source);
|
||||
}
|
||||
|
||||
/// Add default free news sources
|
||||
pub fn add_default_sources(&mut self) {
|
||||
// Climate sources
|
||||
self.add_source(NewsSource {
|
||||
name: "NASA Earth Observatory".to_string(),
|
||||
feed_url: "https://earthobservatory.nasa.gov/feeds/image-of-the-day.rss".to_string(),
|
||||
domain: Domain::Climate,
|
||||
});
|
||||
|
||||
// Financial sources
|
||||
self.add_source(NewsSource {
|
||||
name: "Yahoo Finance - Top Stories".to_string(),
|
||||
feed_url: "https://finance.yahoo.com/news/rssindex".to_string(),
|
||||
domain: Domain::Finance,
|
||||
});
|
||||
|
||||
// Medical/Research sources
|
||||
self.add_source(NewsSource {
|
||||
name: "PubMed Recent".to_string(),
|
||||
feed_url: "https://pubmed.ncbi.nlm.nih.gov/rss/search/1nKx2zx8g-9UCGpQD5qVmN6jTvSRRxYqjD3T_nA-pSMjDlXr4u/?limit=100&utm_campaign=pubmed-2&fc=20210421200858".to_string(),
|
||||
domain: Domain::Research,
|
||||
});
|
||||
|
||||
// General news sources
|
||||
self.add_source(NewsSource {
|
||||
name: "Reuters Top News".to_string(),
|
||||
feed_url: "https://www.reutersagency.com/feed/?taxonomy=best-topics&post_type=best".to_string(),
|
||||
domain: Domain::CrossDomain,
|
||||
});
|
||||
|
||||
self.add_source(NewsSource {
|
||||
name: "AP News Top Stories".to_string(),
|
||||
feed_url: "https://apnews.com/index.rss".to_string(),
|
||||
domain: Domain::CrossDomain,
|
||||
});
|
||||
}
|
||||
|
||||
/// Fetch latest items from all sources
|
||||
pub async fn fetch_latest(&self, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let mut all_vectors = Vec::new();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for source in &self.sources {
|
||||
match self.fetch_source(source, limit).await {
|
||||
Ok(vectors) => {
|
||||
for vector in vectors {
|
||||
if !seen.contains(&vector.id) {
|
||||
seen.insert(vector.id.clone());
|
||||
all_vectors.push(vector);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch {}: {}", source.name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by timestamp, most recent first
|
||||
all_vectors.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
|
||||
|
||||
// Limit results
|
||||
all_vectors.truncate(limit);
|
||||
|
||||
Ok(all_vectors)
|
||||
}
|
||||
|
||||
/// Fetch from a single source
|
||||
async fn fetch_source(&self, source: &NewsSource, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let response = self.client.get(&source.feed_url).send().await?;
|
||||
let content = response.text().await?;
|
||||
|
||||
let items = RealTimeEngine::parse_rss(&content)?;
|
||||
let mut vectors = Vec::new();
|
||||
|
||||
for item in items.into_iter().take(limit) {
|
||||
let vector = RealTimeEngine::item_to_vector(item, source.domain);
|
||||
vectors.push(vector);
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NewsAggregator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_tag() {
|
||||
let xml = "<title>Test Title</title><description>Test Description</description>";
|
||||
assert_eq!(
|
||||
RealTimeEngine::extract_tag(xml, "title"),
|
||||
Some("Test Title".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
RealTimeEngine::extract_tag(xml, "description"),
|
||||
Some("Test Description".to_string())
|
||||
);
|
||||
assert_eq!(RealTimeEngine::extract_tag(xml, "missing"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_category_to_domain() {
|
||||
assert_eq!(
|
||||
RealTimeEngine::category_to_domain("climate"),
|
||||
Domain::Climate
|
||||
);
|
||||
assert_eq!(
|
||||
RealTimeEngine::category_to_domain("Finance"),
|
||||
Domain::Finance
|
||||
);
|
||||
assert_eq!(
|
||||
RealTimeEngine::category_to_domain("research"),
|
||||
Domain::Research
|
||||
);
|
||||
assert_eq!(
|
||||
RealTimeEngine::category_to_domain("other"),
|
||||
Domain::CrossDomain
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_embedding() {
|
||||
let embedding = RealTimeEngine::simple_embedding("climate change impacts");
|
||||
assert_eq!(embedding.len(), 384);
|
||||
|
||||
// Check normalization
|
||||
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((magnitude - 1.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_realtime_engine_lifecycle() {
|
||||
let mut engine = RealTimeEngine::new(Duration::from_secs(1));
|
||||
|
||||
engine.add_feed(FeedSource::Rss {
|
||||
url: "https://example.com/feed.rss".to_string(),
|
||||
category: "climate".to_string(),
|
||||
});
|
||||
|
||||
// Start and stop
|
||||
assert!(engine.start().await.is_ok());
|
||||
engine.stop().await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_news_aggregator() {
|
||||
let mut aggregator = NewsAggregator::new();
|
||||
aggregator.add_default_sources();
|
||||
assert!(aggregator.sources.len() >= 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_rss_item() {
|
||||
let xml = r#"
|
||||
<title>Test Article</title>
|
||||
<description>This is a test article</description>
|
||||
<link>https://example.com/article</link>
|
||||
<guid>article-123</guid>
|
||||
<pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate>
|
||||
"#;
|
||||
|
||||
let item = RealTimeEngine::parse_rss_item(xml);
|
||||
assert!(item.is_some());
|
||||
|
||||
let item = item.unwrap();
|
||||
assert_eq!(item.title, "Test Article");
|
||||
assert_eq!(item.description, "This is a test article");
|
||||
assert_eq!(item.link, "https://example.com/article");
|
||||
assert_eq!(item.id, "article-123");
|
||||
}
|
||||
}
|
||||
854
vendor/ruvector/examples/data/framework/src/ruvector_native.rs
vendored
Normal file
854
vendor/ruvector/examples/data/framework/src/ruvector_native.rs
vendored
Normal file
@@ -0,0 +1,854 @@
|
||||
//! RuVector-Native Discovery Engine
|
||||
//!
|
||||
//! Deep integration with ruvector-core, ruvector-graph, and ruvector-mincut
|
||||
//! for production-grade coherence analysis and pattern discovery.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::utils::cosine_similarity;
|
||||
|
||||
/// Vector embedding for semantic similarity
|
||||
/// Uses RuVector's native vector storage format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SemanticVector {
|
||||
/// Vector ID
|
||||
pub id: String,
|
||||
/// Dense embedding (typically 384-1536 dimensions)
|
||||
pub embedding: Vec<f32>,
|
||||
/// Source domain
|
||||
pub domain: Domain,
|
||||
/// Timestamp
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Metadata for filtering
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Discovery domains
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum Domain {
|
||||
Climate,
|
||||
Finance,
|
||||
Research,
|
||||
Medical,
|
||||
Economic,
|
||||
Genomics,
|
||||
Physics,
|
||||
Seismic,
|
||||
Ocean,
|
||||
Space,
|
||||
Transportation,
|
||||
Geospatial,
|
||||
Government,
|
||||
CrossDomain,
|
||||
}
|
||||
|
||||
/// RuVector-native graph node
|
||||
/// Designed to work with ruvector-graph's adjacency structures
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphNode {
|
||||
/// Node ID (u32 for ruvector compatibility)
|
||||
pub id: u32,
|
||||
/// String identifier for external reference
|
||||
pub external_id: String,
|
||||
/// Domain
|
||||
pub domain: Domain,
|
||||
/// Associated vector embedding index
|
||||
pub vector_idx: Option<usize>,
|
||||
/// Node weight (for weighted min-cut)
|
||||
pub weight: f64,
|
||||
/// Attributes
|
||||
pub attributes: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
/// RuVector-native graph edge
|
||||
/// Compatible with ruvector-mincut's edge format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphEdge {
|
||||
/// Source node ID
|
||||
pub source: u32,
|
||||
/// Target node ID
|
||||
pub target: u32,
|
||||
/// Edge weight (capacity for min-cut)
|
||||
pub weight: f64,
|
||||
/// Edge type
|
||||
pub edge_type: EdgeType,
|
||||
/// Timestamp when edge was created/updated
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Types of edges in the discovery graph
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum EdgeType {
|
||||
/// Correlation-based (e.g., temperature correlation)
|
||||
Correlation,
|
||||
/// Similarity-based (e.g., vector cosine similarity)
|
||||
Similarity,
|
||||
/// Citation/reference link
|
||||
Citation,
|
||||
/// Causal relationship
|
||||
Causal,
|
||||
/// Cross-domain bridge
|
||||
CrossDomain,
|
||||
}
|
||||
|
||||
/// Configuration for the native discovery engine
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NativeEngineConfig {
|
||||
/// Minimum edge weight to include
|
||||
pub min_edge_weight: f64,
|
||||
/// Vector similarity threshold
|
||||
pub similarity_threshold: f64,
|
||||
/// Min-cut sensitivity (lower = more sensitive to breaks)
|
||||
pub mincut_sensitivity: f64,
|
||||
/// Enable cross-domain discovery
|
||||
pub cross_domain: bool,
|
||||
/// Window size for temporal analysis (seconds)
|
||||
pub window_seconds: i64,
|
||||
/// HNSW parameters
|
||||
pub hnsw_m: usize,
|
||||
pub hnsw_ef_construction: usize,
|
||||
pub hnsw_ef_search: usize,
|
||||
/// Vector dimension
|
||||
pub dimension: usize,
|
||||
/// Batch size for processing
|
||||
pub batch_size: usize,
|
||||
/// Checkpoint interval (records)
|
||||
pub checkpoint_interval: u64,
|
||||
/// Number of parallel workers
|
||||
pub parallel_workers: usize,
|
||||
}
|
||||
|
||||
impl Default for NativeEngineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_edge_weight: 0.3,
|
||||
similarity_threshold: 0.7,
|
||||
mincut_sensitivity: 0.15,
|
||||
cross_domain: true,
|
||||
window_seconds: 86400 * 30, // 30 days
|
||||
hnsw_m: 16,
|
||||
hnsw_ef_construction: 200,
|
||||
hnsw_ef_search: 50,
|
||||
dimension: 384,
|
||||
batch_size: 1000,
|
||||
checkpoint_interval: 10_000,
|
||||
parallel_workers: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The main RuVector-native discovery engine
|
||||
///
|
||||
/// This engine uses RuVector's core algorithms:
|
||||
/// - Vector similarity via HNSW index
|
||||
/// - Graph coherence via Stoer-Wagner min-cut
|
||||
/// - Temporal windowing for streaming analysis
|
||||
pub struct NativeDiscoveryEngine {
|
||||
config: NativeEngineConfig,
|
||||
|
||||
/// Vector storage (would use ruvector-core in production)
|
||||
vectors: Vec<SemanticVector>,
|
||||
|
||||
/// Graph nodes
|
||||
nodes: HashMap<u32, GraphNode>,
|
||||
|
||||
/// Graph edges (adjacency list format for ruvector-mincut)
|
||||
edges: Vec<GraphEdge>,
|
||||
|
||||
/// Historical coherence values for change detection
|
||||
coherence_history: Vec<(DateTime<Utc>, f64, CoherenceSnapshot)>,
|
||||
|
||||
/// Next node ID
|
||||
next_node_id: u32,
|
||||
|
||||
/// Domain-specific subgraph indices
|
||||
domain_nodes: HashMap<Domain, Vec<u32>>,
|
||||
}
|
||||
|
||||
/// Snapshot of coherence state for historical comparison
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CoherenceSnapshot {
|
||||
/// Min-cut value
|
||||
pub mincut_value: f64,
|
||||
/// Number of nodes
|
||||
pub node_count: usize,
|
||||
/// Number of edges
|
||||
pub edge_count: usize,
|
||||
/// Partition sizes after min-cut
|
||||
pub partition_sizes: (usize, usize),
|
||||
/// Boundary nodes (nodes on the cut)
|
||||
pub boundary_nodes: Vec<u32>,
|
||||
/// Average edge weight
|
||||
pub avg_edge_weight: f64,
|
||||
}
|
||||
|
||||
/// A detected pattern or anomaly
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiscoveredPattern {
|
||||
/// Pattern ID
|
||||
pub id: String,
|
||||
/// Pattern type
|
||||
pub pattern_type: PatternType,
|
||||
/// Confidence score (0-1)
|
||||
pub confidence: f64,
|
||||
/// Affected nodes
|
||||
pub affected_nodes: Vec<u32>,
|
||||
/// Timestamp of detection
|
||||
pub detected_at: DateTime<Utc>,
|
||||
/// Description
|
||||
pub description: String,
|
||||
/// Evidence
|
||||
pub evidence: Vec<Evidence>,
|
||||
/// Cross-domain connections if applicable
|
||||
pub cross_domain_links: Vec<CrossDomainLink>,
|
||||
}
|
||||
|
||||
/// Types of discoverable patterns
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum PatternType {
|
||||
/// Network coherence break (min-cut dropped)
|
||||
CoherenceBreak,
|
||||
/// Network consolidation (min-cut increased)
|
||||
Consolidation,
|
||||
/// Emerging cluster (new dense subgraph)
|
||||
EmergingCluster,
|
||||
/// Dissolving cluster
|
||||
DissolvingCluster,
|
||||
/// Bridge formation (cross-domain connection)
|
||||
BridgeFormation,
|
||||
/// Anomalous node (outlier in vector space)
|
||||
AnomalousNode,
|
||||
/// Temporal shift (pattern change over time)
|
||||
TemporalShift,
|
||||
/// Cascade (change propagating through network)
|
||||
Cascade,
|
||||
}
|
||||
|
||||
/// Evidence supporting a pattern detection
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Evidence {
|
||||
pub evidence_type: String,
|
||||
pub value: f64,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// Cross-domain link discovered
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrossDomainLink {
|
||||
pub source_domain: Domain,
|
||||
pub target_domain: Domain,
|
||||
pub source_nodes: Vec<u32>,
|
||||
pub target_nodes: Vec<u32>,
|
||||
pub link_strength: f64,
|
||||
pub link_type: String,
|
||||
}
|
||||
|
||||
impl NativeDiscoveryEngine {
|
||||
/// Create a new engine with the given configuration
|
||||
pub fn new(config: NativeEngineConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
vectors: Vec::new(),
|
||||
nodes: HashMap::new(),
|
||||
edges: Vec::new(),
|
||||
coherence_history: Vec::new(),
|
||||
next_node_id: 0,
|
||||
domain_nodes: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a vector to the engine
|
||||
/// In production, this would use ruvector-core's vector storage
|
||||
pub fn add_vector(&mut self, vector: SemanticVector) -> u32 {
|
||||
let node_id = self.next_node_id;
|
||||
self.next_node_id += 1;
|
||||
|
||||
let vector_idx = self.vectors.len();
|
||||
self.vectors.push(vector.clone());
|
||||
|
||||
let node = GraphNode {
|
||||
id: node_id,
|
||||
external_id: vector.id.clone(),
|
||||
domain: vector.domain,
|
||||
vector_idx: Some(vector_idx),
|
||||
weight: 1.0,
|
||||
attributes: HashMap::new(),
|
||||
};
|
||||
|
||||
self.nodes.insert(node_id, node);
|
||||
self.domain_nodes.entry(vector.domain).or_default().push(node_id);
|
||||
|
||||
// Auto-connect to similar vectors
|
||||
self.connect_similar_vectors(node_id);
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
/// Connect a node to similar vectors using cosine similarity
|
||||
/// In production, this would use ruvector-hnsw for O(log n) search
|
||||
fn connect_similar_vectors(&mut self, node_id: u32) {
|
||||
let node = match self.nodes.get(&node_id) {
|
||||
Some(n) => n.clone(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let vector_idx = match node.vector_idx {
|
||||
Some(idx) => idx,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let source_vec = &self.vectors[vector_idx].embedding;
|
||||
|
||||
// Find similar vectors (brute force - would use HNSW in production)
|
||||
for (other_id, other_node) in &self.nodes {
|
||||
if *other_id == node_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(other_idx) = other_node.vector_idx {
|
||||
let other_vec = &self.vectors[other_idx].embedding;
|
||||
let similarity = cosine_similarity(source_vec, other_vec);
|
||||
|
||||
if similarity >= self.config.similarity_threshold as f32 {
|
||||
// Determine edge type
|
||||
let edge_type = if node.domain != other_node.domain {
|
||||
EdgeType::CrossDomain
|
||||
} else {
|
||||
EdgeType::Similarity
|
||||
};
|
||||
|
||||
self.edges.push(GraphEdge {
|
||||
source: node_id,
|
||||
target: *other_id,
|
||||
weight: similarity as f64,
|
||||
edge_type,
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a correlation-based edge
|
||||
pub fn add_correlation_edge(&mut self, source: u32, target: u32, correlation: f64) {
|
||||
if correlation.abs() >= self.config.min_edge_weight {
|
||||
self.edges.push(GraphEdge {
|
||||
source,
|
||||
target,
|
||||
weight: correlation.abs(),
|
||||
edge_type: EdgeType::Correlation,
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute current coherence using Stoer-Wagner min-cut
|
||||
///
|
||||
/// The min-cut value represents the "weakest link" in the network.
|
||||
/// A drop in min-cut indicates the network is becoming fragmented.
|
||||
pub fn compute_coherence(&self) -> CoherenceSnapshot {
|
||||
if self.nodes.is_empty() || self.edges.is_empty() {
|
||||
return CoherenceSnapshot {
|
||||
mincut_value: 0.0,
|
||||
node_count: self.nodes.len(),
|
||||
edge_count: self.edges.len(),
|
||||
partition_sizes: (0, 0),
|
||||
boundary_nodes: vec![],
|
||||
avg_edge_weight: 0.0,
|
||||
};
|
||||
}
|
||||
|
||||
// Build adjacency matrix for min-cut
|
||||
// In production, this would call ruvector-mincut directly
|
||||
let mincut_result = self.stoer_wagner_mincut();
|
||||
|
||||
let avg_edge_weight = if self.edges.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
self.edges.iter().map(|e| e.weight).sum::<f64>() / self.edges.len() as f64
|
||||
};
|
||||
|
||||
CoherenceSnapshot {
|
||||
mincut_value: mincut_result.0,
|
||||
node_count: self.nodes.len(),
|
||||
edge_count: self.edges.len(),
|
||||
partition_sizes: mincut_result.1,
|
||||
boundary_nodes: mincut_result.2,
|
||||
avg_edge_weight,
|
||||
}
|
||||
}
|
||||
|
||||
/// Stoer-Wagner minimum cut algorithm
|
||||
/// Returns (min_cut_value, partition_sizes, boundary_nodes)
|
||||
fn stoer_wagner_mincut(&self) -> (f64, (usize, usize), Vec<u32>) {
|
||||
let n = self.nodes.len();
|
||||
if n < 2 {
|
||||
return (0.0, (n, 0), vec![]);
|
||||
}
|
||||
|
||||
// Build adjacency matrix
|
||||
let node_ids: Vec<u32> = self.nodes.keys().copied().collect();
|
||||
let id_to_idx: HashMap<u32, usize> = node_ids.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &id)| (id, i))
|
||||
.collect();
|
||||
|
||||
let mut adj = vec![vec![0.0; n]; n];
|
||||
for edge in &self.edges {
|
||||
if let (Some(&i), Some(&j)) = (id_to_idx.get(&edge.source), id_to_idx.get(&edge.target)) {
|
||||
adj[i][j] += edge.weight;
|
||||
adj[j][i] += edge.weight;
|
||||
}
|
||||
}
|
||||
|
||||
// Stoer-Wagner algorithm
|
||||
let mut best_cut = f64::INFINITY;
|
||||
let mut best_partition = (0, 0);
|
||||
let mut best_boundary = vec![];
|
||||
|
||||
let mut active: Vec<bool> = vec![true; n];
|
||||
let mut merged: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
|
||||
|
||||
for phase in 0..(n - 1) {
|
||||
// Maximum adjacency search
|
||||
let mut in_a = vec![false; n];
|
||||
let mut key = vec![0.0; n];
|
||||
|
||||
// Find first active node
|
||||
let start = (0..n).find(|&i| active[i]).unwrap();
|
||||
in_a[start] = true;
|
||||
|
||||
// Update keys
|
||||
for j in 0..n {
|
||||
if active[j] && !in_a[j] {
|
||||
key[j] = adj[start][j];
|
||||
}
|
||||
}
|
||||
|
||||
let mut s = start;
|
||||
let mut t = start;
|
||||
|
||||
for _ in 1..=(n - 1 - phase) {
|
||||
// Find max key not in A
|
||||
let mut max_key = f64::NEG_INFINITY;
|
||||
let mut max_node = 0;
|
||||
|
||||
for j in 0..n {
|
||||
if active[j] && !in_a[j] && key[j] > max_key {
|
||||
max_key = key[j];
|
||||
max_node = j;
|
||||
}
|
||||
}
|
||||
|
||||
s = t;
|
||||
t = max_node;
|
||||
in_a[t] = true;
|
||||
|
||||
// Update keys
|
||||
for j in 0..n {
|
||||
if active[j] && !in_a[j] {
|
||||
key[j] += adj[t][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cut of the phase
|
||||
let cut_weight = key[t];
|
||||
|
||||
if cut_weight < best_cut {
|
||||
best_cut = cut_weight;
|
||||
|
||||
// Partition is: merged[t] vs everything else
|
||||
let partition_a: Vec<usize> = merged[t].clone();
|
||||
let partition_b: Vec<usize> = (0..n)
|
||||
.filter(|&i| active[i] && i != t)
|
||||
.flat_map(|i| merged[i].iter().copied())
|
||||
.collect();
|
||||
|
||||
best_partition = (partition_a.len(), partition_b.len());
|
||||
|
||||
// Boundary nodes are those in the smaller partition with edges to other
|
||||
best_boundary = partition_a.iter()
|
||||
.map(|&i| node_ids[i])
|
||||
.collect();
|
||||
}
|
||||
|
||||
// Merge s and t
|
||||
active[t] = false;
|
||||
let to_merge: Vec<usize> = merged[t].clone();
|
||||
merged[s].extend(to_merge);
|
||||
|
||||
for i in 0..n {
|
||||
if active[i] && i != s {
|
||||
adj[s][i] += adj[t][i];
|
||||
adj[i][s] += adj[i][t];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(best_cut, best_partition, best_boundary)
|
||||
}
|
||||
|
||||
/// Detect patterns by comparing current state to history
|
||||
pub fn detect_patterns(&mut self) -> Vec<DiscoveredPattern> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
let current = self.compute_coherence();
|
||||
let now = Utc::now();
|
||||
|
||||
// Compare to previous state
|
||||
if let Some((prev_time, prev_mincut, prev_snapshot)) = self.coherence_history.last() {
|
||||
let mincut_delta = current.mincut_value - prev_mincut;
|
||||
let relative_change = if *prev_mincut > 0.0 {
|
||||
mincut_delta.abs() / prev_mincut
|
||||
} else {
|
||||
mincut_delta.abs()
|
||||
};
|
||||
|
||||
// Detect coherence break
|
||||
if mincut_delta < -self.config.mincut_sensitivity {
|
||||
patterns.push(DiscoveredPattern {
|
||||
id: format!("coherence_break_{}", now.timestamp()),
|
||||
pattern_type: PatternType::CoherenceBreak,
|
||||
confidence: (relative_change.min(1.0) * 0.5 + 0.5),
|
||||
affected_nodes: current.boundary_nodes.clone(),
|
||||
detected_at: now,
|
||||
description: format!(
|
||||
"Network coherence dropped from {:.3} to {:.3} ({:.1}% decrease)",
|
||||
prev_mincut, current.mincut_value, relative_change * 100.0
|
||||
),
|
||||
evidence: vec![
|
||||
Evidence {
|
||||
evidence_type: "mincut_delta".to_string(),
|
||||
value: mincut_delta,
|
||||
description: "Change in min-cut value".to_string(),
|
||||
},
|
||||
Evidence {
|
||||
evidence_type: "boundary_size".to_string(),
|
||||
value: current.boundary_nodes.len() as f64,
|
||||
description: "Number of nodes on the cut".to_string(),
|
||||
},
|
||||
],
|
||||
cross_domain_links: self.find_cross_domain_at_boundary(¤t.boundary_nodes),
|
||||
});
|
||||
}
|
||||
|
||||
// Detect consolidation
|
||||
if mincut_delta > self.config.mincut_sensitivity {
|
||||
patterns.push(DiscoveredPattern {
|
||||
id: format!("consolidation_{}", now.timestamp()),
|
||||
pattern_type: PatternType::Consolidation,
|
||||
confidence: (relative_change.min(1.0) * 0.5 + 0.5),
|
||||
affected_nodes: current.boundary_nodes.clone(),
|
||||
detected_at: now,
|
||||
description: format!(
|
||||
"Network coherence increased from {:.3} to {:.3} ({:.1}% increase)",
|
||||
prev_mincut, current.mincut_value, relative_change * 100.0
|
||||
),
|
||||
evidence: vec![
|
||||
Evidence {
|
||||
evidence_type: "mincut_delta".to_string(),
|
||||
value: mincut_delta,
|
||||
description: "Change in min-cut value".to_string(),
|
||||
},
|
||||
],
|
||||
cross_domain_links: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Detect partition imbalance (emerging cluster)
|
||||
let (part_a, part_b) = current.partition_sizes;
|
||||
let imbalance = (part_a as f64 - part_b as f64).abs() / (part_a + part_b) as f64;
|
||||
let (prev_a, prev_b) = prev_snapshot.partition_sizes;
|
||||
let prev_imbalance = if prev_a + prev_b > 0 {
|
||||
(prev_a as f64 - prev_b as f64).abs() / (prev_a + prev_b) as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if imbalance > prev_imbalance + 0.2 {
|
||||
patterns.push(DiscoveredPattern {
|
||||
id: format!("emerging_cluster_{}", now.timestamp()),
|
||||
pattern_type: PatternType::EmergingCluster,
|
||||
confidence: 0.7,
|
||||
affected_nodes: current.boundary_nodes.clone(),
|
||||
detected_at: now,
|
||||
description: format!(
|
||||
"Partition imbalance increased: {} vs {} nodes (was {} vs {})",
|
||||
part_a, part_b, prev_a, prev_b
|
||||
),
|
||||
evidence: vec![],
|
||||
cross_domain_links: vec![],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Cross-domain pattern detection
|
||||
if self.config.cross_domain {
|
||||
patterns.extend(self.detect_cross_domain_patterns());
|
||||
}
|
||||
|
||||
// Store current state in history
|
||||
self.coherence_history.push((now, current.mincut_value, current));
|
||||
|
||||
patterns
|
||||
}
|
||||
|
||||
/// Find cross-domain links at boundary nodes
|
||||
fn find_cross_domain_at_boundary(&self, boundary: &[u32]) -> Vec<CrossDomainLink> {
|
||||
let mut links = Vec::new();
|
||||
|
||||
// Find cross-domain edges involving boundary nodes
|
||||
for edge in &self.edges {
|
||||
if edge.edge_type == EdgeType::CrossDomain {
|
||||
if boundary.contains(&edge.source) || boundary.contains(&edge.target) {
|
||||
if let (Some(src_node), Some(tgt_node)) =
|
||||
(self.nodes.get(&edge.source), self.nodes.get(&edge.target))
|
||||
{
|
||||
links.push(CrossDomainLink {
|
||||
source_domain: src_node.domain,
|
||||
target_domain: tgt_node.domain,
|
||||
source_nodes: vec![edge.source],
|
||||
target_nodes: vec![edge.target],
|
||||
link_strength: edge.weight,
|
||||
link_type: "boundary_crossing".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
links
|
||||
}
|
||||
|
||||
/// Detect patterns that span multiple domains
|
||||
fn detect_cross_domain_patterns(&self) -> Vec<DiscoveredPattern> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
// Count cross-domain edges by domain pair
|
||||
let mut cross_counts: HashMap<(Domain, Domain), Vec<&GraphEdge>> = HashMap::new();
|
||||
|
||||
for edge in &self.edges {
|
||||
if edge.edge_type == EdgeType::CrossDomain {
|
||||
if let (Some(src), Some(tgt)) =
|
||||
(self.nodes.get(&edge.source), self.nodes.get(&edge.target))
|
||||
{
|
||||
let key = if src.domain < tgt.domain {
|
||||
(src.domain, tgt.domain)
|
||||
} else {
|
||||
(tgt.domain, src.domain)
|
||||
};
|
||||
cross_counts.entry(key).or_default().push(edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Report significant cross-domain bridges
|
||||
for ((domain_a, domain_b), edges) in cross_counts {
|
||||
if edges.len() >= 3 {
|
||||
let avg_strength = edges.iter().map(|e| e.weight).sum::<f64>() / edges.len() as f64;
|
||||
|
||||
if avg_strength > self.config.similarity_threshold as f64 {
|
||||
patterns.push(DiscoveredPattern {
|
||||
id: format!("bridge_{:?}_{:?}_{}", domain_a, domain_b, Utc::now().timestamp()),
|
||||
pattern_type: PatternType::BridgeFormation,
|
||||
confidence: avg_strength,
|
||||
affected_nodes: edges.iter()
|
||||
.flat_map(|e| vec![e.source, e.target])
|
||||
.collect(),
|
||||
detected_at: Utc::now(),
|
||||
description: format!(
|
||||
"Cross-domain bridge detected: {:?} ↔ {:?} ({} connections, avg strength {:.3})",
|
||||
domain_a, domain_b, edges.len(), avg_strength
|
||||
),
|
||||
evidence: vec![
|
||||
Evidence {
|
||||
evidence_type: "edge_count".to_string(),
|
||||
value: edges.len() as f64,
|
||||
description: "Number of cross-domain connections".to_string(),
|
||||
},
|
||||
],
|
||||
cross_domain_links: vec![CrossDomainLink {
|
||||
source_domain: domain_a,
|
||||
target_domain: domain_b,
|
||||
source_nodes: edges.iter().map(|e| e.source).collect(),
|
||||
target_nodes: edges.iter().map(|e| e.target).collect(),
|
||||
link_strength: avg_strength,
|
||||
link_type: "semantic_bridge".to_string(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patterns
|
||||
}
|
||||
|
||||
/// Get domain-specific coherence
|
||||
pub fn domain_coherence(&self, domain: Domain) -> Option<f64> {
|
||||
let domain_node_ids = self.domain_nodes.get(&domain)?;
|
||||
|
||||
if domain_node_ids.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Count edges within domain
|
||||
let mut internal_weight = 0.0;
|
||||
let mut edge_count = 0;
|
||||
|
||||
for edge in &self.edges {
|
||||
if domain_node_ids.contains(&edge.source) && domain_node_ids.contains(&edge.target) {
|
||||
internal_weight += edge.weight;
|
||||
edge_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if edge_count == 0 {
|
||||
return Some(0.0);
|
||||
}
|
||||
|
||||
Some(internal_weight / edge_count as f64)
|
||||
}
|
||||
|
||||
/// Get statistics about the current state
|
||||
pub fn stats(&self) -> EngineStats {
|
||||
let mut domain_counts = HashMap::new();
|
||||
for domain in self.domain_nodes.keys() {
|
||||
domain_counts.insert(*domain, self.domain_nodes[domain].len());
|
||||
}
|
||||
|
||||
let mut cross_domain_edges = 0;
|
||||
for edge in &self.edges {
|
||||
if edge.edge_type == EdgeType::CrossDomain {
|
||||
cross_domain_edges += 1;
|
||||
}
|
||||
}
|
||||
|
||||
EngineStats {
|
||||
total_nodes: self.nodes.len(),
|
||||
total_edges: self.edges.len(),
|
||||
total_vectors: self.vectors.len(),
|
||||
domain_counts,
|
||||
cross_domain_edges,
|
||||
history_length: self.coherence_history.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all detected patterns from the latest detection run
|
||||
pub fn get_patterns(&self) -> Vec<DiscoveredPattern> {
|
||||
// For now, return an empty vec. In production, this would store
|
||||
// patterns from the last detect_patterns() call
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Export the current graph structure
|
||||
pub fn export_graph(&self) -> GraphExport {
|
||||
GraphExport {
|
||||
nodes: self.nodes.values().cloned().collect(),
|
||||
edges: self.edges.clone(),
|
||||
domains: self.domain_nodes.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the coherence history
|
||||
pub fn get_coherence_history(&self) -> Vec<CoherenceHistoryEntry> {
|
||||
self.coherence_history.iter()
|
||||
.map(|(timestamp, mincut, snapshot)| {
|
||||
CoherenceHistoryEntry {
|
||||
timestamp: *timestamp,
|
||||
mincut_value: *mincut,
|
||||
snapshot: snapshot.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Engine statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EngineStats {
|
||||
pub total_nodes: usize,
|
||||
pub total_edges: usize,
|
||||
pub total_vectors: usize,
|
||||
pub domain_counts: HashMap<Domain, usize>,
|
||||
pub cross_domain_edges: usize,
|
||||
pub history_length: usize,
|
||||
}
|
||||
|
||||
/// Exported graph structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphExport {
|
||||
pub nodes: Vec<GraphNode>,
|
||||
pub edges: Vec<GraphEdge>,
|
||||
pub domains: HashMap<Domain, Vec<u32>>,
|
||||
}
|
||||
|
||||
/// Coherence history entry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CoherenceHistoryEntry {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub mincut_value: f64,
|
||||
pub snapshot: CoherenceSnapshot,
|
||||
}
|
||||
|
||||
// Note: cosine_similarity is imported from crate::utils
|
||||
|
||||
// Implement ordering for Domain to use in HashMap keys
|
||||
impl PartialOrd for Domain {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Domain {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
(*self as u8).cmp(&(*other as u8))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &c)).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_basic() {
|
||||
let config = NativeEngineConfig::default();
|
||||
let mut engine = NativeDiscoveryEngine::new(config);
|
||||
|
||||
// Add some vectors
|
||||
let v1 = SemanticVector {
|
||||
id: "climate_1".to_string(),
|
||||
embedding: vec![1.0, 0.5, 0.2],
|
||||
domain: Domain::Climate,
|
||||
timestamp: Utc::now(),
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
let v2 = SemanticVector {
|
||||
id: "climate_2".to_string(),
|
||||
embedding: vec![0.9, 0.6, 0.3],
|
||||
domain: Domain::Climate,
|
||||
timestamp: Utc::now(),
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
engine.add_vector(v1);
|
||||
engine.add_vector(v2);
|
||||
|
||||
let stats = engine.stats();
|
||||
assert_eq!(stats.total_nodes, 2);
|
||||
assert_eq!(stats.total_vectors, 2);
|
||||
}
|
||||
}
|
||||
841
vendor/ruvector/examples/data/framework/src/semantic_scholar.rs
vendored
Normal file
841
vendor/ruvector/examples/data/framework/src/semantic_scholar.rs
vendored
Normal file
@@ -0,0 +1,841 @@
|
||||
//! Semantic Scholar API Integration
|
||||
//!
|
||||
//! This module provides an async client for fetching academic papers from Semantic Scholar,
|
||||
//! converting responses to SemanticVector format for RuVector discovery.
|
||||
//!
|
||||
//! # Semantic Scholar API Details
|
||||
//! - Base URL: https://api.semanticscholar.org/graph/v1
|
||||
//! - Free tier: 100 requests per 5 minutes without API key
|
||||
//! - With API key: Higher limits (contact Semantic Scholar)
|
||||
//! - Returns JSON responses
|
||||
//!
|
||||
//! # Example
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_data_framework::semantic_scholar::SemanticScholarClient;
|
||||
//!
|
||||
//! let client = SemanticScholarClient::new(None); // No API key
|
||||
//!
|
||||
//! // Search papers by keywords
|
||||
//! let vectors = client.search_papers("machine learning", 10).await?;
|
||||
//!
|
||||
//! // Get paper details
|
||||
//! let paper = client.get_paper("649def34f8be52c8b66281af98ae884c09aef38b").await?;
|
||||
//!
|
||||
//! // Get citations
|
||||
//! let citations = client.get_citations("649def34f8be52c8b66281af98ae884c09aef38b", 20).await?;
|
||||
//!
|
||||
//! // Search by field of study
|
||||
//! let cs_papers = client.search_by_field("Computer Science", 50).await?;
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{DateTime, NaiveDate, Utc};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
use crate::ruvector_native::{Domain, SemanticVector};
|
||||
use crate::{FrameworkError, Result};
|
||||
|
||||
/// Rate limiting configuration for Semantic Scholar API
|
||||
const S2_RATE_LIMIT_MS: u64 = 3000; // 3 seconds between requests (100 req / 5 min = ~20 req/min = 3s/req)
|
||||
const S2_WITH_KEY_RATE_LIMIT_MS: u64 = 200; // More aggressive with API key
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 2000;
|
||||
const DEFAULT_EMBEDDING_DIM: usize = 384;
|
||||
|
||||
// ============================================================================
|
||||
// Semantic Scholar API Response Structures
|
||||
// ============================================================================
|
||||
|
||||
/// Search response from Semantic Scholar
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SearchResponse {
|
||||
#[serde(default)]
|
||||
total: Option<i32>,
|
||||
#[serde(default)]
|
||||
offset: Option<i32>,
|
||||
#[serde(default)]
|
||||
next: Option<i32>,
|
||||
#[serde(default)]
|
||||
data: Vec<PaperData>,
|
||||
}
|
||||
|
||||
/// Paper data structure
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
struct PaperData {
|
||||
#[serde(rename = "paperId")]
|
||||
paper_id: String,
|
||||
|
||||
#[serde(default)]
|
||||
title: Option<String>,
|
||||
|
||||
#[serde(rename = "abstract", default)]
|
||||
abstract_text: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
year: Option<i32>,
|
||||
|
||||
#[serde(rename = "citationCount", default)]
|
||||
citation_count: Option<i32>,
|
||||
|
||||
#[serde(rename = "referenceCount", default)]
|
||||
reference_count: Option<i32>,
|
||||
|
||||
#[serde(rename = "influentialCitationCount", default)]
|
||||
influential_citation_count: Option<i32>,
|
||||
|
||||
#[serde(default)]
|
||||
authors: Vec<AuthorData>,
|
||||
|
||||
#[serde(rename = "fieldsOfStudy", default)]
|
||||
fields_of_study: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
venue: Option<String>,
|
||||
|
||||
#[serde(rename = "publicationVenue", default)]
|
||||
publication_venue: Option<PublicationVenue>,
|
||||
|
||||
#[serde(default)]
|
||||
url: Option<String>,
|
||||
|
||||
#[serde(rename = "openAccessPdf", default)]
|
||||
open_access_pdf: Option<OpenAccessPdf>,
|
||||
}
|
||||
|
||||
/// Author information
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
struct AuthorData {
|
||||
#[serde(rename = "authorId", default)]
|
||||
author_id: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
/// Publication venue details
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
struct PublicationVenue {
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
|
||||
#[serde(rename = "type", default)]
|
||||
venue_type: Option<String>,
|
||||
}
|
||||
|
||||
/// Open access PDF information
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
struct OpenAccessPdf {
|
||||
#[serde(default)]
|
||||
url: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
status: Option<String>,
|
||||
}
|
||||
|
||||
/// Citation/reference response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CitationResponse {
|
||||
#[serde(default)]
|
||||
offset: Option<i32>,
|
||||
|
||||
#[serde(default)]
|
||||
next: Option<i32>,
|
||||
|
||||
#[serde(default)]
|
||||
data: Vec<CitationData>,
|
||||
}
|
||||
|
||||
/// Citation data wrapper
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CitationData {
|
||||
#[serde(rename = "citingPaper", default)]
|
||||
citing_paper: Option<PaperData>,
|
||||
|
||||
#[serde(rename = "citedPaper", default)]
|
||||
cited_paper: Option<PaperData>,
|
||||
}
|
||||
|
||||
/// Author details response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthorResponse {
|
||||
#[serde(rename = "authorId")]
|
||||
author_id: String,
|
||||
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
|
||||
#[serde(rename = "paperCount", default)]
|
||||
paper_count: Option<i32>,
|
||||
|
||||
#[serde(rename = "citationCount", default)]
|
||||
citation_count: Option<i32>,
|
||||
|
||||
#[serde(rename = "hIndex", default)]
|
||||
h_index: Option<i32>,
|
||||
|
||||
#[serde(default)]
|
||||
papers: Vec<PaperData>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Semantic Scholar Client
|
||||
// ============================================================================
|
||||
|
||||
/// Client for Semantic Scholar API
|
||||
///
|
||||
/// Provides methods to search for academic papers, retrieve citations and references,
|
||||
/// filter by fields of study, and convert results to SemanticVector format for RuVector analysis.
|
||||
///
|
||||
/// # Rate Limiting
|
||||
/// The client automatically enforces rate limits:
|
||||
/// - Without API key: 100 requests per 5 minutes (3 seconds between requests)
|
||||
/// - With API key: Higher limits (200ms between requests)
|
||||
///
|
||||
/// # API Key
|
||||
/// Set the `SEMANTIC_SCHOLAR_API_KEY` environment variable to use authenticated requests.
|
||||
pub struct SemanticScholarClient {
|
||||
client: Client,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
rate_limit_delay: Duration,
|
||||
}
|
||||
|
||||
impl SemanticScholarClient {
|
||||
/// Create a new Semantic Scholar API client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - Optional API key. If None, checks SEMANTIC_SCHOLAR_API_KEY env var
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// // Without API key
|
||||
/// let client = SemanticScholarClient::new(None);
|
||||
///
|
||||
/// // With API key
|
||||
/// let client = SemanticScholarClient::new(Some("your-api-key".to_string()));
|
||||
/// ```
|
||||
pub fn new(api_key: Option<String>) -> Self {
|
||||
Self::with_embedding_dim(api_key, DEFAULT_EMBEDDING_DIM)
|
||||
}
|
||||
|
||||
/// Create a new client with custom embedding dimension
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - Optional API key
|
||||
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
|
||||
pub fn with_embedding_dim(api_key: Option<String>, embedding_dim: usize) -> Self {
|
||||
// Try API key from parameter, then environment variable
|
||||
let api_key = api_key.or_else(|| env::var("SEMANTIC_SCHOLAR_API_KEY").ok());
|
||||
|
||||
let rate_limit_delay = if api_key.is_some() {
|
||||
Duration::from_millis(S2_WITH_KEY_RATE_LIMIT_MS)
|
||||
} else {
|
||||
Duration::from_millis(S2_RATE_LIMIT_MS)
|
||||
};
|
||||
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent("RuVector-Discovery/1.0")
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client"),
|
||||
embedder: Arc::new(SimpleEmbedder::new(embedding_dim)),
|
||||
base_url: "https://api.semanticscholar.org/graph/v1".to_string(),
|
||||
api_key,
|
||||
rate_limit_delay,
|
||||
}
|
||||
}
|
||||
|
||||
/// Search papers by keywords
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query (keywords, title, etc.)
|
||||
/// * `limit` - Maximum number of results to return (max 100 per request)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let vectors = client.search_papers("deep learning transformers", 50).await?;
|
||||
/// ```
|
||||
pub async fn search_papers(&self, query: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let limit = limit.min(100); // API limit
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
|
||||
let url = format!(
|
||||
"{}/paper/search?query={}&limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,influentialCitationCount,authors,fieldsOfStudy,venue,publicationVenue,url,openAccessPdf",
|
||||
self.base_url, encoded_query, limit
|
||||
);
|
||||
|
||||
let response: SearchResponse = self.fetch_json(&url).await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for paper in response.data {
|
||||
if let Some(vector) = self.paper_to_vector(paper) {
|
||||
vectors.push(vector);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Get a single paper by Semantic Scholar paper ID
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `paper_id` - Semantic Scholar paper ID (e.g., "649def34f8be52c8b66281af98ae884c09aef38b")
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let paper = client.get_paper("649def34f8be52c8b66281af98ae884c09aef38b").await?;
|
||||
/// ```
|
||||
pub async fn get_paper(&self, paper_id: &str) -> Result<Option<SemanticVector>> {
|
||||
let url = format!(
|
||||
"{}/paper/{}?fields=paperId,title,abstract,year,citationCount,referenceCount,influentialCitationCount,authors,fieldsOfStudy,venue,publicationVenue,url,openAccessPdf",
|
||||
self.base_url, paper_id
|
||||
);
|
||||
|
||||
let paper: PaperData = self.fetch_json(&url).await?;
|
||||
Ok(self.paper_to_vector(paper))
|
||||
}
|
||||
|
||||
/// Get papers that cite this paper
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `paper_id` - Semantic Scholar paper ID
|
||||
/// * `limit` - Maximum number of citations to return (max 1000)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let citations = client.get_citations("649def34f8be52c8b66281af98ae884c09aef38b", 50).await?;
|
||||
/// ```
|
||||
pub async fn get_citations(&self, paper_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let limit = limit.min(1000); // API limit
|
||||
|
||||
let url = format!(
|
||||
"{}/paper/{}/citations?limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,authors,fieldsOfStudy,venue,url",
|
||||
self.base_url, paper_id, limit
|
||||
);
|
||||
|
||||
let response: CitationResponse = self.fetch_json(&url).await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for citation in response.data {
|
||||
if let Some(citing_paper) = citation.citing_paper {
|
||||
if let Some(vector) = self.paper_to_vector(citing_paper) {
|
||||
vectors.push(vector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Get papers this paper references
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `paper_id` - Semantic Scholar paper ID
|
||||
/// * `limit` - Maximum number of references to return (max 1000)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let references = client.get_references("649def34f8be52c8b66281af98ae884c09aef38b", 50).await?;
|
||||
/// ```
|
||||
pub async fn get_references(&self, paper_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
let limit = limit.min(1000); // API limit
|
||||
|
||||
let url = format!(
|
||||
"{}/paper/{}/references?limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,authors,fieldsOfStudy,venue,url",
|
||||
self.base_url, paper_id, limit
|
||||
);
|
||||
|
||||
let response: CitationResponse = self.fetch_json(&url).await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for reference in response.data {
|
||||
if let Some(cited_paper) = reference.cited_paper {
|
||||
if let Some(vector) = self.paper_to_vector(cited_paper) {
|
||||
vectors.push(vector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Search papers by field of study
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `field_of_study` - Field name (e.g., "Computer Science", "Medicine", "Biology", "Physics", "Economics")
|
||||
/// * `limit` - Maximum number of results to return
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let cs_papers = client.search_by_field("Computer Science", 100).await?;
|
||||
/// let medical_papers = client.search_by_field("Medicine", 50).await?;
|
||||
/// ```
|
||||
pub async fn search_by_field(&self, field_of_study: &str, limit: usize) -> Result<Vec<SemanticVector>> {
|
||||
// Search for papers in this field, sorted by citation count
|
||||
let query = format!("fieldsOfStudy:{}", field_of_study);
|
||||
self.search_papers(&query, limit).await
|
||||
}
|
||||
|
||||
/// Get author details and their papers
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `author_id` - Semantic Scholar author ID
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let author_papers = client.get_author("1741101").await?;
|
||||
/// ```
|
||||
pub async fn get_author(&self, author_id: &str) -> Result<Vec<SemanticVector>> {
|
||||
let url = format!(
|
||||
"{}/author/{}?fields=authorId,name,paperCount,citationCount,hIndex,papers.paperId,papers.title,papers.abstract,papers.year,papers.citationCount,papers.fieldsOfStudy",
|
||||
self.base_url, author_id
|
||||
);
|
||||
|
||||
let author: AuthorResponse = self.fetch_json(&url).await?;
|
||||
|
||||
let mut vectors = Vec::new();
|
||||
for paper in author.papers {
|
||||
if let Some(vector) = self.paper_to_vector(paper) {
|
||||
vectors.push(vector);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vectors)
|
||||
}
|
||||
|
||||
/// Search recent papers published after a minimum year
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query
|
||||
/// * `year_min` - Minimum publication year (e.g., 2020)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// // Get papers about "climate change" published since 2020
|
||||
/// let recent = client.search_recent("climate change", 2020).await?;
|
||||
/// ```
|
||||
pub async fn search_recent(&self, query: &str, year_min: i32) -> Result<Vec<SemanticVector>> {
|
||||
let all_results = self.search_papers(query, 100).await?;
|
||||
|
||||
// Filter by year
|
||||
Ok(all_results
|
||||
.into_iter()
|
||||
.filter(|v| {
|
||||
v.metadata
|
||||
.get("year")
|
||||
.and_then(|y| y.parse::<i32>().ok())
|
||||
.map(|year| year >= year_min)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Build citation graph for a paper
|
||||
///
|
||||
/// Returns a tuple of (paper, citations, references) as SemanticVectors
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `paper_id` - Semantic Scholar paper ID
|
||||
/// * `max_citations` - Maximum citations to retrieve
|
||||
/// * `max_references` - Maximum references to retrieve
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,ignore
|
||||
/// let (paper, citations, references) = client.build_citation_graph(
|
||||
/// "649def34f8be52c8b66281af98ae884c09aef38b",
|
||||
/// 50,
|
||||
/// 50
|
||||
/// ).await?;
|
||||
/// ```
|
||||
pub async fn build_citation_graph(
|
||||
&self,
|
||||
paper_id: &str,
|
||||
max_citations: usize,
|
||||
max_references: usize,
|
||||
) -> Result<(Option<SemanticVector>, Vec<SemanticVector>, Vec<SemanticVector>)> {
|
||||
// Fetch paper, citations, and references in parallel
|
||||
let paper_result = self.get_paper(paper_id);
|
||||
let citations_result = self.get_citations(paper_id, max_citations);
|
||||
let references_result = self.get_references(paper_id, max_references);
|
||||
|
||||
// Wait for all with proper spacing for rate limiting
|
||||
let paper = paper_result.await?;
|
||||
sleep(self.rate_limit_delay).await;
|
||||
|
||||
let citations = citations_result.await?;
|
||||
sleep(self.rate_limit_delay).await;
|
||||
|
||||
let references = references_result.await?;
|
||||
|
||||
Ok((paper, citations, references))
|
||||
}
|
||||
|
||||
/// Convert PaperData to SemanticVector
|
||||
fn paper_to_vector(&self, paper: PaperData) -> Option<SemanticVector> {
|
||||
let title = paper.title.clone().unwrap_or_default();
|
||||
let abstract_text = paper.abstract_text.clone().unwrap_or_default();
|
||||
|
||||
// Skip papers without title
|
||||
if title.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Generate embedding from title + abstract
|
||||
let combined_text = format!("{} {}", title, abstract_text);
|
||||
let embedding = self.embedder.embed_text(&combined_text);
|
||||
|
||||
// Convert year to timestamp
|
||||
let timestamp = paper.year
|
||||
.and_then(|y| NaiveDate::from_ymd_opt(y, 1, 1))
|
||||
.map(|d| DateTime::from_naive_utc_and_offset(d.and_hms_opt(0, 0, 0).unwrap(), Utc))
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
// Build metadata
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("paper_id".to_string(), paper.paper_id.clone());
|
||||
metadata.insert("title".to_string(), title);
|
||||
|
||||
if !abstract_text.is_empty() {
|
||||
metadata.insert("abstract".to_string(), abstract_text);
|
||||
}
|
||||
|
||||
if let Some(year) = paper.year {
|
||||
metadata.insert("year".to_string(), year.to_string());
|
||||
}
|
||||
|
||||
if let Some(count) = paper.citation_count {
|
||||
metadata.insert("citationCount".to_string(), count.to_string());
|
||||
}
|
||||
|
||||
if let Some(count) = paper.reference_count {
|
||||
metadata.insert("referenceCount".to_string(), count.to_string());
|
||||
}
|
||||
|
||||
if let Some(count) = paper.influential_citation_count {
|
||||
metadata.insert("influentialCitationCount".to_string(), count.to_string());
|
||||
}
|
||||
|
||||
// Authors
|
||||
let authors = paper
|
||||
.authors
|
||||
.iter()
|
||||
.filter_map(|a| a.name.as_ref())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
if !authors.is_empty() {
|
||||
metadata.insert("authors".to_string(), authors);
|
||||
}
|
||||
|
||||
// Fields of study
|
||||
if !paper.fields_of_study.is_empty() {
|
||||
metadata.insert("fieldsOfStudy".to_string(), paper.fields_of_study.join(", "));
|
||||
}
|
||||
|
||||
// Venue
|
||||
if let Some(venue) = paper.venue.or_else(|| paper.publication_venue.and_then(|pv| pv.name)) {
|
||||
metadata.insert("venue".to_string(), venue);
|
||||
}
|
||||
|
||||
// URL
|
||||
if let Some(url) = paper.url {
|
||||
metadata.insert("url".to_string(), url);
|
||||
} else {
|
||||
metadata.insert(
|
||||
"url".to_string(),
|
||||
format!("https://www.semanticscholar.org/paper/{}", paper.paper_id),
|
||||
);
|
||||
}
|
||||
|
||||
// Open access PDF
|
||||
if let Some(pdf) = paper.open_access_pdf.and_then(|p| p.url) {
|
||||
metadata.insert("pdf_url".to_string(), pdf);
|
||||
}
|
||||
|
||||
metadata.insert("source".to_string(), "semantic_scholar".to_string());
|
||||
|
||||
Some(SemanticVector {
|
||||
id: format!("s2:{}", paper.paper_id),
|
||||
embedding,
|
||||
domain: Domain::Research,
|
||||
timestamp,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch JSON from URL with rate limiting and retry logic
|
||||
async fn fetch_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
|
||||
// Rate limiting
|
||||
sleep(self.rate_limit_delay).await;
|
||||
|
||||
let response = self.fetch_with_retry(url).await?;
|
||||
let json = response.json::<T>().await?;
|
||||
|
||||
Ok(json)
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let mut request = self.client.get(url);
|
||||
|
||||
// Add API key header if available
|
||||
if let Some(ref api_key) = self.api_key {
|
||||
request = request.header("x-api-key", api_key);
|
||||
}
|
||||
|
||||
match request.send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
|
||||
retries += 1;
|
||||
let delay = RETRY_DELAY_MS * (2_u64.pow(retries - 1)); // Exponential backoff
|
||||
tracing::warn!(
|
||||
"Rate limited by Semantic Scholar, retrying in {}ms",
|
||||
delay
|
||||
);
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
continue;
|
||||
}
|
||||
if !response.status().is_success() {
|
||||
return Err(FrameworkError::Network(
|
||||
reqwest::Error::from(response.error_for_status().unwrap_err()),
|
||||
));
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
let delay = RETRY_DELAY_MS * (2_u64.pow(retries - 1)); // Exponential backoff
|
||||
tracing::warn!("Request failed, retrying ({}/{}) in {}ms", retries, MAX_RETRIES, delay);
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SemanticScholarClient {
|
||||
fn default() -> Self {
|
||||
Self::new(None)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_client_creation() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
assert_eq!(client.base_url, "https://api.semanticscholar.org/graph/v1");
|
||||
assert_eq!(client.rate_limit_delay, Duration::from_millis(S2_RATE_LIMIT_MS));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_with_api_key() {
|
||||
let client = SemanticScholarClient::new(Some("test-key".to_string()));
|
||||
assert_eq!(client.api_key, Some("test-key".to_string()));
|
||||
assert_eq!(client.rate_limit_delay, Duration::from_millis(S2_WITH_KEY_RATE_LIMIT_MS));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_embedding_dim() {
|
||||
let client = SemanticScholarClient::with_embedding_dim(None, 512);
|
||||
let embedding = client.embedder.embed_text("test");
|
||||
assert_eq!(embedding.len(), 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paper_to_vector() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
|
||||
let paper = PaperData {
|
||||
paper_id: "649def34f8be52c8b66281af98ae884c09aef38b".to_string(),
|
||||
title: Some("Attention Is All You Need".to_string()),
|
||||
abstract_text: Some("The dominant sequence transduction models...".to_string()),
|
||||
year: Some(2017),
|
||||
citation_count: Some(50000),
|
||||
reference_count: Some(35),
|
||||
influential_citation_count: Some(5000),
|
||||
authors: vec![
|
||||
AuthorData {
|
||||
author_id: Some("1741101".to_string()),
|
||||
name: Some("Ashish Vaswani".to_string()),
|
||||
},
|
||||
AuthorData {
|
||||
author_id: Some("1699545".to_string()),
|
||||
name: Some("Noam Shazeer".to_string()),
|
||||
},
|
||||
],
|
||||
fields_of_study: vec!["Computer Science".to_string(), "Mathematics".to_string()],
|
||||
venue: Some("NeurIPS".to_string()),
|
||||
publication_venue: None,
|
||||
url: Some("https://arxiv.org/abs/1706.03762".to_string()),
|
||||
open_access_pdf: Some(OpenAccessPdf {
|
||||
url: Some("https://arxiv.org/pdf/1706.03762.pdf".to_string()),
|
||||
status: Some("GREEN".to_string()),
|
||||
}),
|
||||
};
|
||||
|
||||
let vector = client.paper_to_vector(paper);
|
||||
assert!(vector.is_some());
|
||||
|
||||
let v = vector.unwrap();
|
||||
assert_eq!(v.id, "s2:649def34f8be52c8b66281af98ae884c09aef38b");
|
||||
assert_eq!(v.domain, Domain::Research);
|
||||
assert_eq!(v.metadata.get("paper_id").unwrap(), "649def34f8be52c8b66281af98ae884c09aef38b");
|
||||
assert_eq!(v.metadata.get("title").unwrap(), "Attention Is All You Need");
|
||||
assert_eq!(v.metadata.get("year").unwrap(), "2017");
|
||||
assert_eq!(v.metadata.get("citationCount").unwrap(), "50000");
|
||||
assert_eq!(v.metadata.get("referenceCount").unwrap(), "35");
|
||||
assert_eq!(v.metadata.get("authors").unwrap(), "Ashish Vaswani, Noam Shazeer");
|
||||
assert_eq!(v.metadata.get("fieldsOfStudy").unwrap(), "Computer Science, Mathematics");
|
||||
assert_eq!(v.metadata.get("venue").unwrap(), "NeurIPS");
|
||||
assert!(v.metadata.contains_key("pdf_url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paper_to_vector_minimal() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
|
||||
let paper = PaperData {
|
||||
paper_id: "test123".to_string(),
|
||||
title: Some("Minimal Paper".to_string()),
|
||||
abstract_text: None,
|
||||
year: None,
|
||||
citation_count: None,
|
||||
reference_count: None,
|
||||
influential_citation_count: None,
|
||||
authors: vec![],
|
||||
fields_of_study: vec![],
|
||||
venue: None,
|
||||
publication_venue: None,
|
||||
url: None,
|
||||
open_access_pdf: None,
|
||||
};
|
||||
|
||||
let vector = client.paper_to_vector(paper);
|
||||
assert!(vector.is_some());
|
||||
|
||||
let v = vector.unwrap();
|
||||
assert_eq!(v.id, "s2:test123");
|
||||
assert_eq!(v.metadata.get("title").unwrap(), "Minimal Paper");
|
||||
assert!(v.metadata.get("url").unwrap().contains("semanticscholar.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paper_without_title() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
|
||||
let paper = PaperData {
|
||||
paper_id: "test456".to_string(),
|
||||
title: None,
|
||||
abstract_text: Some("Has abstract but no title".to_string()),
|
||||
year: Some(2020),
|
||||
citation_count: None,
|
||||
reference_count: None,
|
||||
influential_citation_count: None,
|
||||
authors: vec![],
|
||||
fields_of_study: vec![],
|
||||
venue: None,
|
||||
publication_venue: None,
|
||||
url: None,
|
||||
open_access_pdf: None,
|
||||
};
|
||||
|
||||
// Papers without titles should be skipped
|
||||
let vector = client.paper_to_vector(paper);
|
||||
assert!(vector.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API in tests
|
||||
async fn test_search_papers_integration() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
let results = client.search_papers("machine learning", 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
|
||||
if !vectors.is_empty() {
|
||||
let first = &vectors[0];
|
||||
assert!(first.id.starts_with("s2:"));
|
||||
assert_eq!(first.domain, Domain::Research);
|
||||
assert!(first.metadata.contains_key("title"));
|
||||
assert!(first.metadata.contains_key("paper_id"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
|
||||
async fn test_get_paper_integration() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
|
||||
// Well-known paper: "Attention Is All You Need"
|
||||
let result = client.get_paper("649def34f8be52c8b66281af98ae884c09aef38b").await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let paper = result.unwrap();
|
||||
assert!(paper.is_some());
|
||||
|
||||
let p = paper.unwrap();
|
||||
assert_eq!(p.id, "s2:649def34f8be52c8b66281af98ae884c09aef38b");
|
||||
assert!(p.metadata.get("title").unwrap().contains("Attention"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
|
||||
async fn test_get_citations_integration() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
|
||||
// Get citations for "Attention Is All You Need"
|
||||
let result = client.get_citations("649def34f8be52c8b66281af98ae884c09aef38b", 10).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let citations = result.unwrap();
|
||||
assert!(citations.len() <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
|
||||
async fn test_search_by_field_integration() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
let results = client.search_by_field("Computer Science", 5).await;
|
||||
assert!(results.is_ok());
|
||||
|
||||
let vectors = results.unwrap();
|
||||
assert!(vectors.len() <= 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
|
||||
async fn test_build_citation_graph_integration() {
|
||||
let client = SemanticScholarClient::new(None);
|
||||
|
||||
let result = client.build_citation_graph(
|
||||
"649def34f8be52c8b66281af98ae884c09aef38b",
|
||||
5,
|
||||
5
|
||||
).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let (paper, citations, references) = result.unwrap();
|
||||
assert!(paper.is_some());
|
||||
}
|
||||
}
|
||||
1284
vendor/ruvector/examples/data/framework/src/space_clients.rs
vendored
Normal file
1284
vendor/ruvector/examples/data/framework/src/space_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
703
vendor/ruvector/examples/data/framework/src/streaming.rs
vendored
Normal file
703
vendor/ruvector/examples/data/framework/src/streaming.rs
vendored
Normal file
@@ -0,0 +1,703 @@
|
||||
//! Real-time Streaming Data Ingestion
|
||||
//!
|
||||
//! Provides async stream processing with windowed analysis, real-time pattern
|
||||
//! detection, backpressure handling, and comprehensive metrics collection.
|
||||
//!
|
||||
//! ## Features
|
||||
//! - Async stream processing for continuous data ingestion
|
||||
//! - Tumbling and sliding window analysis
|
||||
//! - Real-time pattern detection with callbacks
|
||||
//! - Automatic backpressure handling
|
||||
//! - Throughput and latency metrics
|
||||
//!
|
||||
//! ## Example
|
||||
//! ```rust,ignore
|
||||
//! use futures::stream;
|
||||
//! use std::time::Duration;
|
||||
//!
|
||||
//! let config = StreamingConfig {
|
||||
//! window_size: Duration::from_secs(60),
|
||||
//! slide_interval: Duration::from_secs(30),
|
||||
//! max_buffer_size: 10000,
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
//!
|
||||
//! let mut engine = StreamingEngine::new(config);
|
||||
//!
|
||||
//! // Set pattern callback
|
||||
//! engine.set_pattern_callback(|pattern| {
|
||||
//! println!("Pattern detected: {:?}", pattern);
|
||||
//! });
|
||||
//!
|
||||
//! // Ingest stream
|
||||
//! let stream = stream::iter(vectors);
|
||||
//! engine.ingest_stream(stream).await?;
|
||||
//!
|
||||
//! // Get metrics
|
||||
//! let metrics = engine.metrics();
|
||||
//! println!("Processed: {} vectors, {} patterns",
|
||||
//! metrics.vectors_processed, metrics.patterns_detected);
|
||||
//! ```
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration as StdDuration, Instant};
|
||||
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
||||
use futures::{Stream, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::{RwLock, Semaphore};
|
||||
|
||||
use crate::optimized::{OptimizedConfig, OptimizedDiscoveryEngine, SignificantPattern};
|
||||
use crate::ruvector_native::SemanticVector;
|
||||
use crate::Result;
|
||||
|
||||
/// Configuration for the streaming engine
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamingConfig {
|
||||
/// Discovery engine configuration
|
||||
pub discovery_config: OptimizedConfig,
|
||||
|
||||
/// Window size for temporal analysis
|
||||
pub window_size: StdDuration,
|
||||
|
||||
/// Slide interval for sliding windows (if None, use tumbling windows)
|
||||
pub slide_interval: Option<StdDuration>,
|
||||
|
||||
/// Maximum buffer size before applying backpressure
|
||||
pub max_buffer_size: usize,
|
||||
|
||||
/// Timeout for processing a single vector (None = no timeout)
|
||||
pub processing_timeout: Option<StdDuration>,
|
||||
|
||||
/// Batch size for parallel processing
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Enable automatic pattern detection
|
||||
pub auto_detect_patterns: bool,
|
||||
|
||||
/// Pattern detection interval (check every N vectors)
|
||||
pub detection_interval: usize,
|
||||
|
||||
/// Maximum concurrent processing tasks
|
||||
pub max_concurrency: usize,
|
||||
}
|
||||
|
||||
impl Default for StreamingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
discovery_config: OptimizedConfig::default(),
|
||||
window_size: StdDuration::from_secs(60),
|
||||
slide_interval: Some(StdDuration::from_secs(30)),
|
||||
max_buffer_size: 10000,
|
||||
processing_timeout: Some(StdDuration::from_secs(5)),
|
||||
batch_size: 100,
|
||||
auto_detect_patterns: true,
|
||||
detection_interval: 100,
|
||||
max_concurrency: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming metrics for monitoring performance
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct StreamingMetrics {
|
||||
/// Total vectors processed
|
||||
pub vectors_processed: u64,
|
||||
|
||||
/// Total patterns detected
|
||||
pub patterns_detected: u64,
|
||||
|
||||
/// Average latency in milliseconds
|
||||
pub avg_latency_ms: f64,
|
||||
|
||||
/// Throughput (vectors per second)
|
||||
pub throughput_per_sec: f64,
|
||||
|
||||
/// Current window count
|
||||
pub windows_processed: u64,
|
||||
|
||||
/// Total bytes processed (if available)
|
||||
pub bytes_processed: u64,
|
||||
|
||||
/// Backpressure events (times buffer was full)
|
||||
pub backpressure_events: u64,
|
||||
|
||||
/// Processing errors
|
||||
pub errors: u64,
|
||||
|
||||
/// Peak vectors in buffer
|
||||
pub peak_buffer_size: usize,
|
||||
|
||||
/// Start time
|
||||
pub start_time: Option<DateTime<Utc>>,
|
||||
|
||||
/// Last update time
|
||||
pub last_update: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl StreamingMetrics {
|
||||
/// Calculate uptime in seconds
|
||||
pub fn uptime_secs(&self) -> f64 {
|
||||
if let (Some(start), Some(last)) = (self.start_time, self.last_update) {
|
||||
(last - start).num_milliseconds() as f64 / 1000.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate average throughput
|
||||
pub fn calculate_throughput(&mut self) {
|
||||
let uptime = self.uptime_secs();
|
||||
if uptime > 0.0 {
|
||||
self.throughput_per_sec = self.vectors_processed as f64 / uptime;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Time window for analysis
|
||||
#[derive(Debug, Clone)]
|
||||
struct TimeWindow {
|
||||
start: DateTime<Utc>,
|
||||
end: DateTime<Utc>,
|
||||
vectors: Vec<SemanticVector>,
|
||||
}
|
||||
|
||||
impl TimeWindow {
|
||||
fn new(start: DateTime<Utc>, duration: ChronoDuration) -> Self {
|
||||
Self {
|
||||
start,
|
||||
end: start + duration,
|
||||
vectors: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn contains(&self, timestamp: DateTime<Utc>) -> bool {
|
||||
timestamp >= self.start && timestamp < self.end
|
||||
}
|
||||
|
||||
fn add_vector(&mut self, vector: SemanticVector) {
|
||||
self.vectors.push(vector);
|
||||
}
|
||||
|
||||
fn is_complete(&self, now: DateTime<Utc>) -> bool {
|
||||
now >= self.end
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming engine for real-time data ingestion and pattern detection
|
||||
pub struct StreamingEngine {
|
||||
/// Configuration
|
||||
config: StreamingConfig,
|
||||
|
||||
/// Underlying discovery engine (wrapped in Arc<RwLock> for async access)
|
||||
engine: Arc<RwLock<OptimizedDiscoveryEngine>>,
|
||||
|
||||
/// Pattern callback
|
||||
on_pattern: Arc<RwLock<Option<Box<dyn Fn(SignificantPattern) + Send + Sync>>>>,
|
||||
|
||||
/// Metrics
|
||||
metrics: Arc<RwLock<StreamingMetrics>>,
|
||||
|
||||
/// Current windows (for sliding window analysis)
|
||||
windows: Arc<RwLock<Vec<TimeWindow>>>,
|
||||
|
||||
/// Backpressure semaphore
|
||||
semaphore: Arc<Semaphore>,
|
||||
|
||||
/// Latency tracking
|
||||
latencies: Arc<RwLock<Vec<f64>>>,
|
||||
}
|
||||
|
||||
impl StreamingEngine {
|
||||
/// Create a new streaming engine
|
||||
pub fn new(config: StreamingConfig) -> Self {
|
||||
let discovery_config = config.discovery_config.clone();
|
||||
let max_buffer = config.max_buffer_size;
|
||||
|
||||
let mut metrics = StreamingMetrics::default();
|
||||
metrics.start_time = Some(Utc::now());
|
||||
|
||||
Self {
|
||||
config,
|
||||
engine: Arc::new(RwLock::new(OptimizedDiscoveryEngine::new(discovery_config))),
|
||||
on_pattern: Arc::new(RwLock::new(None)),
|
||||
metrics: Arc::new(RwLock::new(metrics)),
|
||||
windows: Arc::new(RwLock::new(Vec::new())),
|
||||
semaphore: Arc::new(Semaphore::new(max_buffer)),
|
||||
latencies: Arc::new(RwLock::new(Vec::with_capacity(1000))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the pattern detection callback
|
||||
pub async fn set_pattern_callback<F>(&mut self, callback: F)
|
||||
where
|
||||
F: Fn(SignificantPattern) + Send + Sync + 'static,
|
||||
{
|
||||
let mut on_pattern = self.on_pattern.write().await;
|
||||
*on_pattern = Some(Box::new(callback));
|
||||
}
|
||||
|
||||
/// Ingest a stream of vectors with windowed analysis
|
||||
pub async fn ingest_stream<S>(&mut self, stream: S) -> Result<()>
|
||||
where
|
||||
S: Stream<Item = SemanticVector> + Send,
|
||||
{
|
||||
let mut stream = Box::pin(stream);
|
||||
let mut vector_count = 0_u64;
|
||||
let mut current_batch = Vec::with_capacity(self.config.batch_size);
|
||||
|
||||
// Initialize first window
|
||||
let window_duration = ChronoDuration::from_std(self.config.window_size)
|
||||
.map_err(|e| crate::FrameworkError::Config(format!("Invalid window size: {}", e)))?;
|
||||
|
||||
let mut last_window_start = Utc::now();
|
||||
self.create_window(last_window_start, window_duration).await;
|
||||
|
||||
while let Some(vector) = stream.next().await {
|
||||
// Backpressure handling
|
||||
let _permit = self.semaphore.acquire().await.map_err(|e| {
|
||||
crate::FrameworkError::Ingestion(format!("Backpressure semaphore error: {}", e))
|
||||
})?;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Check if we need to create a new window (sliding)
|
||||
if let Some(slide_interval) = self.config.slide_interval {
|
||||
let slide_duration = ChronoDuration::from_std(slide_interval)
|
||||
.map_err(|e| crate::FrameworkError::Config(format!("Invalid slide interval: {}", e)))?;
|
||||
|
||||
let now = Utc::now();
|
||||
if (now - last_window_start) >= slide_duration {
|
||||
self.create_window(now, window_duration).await;
|
||||
last_window_start = now;
|
||||
}
|
||||
}
|
||||
|
||||
// Add vector to appropriate windows
|
||||
self.add_to_windows(vector.clone()).await;
|
||||
current_batch.push(vector);
|
||||
vector_count += 1;
|
||||
|
||||
// Process batch
|
||||
if current_batch.len() >= self.config.batch_size {
|
||||
self.process_batch(¤t_batch).await?;
|
||||
current_batch.clear();
|
||||
}
|
||||
|
||||
// Pattern detection
|
||||
if self.config.auto_detect_patterns && vector_count % self.config.detection_interval as u64 == 0 {
|
||||
self.detect_patterns().await?;
|
||||
}
|
||||
|
||||
// Close completed windows
|
||||
self.close_completed_windows().await?;
|
||||
|
||||
// Record latency
|
||||
let latency_ms = start.elapsed().as_micros() as f64 / 1000.0;
|
||||
self.record_latency(latency_ms).await;
|
||||
|
||||
// Update metrics
|
||||
let mut metrics = self.metrics.write().await;
|
||||
metrics.vectors_processed = vector_count;
|
||||
metrics.last_update = Some(Utc::now());
|
||||
}
|
||||
|
||||
// Process remaining batch
|
||||
if !current_batch.is_empty() {
|
||||
self.process_batch(¤t_batch).await?;
|
||||
}
|
||||
|
||||
// Final pattern detection
|
||||
if self.config.auto_detect_patterns {
|
||||
self.detect_patterns().await?;
|
||||
}
|
||||
|
||||
// Close all remaining windows
|
||||
self.close_all_windows().await?;
|
||||
|
||||
// Calculate final metrics
|
||||
let mut metrics = self.metrics.write().await;
|
||||
metrics.calculate_throughput();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process a batch of vectors in parallel
|
||||
async fn process_batch(&self, vectors: &[SemanticVector]) -> Result<()> {
|
||||
let batch_size = self.config.batch_size;
|
||||
let chunks: Vec<_> = vectors.chunks(batch_size).collect();
|
||||
|
||||
// Process chunks with controlled concurrency
|
||||
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for chunk in chunks {
|
||||
let chunk_vec = chunk.to_vec();
|
||||
let engine = self.engine.clone();
|
||||
let sem = semaphore.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.ok()?;
|
||||
let mut engine_guard = engine.write().await;
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
{
|
||||
engine_guard.add_vectors_batch(chunk_vec);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "parallel"))]
|
||||
{
|
||||
for vector in chunk_vec {
|
||||
engine_guard.add_vector(vector);
|
||||
}
|
||||
}
|
||||
|
||||
Some(())
|
||||
});
|
||||
|
||||
tasks.push(task);
|
||||
}
|
||||
|
||||
// Wait for all tasks to complete
|
||||
for task in tasks {
|
||||
if let Err(e) = task.await {
|
||||
tracing::warn!("Batch processing task failed: {}", e);
|
||||
let mut metrics = self.metrics.write().await;
|
||||
metrics.errors += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new time window
|
||||
async fn create_window(&self, start: DateTime<Utc>, duration: ChronoDuration) {
|
||||
let window = TimeWindow::new(start, duration);
|
||||
let mut windows = self.windows.write().await;
|
||||
windows.push(window);
|
||||
}
|
||||
|
||||
/// Add vector to all active windows
|
||||
async fn add_to_windows(&self, vector: SemanticVector) {
|
||||
let timestamp = vector.timestamp;
|
||||
let mut windows = self.windows.write().await;
|
||||
|
||||
for window in windows.iter_mut() {
|
||||
if window.contains(timestamp) {
|
||||
window.add_vector(vector.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Close completed windows and analyze them
|
||||
async fn close_completed_windows(&self) -> Result<()> {
|
||||
let now = Utc::now();
|
||||
let mut windows = self.windows.write().await;
|
||||
|
||||
// Find completed windows
|
||||
let (completed, active): (Vec<_>, Vec<_>) = windows
|
||||
.drain(..)
|
||||
.partition(|w| w.is_complete(now));
|
||||
|
||||
*windows = active;
|
||||
drop(windows); // Release lock before processing
|
||||
|
||||
// Process completed windows
|
||||
for window in completed {
|
||||
self.process_window(window).await?;
|
||||
|
||||
let mut metrics = self.metrics.write().await;
|
||||
metrics.windows_processed += 1;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Close all remaining windows
|
||||
async fn close_all_windows(&self) -> Result<()> {
|
||||
let mut windows = self.windows.write().await;
|
||||
let all_windows: Vec<_> = windows.drain(..).collect();
|
||||
drop(windows);
|
||||
|
||||
for window in all_windows {
|
||||
self.process_window(window).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process a completed window
|
||||
async fn process_window(&self, window: TimeWindow) -> Result<()> {
|
||||
if window.vectors.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"Processing window: {} vectors from {} to {}",
|
||||
window.vectors.len(),
|
||||
window.start,
|
||||
window.end
|
||||
);
|
||||
|
||||
// Add vectors to engine
|
||||
self.process_batch(&window.vectors).await?;
|
||||
|
||||
// Detect patterns for this window
|
||||
if self.config.auto_detect_patterns {
|
||||
self.detect_patterns().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Detect patterns and trigger callbacks
|
||||
async fn detect_patterns(&self) -> Result<()> {
|
||||
let patterns = {
|
||||
let mut engine = self.engine.write().await;
|
||||
engine.detect_patterns_with_significance()
|
||||
};
|
||||
|
||||
let pattern_count = patterns.len();
|
||||
|
||||
// Trigger callback for each significant pattern
|
||||
let on_pattern = self.on_pattern.read().await;
|
||||
if let Some(callback) = on_pattern.as_ref() {
|
||||
for pattern in patterns {
|
||||
if pattern.is_significant {
|
||||
callback(pattern);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
let mut metrics = self.metrics.write().await;
|
||||
metrics.patterns_detected += pattern_count as u64;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record latency measurement
|
||||
async fn record_latency(&self, latency_ms: f64) {
|
||||
let mut latencies = self.latencies.write().await;
|
||||
latencies.push(latency_ms);
|
||||
|
||||
// Keep only last 1000 measurements
|
||||
let len = latencies.len();
|
||||
if len > 1000 {
|
||||
latencies.drain(0..len - 1000);
|
||||
}
|
||||
|
||||
// Update average latency
|
||||
let avg = latencies.iter().sum::<f64>() / latencies.len() as f64;
|
||||
let mut metrics = self.metrics.write().await;
|
||||
metrics.avg_latency_ms = avg;
|
||||
}
|
||||
|
||||
/// Get current metrics
|
||||
pub async fn metrics(&self) -> StreamingMetrics {
|
||||
let mut metrics = self.metrics.read().await.clone();
|
||||
metrics.calculate_throughput();
|
||||
metrics
|
||||
}
|
||||
|
||||
/// Get engine statistics
|
||||
pub async fn engine_stats(&self) -> crate::optimized::OptimizedStats {
|
||||
let engine = self.engine.read().await;
|
||||
engine.stats()
|
||||
}
|
||||
|
||||
/// Reset metrics
|
||||
pub async fn reset_metrics(&self) {
|
||||
let mut metrics = self.metrics.write().await;
|
||||
*metrics = StreamingMetrics::default();
|
||||
metrics.start_time = Some(Utc::now());
|
||||
|
||||
let mut latencies = self.latencies.write().await;
|
||||
latencies.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for StreamingEngine with fluent API
|
||||
pub struct StreamingEngineBuilder {
|
||||
config: StreamingConfig,
|
||||
}
|
||||
|
||||
impl StreamingEngineBuilder {
|
||||
/// Create a new builder
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: StreamingConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set window size
|
||||
pub fn window_size(mut self, duration: StdDuration) -> Self {
|
||||
self.config.window_size = duration;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set slide interval (for sliding windows)
|
||||
pub fn slide_interval(mut self, duration: StdDuration) -> Self {
|
||||
self.config.slide_interval = Some(duration);
|
||||
self
|
||||
}
|
||||
|
||||
/// Use tumbling windows (no overlap)
|
||||
pub fn tumbling_windows(mut self) -> Self {
|
||||
self.config.slide_interval = None;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max buffer size
|
||||
pub fn max_buffer_size(mut self, size: usize) -> Self {
|
||||
self.config.max_buffer_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set batch size
|
||||
pub fn batch_size(mut self, size: usize) -> Self {
|
||||
self.config.batch_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max concurrency
|
||||
pub fn max_concurrency(mut self, concurrency: usize) -> Self {
|
||||
self.config.max_concurrency = concurrency;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set detection interval
|
||||
pub fn detection_interval(mut self, interval: usize) -> Self {
|
||||
self.config.detection_interval = interval;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set discovery config
|
||||
pub fn discovery_config(mut self, config: OptimizedConfig) -> Self {
|
||||
self.config.discovery_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the streaming engine
|
||||
pub fn build(self) -> StreamingEngine {
|
||||
StreamingEngine::new(self.config)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StreamingEngineBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::stream;
|
||||
use crate::ruvector_native::Domain;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_vector(id: &str, domain: Domain) -> SemanticVector {
|
||||
SemanticVector {
|
||||
id: id.to_string(),
|
||||
embedding: vec![0.1, 0.2, 0.3, 0.4],
|
||||
domain,
|
||||
timestamp: Utc::now(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_engine_creation() {
|
||||
let config = StreamingConfig::default();
|
||||
let engine = StreamingEngine::new(config);
|
||||
let metrics = engine.metrics().await;
|
||||
|
||||
assert_eq!(metrics.vectors_processed, 0);
|
||||
assert_eq!(metrics.patterns_detected, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pattern_callback() {
|
||||
let config = StreamingConfig {
|
||||
auto_detect_patterns: true,
|
||||
detection_interval: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut engine = StreamingEngine::new(config);
|
||||
|
||||
let pattern_count = Arc::new(RwLock::new(0_u64));
|
||||
let pc = pattern_count.clone();
|
||||
|
||||
engine.set_pattern_callback(move |_pattern| {
|
||||
let pc = pc.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut count = pc.write().await;
|
||||
*count += 1;
|
||||
});
|
||||
}).await;
|
||||
|
||||
// Create a stream of vectors
|
||||
let vectors = vec![
|
||||
create_test_vector("v1", Domain::Climate),
|
||||
create_test_vector("v2", Domain::Climate),
|
||||
create_test_vector("v3", Domain::Finance),
|
||||
];
|
||||
|
||||
let vector_stream = stream::iter(vectors);
|
||||
engine.ingest_stream(vector_stream).await.unwrap();
|
||||
|
||||
let metrics = engine.metrics().await;
|
||||
assert!(metrics.vectors_processed >= 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_windowed_processing() {
|
||||
let config = StreamingConfig {
|
||||
window_size: StdDuration::from_millis(100),
|
||||
slide_interval: Some(StdDuration::from_millis(50)),
|
||||
auto_detect_patterns: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut engine = StreamingEngine::new(config);
|
||||
|
||||
let vectors = vec![
|
||||
create_test_vector("v1", Domain::Climate),
|
||||
create_test_vector("v2", Domain::Climate),
|
||||
];
|
||||
|
||||
let vector_stream = stream::iter(vectors);
|
||||
engine.ingest_stream(vector_stream).await.unwrap();
|
||||
|
||||
let metrics = engine.metrics().await;
|
||||
assert_eq!(metrics.vectors_processed, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_builder() {
|
||||
let engine = StreamingEngineBuilder::new()
|
||||
.window_size(StdDuration::from_secs(30))
|
||||
.slide_interval(StdDuration::from_secs(15))
|
||||
.max_buffer_size(5000)
|
||||
.batch_size(50)
|
||||
.build();
|
||||
|
||||
let metrics = engine.metrics().await;
|
||||
assert_eq!(metrics.vectors_processed, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metrics_calculation() {
|
||||
let mut metrics = StreamingMetrics {
|
||||
vectors_processed: 1000,
|
||||
start_time: Some(Utc::now() - ChronoDuration::seconds(10)),
|
||||
last_update: Some(Utc::now()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
metrics.calculate_throughput();
|
||||
assert!(metrics.throughput_per_sec > 0.0);
|
||||
assert!(metrics.uptime_secs() >= 9.0 && metrics.uptime_secs() <= 11.0);
|
||||
}
|
||||
}
|
||||
1720
vendor/ruvector/examples/data/framework/src/transportation_clients.rs
vendored
Normal file
1720
vendor/ruvector/examples/data/framework/src/transportation_clients.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
171
vendor/ruvector/examples/data/framework/src/utils.rs
vendored
Normal file
171
vendor/ruvector/examples/data/framework/src/utils.rs
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
//! Shared utility functions for the RuVector Data Framework
|
||||
//!
|
||||
//! This module contains common utilities used across multiple modules,
|
||||
//! including vector operations and mathematical functions.
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
///
|
||||
/// Returns a value in [-1, 1] where:
|
||||
/// - 1 = identical direction
|
||||
/// - 0 = orthogonal
|
||||
/// - -1 = opposite direction
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - First vector
|
||||
/// * `b` - Second vector (must be same length as `a`)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Cosine similarity score, or 0.0 if vectors are empty or different lengths
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use ruvector_data_framework::utils::cosine_similarity;
|
||||
///
|
||||
/// let a = vec![1.0, 0.0, 0.0];
|
||||
/// let b = vec![1.0, 0.0, 0.0];
|
||||
/// assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
///
|
||||
/// let c = vec![0.0, 1.0, 0.0];
|
||||
/// assert!(cosine_similarity(&a, &c).abs() < 1e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Process in chunks for better cache locality
|
||||
const CHUNK_SIZE: usize = 8;
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_a = 0.0f32;
|
||||
let mut norm_b = 0.0f32;
|
||||
|
||||
// Process aligned chunks
|
||||
let chunks = a.len() / CHUNK_SIZE;
|
||||
for chunk in 0..chunks {
|
||||
let base = chunk * CHUNK_SIZE;
|
||||
for i in 0..CHUNK_SIZE {
|
||||
let ai = a[base + i];
|
||||
let bi = b[base + i];
|
||||
dot += ai * bi;
|
||||
norm_a += ai * ai;
|
||||
norm_b += bi * bi;
|
||||
}
|
||||
}
|
||||
|
||||
// Process remainder
|
||||
for i in (chunks * CHUNK_SIZE)..a.len() {
|
||||
let ai = a[i];
|
||||
let bi = b[i];
|
||||
dot += ai * bi;
|
||||
norm_a += ai * ai;
|
||||
norm_b += bi * bi;
|
||||
}
|
||||
|
||||
let denom = (norm_a * norm_b).sqrt();
|
||||
if denom > 1e-10 {
|
||||
dot / denom
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute Euclidean (L2) distance between two vectors
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - First vector
|
||||
/// * `b` - Second vector (must be same length as `a`)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Euclidean distance, or 0.0 if vectors are empty or different lengths
|
||||
#[inline]
|
||||
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let sum_sq: f32 = a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(ai, bi)| {
|
||||
let diff = ai - bi;
|
||||
diff * diff
|
||||
})
|
||||
.sum();
|
||||
|
||||
sum_sq.sqrt()
|
||||
}
|
||||
|
||||
/// Normalize a vector to unit length (L2 normalization)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `v` - Vector to normalize (modified in place)
|
||||
#[inline]
|
||||
pub fn normalize_vector(v: &mut [f32]) {
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-10 {
|
||||
for x in v.iter_mut() {
|
||||
*x /= norm;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_identical() {
|
||||
let a = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_orthogonal() {
|
||||
let a = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0, 0.0];
|
||||
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_opposite() {
|
||||
let a = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let b = vec![-1.0, 0.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_empty() {
|
||||
let a: Vec<f32> = vec![];
|
||||
let b: Vec<f32> = vec![];
|
||||
assert_eq!(cosine_similarity(&a, &b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_different_lengths() {
|
||||
let a = vec![1.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert_eq!(cosine_similarity(&a, &b), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let a = vec![0.0, 0.0];
|
||||
let b = vec![3.0, 4.0];
|
||||
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_vector() {
|
||||
let mut v = vec![3.0, 4.0];
|
||||
normalize_vector(&mut v);
|
||||
assert!((v[0] - 0.6).abs() < 1e-6);
|
||||
assert!((v[1] - 0.8).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
576
vendor/ruvector/examples/data/framework/src/visualization.rs
vendored
Normal file
576
vendor/ruvector/examples/data/framework/src/visualization.rs
vendored
Normal file
@@ -0,0 +1,576 @@
|
||||
//! ASCII Art Visualization for Discovery Framework
|
||||
//!
|
||||
//! Provides terminal-based graph visualization with ANSI colors, domain clustering,
|
||||
//! coherence heatmaps, and pattern timeline displays.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
use crate::optimized::{OptimizedDiscoveryEngine, SignificantPattern};
|
||||
use crate::ruvector_native::{Domain, PatternType};
|
||||
|
||||
/// ANSI color codes for domains
|
||||
const COLOR_CLIMATE: &str = "\x1b[34m"; // Blue
|
||||
const COLOR_FINANCE: &str = "\x1b[32m"; // Green
|
||||
const COLOR_RESEARCH: &str = "\x1b[33m"; // Yellow
|
||||
const COLOR_MEDICAL: &str = "\x1b[36m"; // Cyan
|
||||
const COLOR_CROSS: &str = "\x1b[35m"; // Magenta
|
||||
const COLOR_RESET: &str = "\x1b[0m";
|
||||
const COLOR_BRIGHT: &str = "\x1b[1m";
|
||||
const COLOR_DIM: &str = "\x1b[2m";
|
||||
|
||||
/// Box-drawing characters
|
||||
const BOX_H: char = '─';
|
||||
const BOX_V: char = '│';
|
||||
const BOX_TL: char = '┌';
|
||||
const BOX_TR: char = '┐';
|
||||
const BOX_BL: char = '└';
|
||||
const BOX_BR: char = '┘';
|
||||
const BOX_CROSS: char = '┼';
|
||||
const BOX_T_DOWN: char = '┬';
|
||||
const BOX_T_UP: char = '┴';
|
||||
const BOX_T_RIGHT: char = '├';
|
||||
const BOX_T_LEFT: char = '┤';
|
||||
|
||||
/// Get ANSI color for a domain
|
||||
fn domain_color(domain: Domain) -> &'static str {
|
||||
match domain {
|
||||
Domain::Climate => COLOR_CLIMATE,
|
||||
Domain::Finance => COLOR_FINANCE,
|
||||
Domain::Research => COLOR_RESEARCH,
|
||||
Domain::Medical => COLOR_MEDICAL,
|
||||
Domain::Economic => "\x1b[38;5;214m", // Orange color for Economic
|
||||
Domain::Genomics => "\x1b[38;5;46m", // Green color for Genomics
|
||||
Domain::Physics => "\x1b[38;5;33m", // Blue color for Physics
|
||||
Domain::Seismic => "\x1b[38;5;130m", // Brown color for Seismic
|
||||
Domain::Ocean => "\x1b[38;5;39m", // Cyan color for Ocean
|
||||
Domain::Space => "\x1b[38;5;141m", // Purple color for Space
|
||||
Domain::Transportation => "\x1b[38;5;208m", // Orange color for Transportation
|
||||
Domain::Geospatial => "\x1b[38;5;118m", // Light green for Geospatial
|
||||
Domain::Government => "\x1b[38;5;243m", // Gray color for Government
|
||||
Domain::CrossDomain => COLOR_CROSS,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a character representation for a domain
|
||||
fn domain_char(domain: Domain) -> char {
|
||||
match domain {
|
||||
Domain::Climate => 'C',
|
||||
Domain::Finance => 'F',
|
||||
Domain::Research => 'R',
|
||||
Domain::Medical => 'M',
|
||||
Domain::Economic => 'E',
|
||||
Domain::Genomics => 'G',
|
||||
Domain::Physics => 'P',
|
||||
Domain::Seismic => 'S',
|
||||
Domain::Ocean => 'O',
|
||||
Domain::Space => 'A', // A for Astronomy/Aerospace
|
||||
Domain::Transportation => 'T',
|
||||
Domain::Geospatial => 'L', // L for Location
|
||||
Domain::Government => 'V', // V for goVernment
|
||||
Domain::CrossDomain => 'X',
|
||||
}
|
||||
}
|
||||
|
||||
/// Render the graph as ASCII art with colored domain nodes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `engine` - The discovery engine containing the graph
|
||||
/// * `width` - Canvas width in characters
|
||||
/// * `height` - Canvas height in characters
|
||||
///
|
||||
/// # Returns
|
||||
/// A string containing the ASCII art representation
|
||||
pub fn render_graph_ascii(engine: &OptimizedDiscoveryEngine, width: usize, height: usize) -> String {
|
||||
let stats = engine.stats();
|
||||
let mut output = String::new();
|
||||
|
||||
// Draw title box
|
||||
output.push_str(&format!("{}{}", COLOR_BRIGHT, BOX_TL));
|
||||
output.push_str(&BOX_H.to_string().repeat(width - 2));
|
||||
output.push_str(&format!("{}{}\n", BOX_TR, COLOR_RESET));
|
||||
|
||||
let title = format!(" Discovery Graph ({} nodes, {} edges) ", stats.total_nodes, stats.total_edges);
|
||||
output.push_str(&format!("{}{}", COLOR_BRIGHT, BOX_V));
|
||||
output.push_str(&format!("{:^width$}", title, width = width - 2));
|
||||
output.push_str(&format!("{}{}\n", BOX_V, COLOR_RESET));
|
||||
|
||||
output.push_str(&format!("{}{}", COLOR_BRIGHT, BOX_BL));
|
||||
output.push_str(&BOX_H.to_string().repeat(width - 2));
|
||||
output.push_str(&format!("{}{}\n\n", BOX_BR, COLOR_RESET));
|
||||
|
||||
// If no nodes, show empty message
|
||||
if stats.total_nodes == 0 {
|
||||
output.push_str(&format!("{} (empty graph){}\n", COLOR_DIM, COLOR_RESET));
|
||||
return output;
|
||||
}
|
||||
|
||||
// Create a simple layout by domain
|
||||
let mut domain_positions: HashMap<Domain, Vec<(usize, usize)>> = HashMap::new();
|
||||
|
||||
// Layout domains in quadrants
|
||||
let mid_x = width / 2;
|
||||
let mid_y = height / 2;
|
||||
|
||||
// Assign domain regions
|
||||
let domain_regions = [
|
||||
(Domain::Climate, 10, 2), // Top-left
|
||||
(Domain::Finance, mid_x + 10, 2), // Top-right
|
||||
(Domain::Research, 10, mid_y + 2), // Bottom-left
|
||||
];
|
||||
|
||||
for (domain, count) in &stats.domain_counts {
|
||||
let (_, base_x, base_y) = domain_regions.iter()
|
||||
.find(|(d, _, _)| d == domain)
|
||||
.unwrap_or(&(Domain::Research, 10, 2));
|
||||
|
||||
let mut positions = Vec::new();
|
||||
|
||||
// Arrange nodes in a cluster
|
||||
let nodes_per_row = ((*count as f64).sqrt().ceil() as usize).max(1);
|
||||
for i in 0..*count {
|
||||
let row = i / nodes_per_row;
|
||||
let col = i % nodes_per_row;
|
||||
let x = base_x + col * 3;
|
||||
let y = base_y + row * 2;
|
||||
|
||||
if x < width - 5 && y < height - 2 {
|
||||
positions.push((x, y));
|
||||
}
|
||||
}
|
||||
|
||||
domain_positions.insert(*domain, positions);
|
||||
}
|
||||
|
||||
// Create canvas
|
||||
let mut canvas: Vec<Vec<String>> = vec![vec![" ".to_string(); width]; height];
|
||||
|
||||
// Draw nodes
|
||||
for (domain, positions) in &domain_positions {
|
||||
let color = domain_color(*domain);
|
||||
let ch = domain_char(*domain);
|
||||
|
||||
for (x, y) in positions {
|
||||
if *x < width && *y < height {
|
||||
canvas[*y][*x] = format!("{}{}{}", color, ch, COLOR_RESET);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Draw edges (simplified - show connections between domains)
|
||||
if stats.cross_domain_edges > 0 {
|
||||
// Draw some connecting lines
|
||||
for (domain_a, positions_a) in &domain_positions {
|
||||
for (domain_b, positions_b) in &domain_positions {
|
||||
if domain_a == domain_b {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Draw one connection line
|
||||
if let (Some(pos_a), Some(pos_b)) = (positions_a.first(), positions_b.first()) {
|
||||
let (x1, y1) = pos_a;
|
||||
let (x2, y2) = pos_b;
|
||||
|
||||
// Simple line drawing (horizontal then vertical)
|
||||
let color = COLOR_DIM;
|
||||
|
||||
// Horizontal part
|
||||
let (min_x, max_x) = if x1 < x2 { (*x1, *x2) } else { (*x2, *x1) };
|
||||
for x in min_x..=max_x {
|
||||
if x < width && *y1 < height && canvas[*y1][x] == " " {
|
||||
canvas[*y1][x] = format!("{}{}{}", color, BOX_H, COLOR_RESET);
|
||||
}
|
||||
}
|
||||
|
||||
// Vertical part
|
||||
let (min_y, max_y) = if y1 < y2 { (*y1, *y2) } else { (*y2, *y1) };
|
||||
for y in min_y..=max_y {
|
||||
if *x2 < width && y < height && canvas[y][*x2] == " " {
|
||||
canvas[y][*x2] = format!("{}{}{}", color, BOX_V, COLOR_RESET);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Render canvas to string
|
||||
for row in canvas {
|
||||
for cell in row {
|
||||
output.push_str(&cell);
|
||||
}
|
||||
output.push('\n');
|
||||
}
|
||||
|
||||
output.push('\n');
|
||||
|
||||
// Legend
|
||||
output.push_str(&format!("{}Legend:{}\n", COLOR_BRIGHT, COLOR_RESET));
|
||||
output.push_str(&format!(" {}C{} = Climate ", COLOR_CLIMATE, COLOR_RESET));
|
||||
output.push_str(&format!("{}F{} = Finance ", COLOR_FINANCE, COLOR_RESET));
|
||||
output.push_str(&format!("{}R{} = Research\n", COLOR_RESEARCH, COLOR_RESET));
|
||||
output.push_str(&format!(" Cross-domain bridges: {}\n", stats.cross_domain_edges));
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Render a domain connectivity matrix
|
||||
///
|
||||
/// Shows the strength of connections between different domains
|
||||
pub fn render_domain_matrix(engine: &OptimizedDiscoveryEngine) -> String {
|
||||
let stats = engine.stats();
|
||||
let mut output = String::new();
|
||||
|
||||
output.push_str(&format!("\n{}{}Domain Connectivity Matrix{}{}\n",
|
||||
COLOR_BRIGHT, BOX_TL, BOX_TR, COLOR_RESET));
|
||||
output.push_str(&format!("{}\n", BOX_H.to_string().repeat(50)));
|
||||
|
||||
// Calculate connections between domains
|
||||
let domains = [Domain::Climate, Domain::Finance, Domain::Research];
|
||||
let mut matrix: HashMap<(Domain, Domain), usize> = HashMap::new();
|
||||
|
||||
// Initialize matrix
|
||||
for &d1 in &domains {
|
||||
for &d2 in &domains {
|
||||
matrix.insert((d1, d2), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// This is a placeholder - in real implementation, we'd iterate through edges
|
||||
// and count connections between domains
|
||||
output.push_str(&format!(" {}Climate{} {}Finance{} {}Research{}\n",
|
||||
COLOR_CLIMATE, COLOR_RESET,
|
||||
COLOR_FINANCE, COLOR_RESET,
|
||||
COLOR_RESEARCH, COLOR_RESET));
|
||||
|
||||
for &domain_a in &domains {
|
||||
let color_a = domain_color(domain_a);
|
||||
output.push_str(&format!("{}{:9}{} ", color_a, format!("{:?}", domain_a), COLOR_RESET));
|
||||
|
||||
for &domain_b in &domains {
|
||||
let count = matrix.get(&(domain_a, domain_b)).unwrap_or(&0);
|
||||
let display = if domain_a == domain_b {
|
||||
format!("{}[{:3}]{}", COLOR_BRIGHT, stats.domain_counts.get(&domain_a).unwrap_or(&0), COLOR_RESET)
|
||||
} else {
|
||||
format!(" {:3} ", count)
|
||||
};
|
||||
output.push_str(&display);
|
||||
}
|
||||
output.push('\n');
|
||||
}
|
||||
|
||||
output.push_str(&format!("\n{}Note:{} Diagonal = node count, Off-diagonal = cross-domain edges\n",
|
||||
COLOR_DIM, COLOR_RESET));
|
||||
output.push_str(&format!("Total cross-domain edges: {}\n", stats.cross_domain_edges));
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Render coherence timeline as ASCII sparkline/chart
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `history` - Time series of (timestamp, coherence_value) pairs
|
||||
pub fn render_coherence_timeline(history: &[(DateTime<Utc>, f64)]) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
output.push_str(&format!("\n{}{}Coherence Timeline{}{}\n",
|
||||
COLOR_BRIGHT, BOX_TL, BOX_TR, COLOR_RESET));
|
||||
output.push_str(&format!("{}\n", BOX_H.to_string().repeat(70)));
|
||||
|
||||
if history.is_empty() {
|
||||
output.push_str(&format!("{} (no coherence history){}\n", COLOR_DIM, COLOR_RESET));
|
||||
return output;
|
||||
}
|
||||
|
||||
let values: Vec<f64> = history.iter().map(|(_, v)| *v).collect();
|
||||
let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
|
||||
let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
|
||||
|
||||
output.push_str(&format!(" Coherence range: {:.4} - {:.4}\n", min_val, max_val));
|
||||
output.push_str(&format!(" Data points: {}\n\n", history.len()));
|
||||
|
||||
// ASCII sparkline
|
||||
let chart_height = 10;
|
||||
let chart_width = 60.min(history.len());
|
||||
|
||||
// Sample data if too many points
|
||||
let step = if history.len() > chart_width {
|
||||
history.len() / chart_width
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let sampled: Vec<f64> = history.iter()
|
||||
.step_by(step)
|
||||
.take(chart_width)
|
||||
.map(|(_, v)| *v)
|
||||
.collect();
|
||||
|
||||
// Normalize values to chart height
|
||||
let range = max_val - min_val;
|
||||
let normalized: Vec<usize> = if range > 1e-10 {
|
||||
sampled.iter()
|
||||
.map(|v| {
|
||||
let normalized = ((v - min_val) / range * (chart_height - 1) as f64) as usize;
|
||||
normalized.min(chart_height - 1)
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![chart_height / 2; sampled.len()]
|
||||
};
|
||||
|
||||
// Draw chart
|
||||
for row in (0..chart_height).rev() {
|
||||
let value = min_val + (row as f64 / (chart_height - 1) as f64) * range;
|
||||
output.push_str(&format!("{:6.3} {} ", value, BOX_V));
|
||||
|
||||
for &height in &normalized {
|
||||
let ch = if height >= row {
|
||||
format!("{}▓{}", COLOR_CLIMATE, COLOR_RESET)
|
||||
} else if height + 1 == row {
|
||||
format!("{}▒{}", COLOR_DIM, COLOR_RESET)
|
||||
} else {
|
||||
" ".to_string()
|
||||
};
|
||||
output.push_str(&ch);
|
||||
}
|
||||
output.push('\n');
|
||||
}
|
||||
|
||||
// X-axis
|
||||
output.push_str(" ");
|
||||
output.push_str(&BOX_BL.to_string());
|
||||
output.push_str(&BOX_H.to_string().repeat(chart_width));
|
||||
output.push('\n');
|
||||
|
||||
// Time labels
|
||||
if let (Some(first), Some(last)) = (history.first(), history.last()) {
|
||||
let duration = last.0.signed_duration_since(first.0);
|
||||
let width_val = if chart_width > 12 { chart_width - 12 } else { 0 };
|
||||
output.push_str(&format!(" {} {:>width$}\n",
|
||||
first.0.format("%Y-%m-%d"),
|
||||
last.0.format("%Y-%m-%d"),
|
||||
width = width_val));
|
||||
output.push_str(&format!(" {}Duration: {}{}\n",
|
||||
COLOR_DIM,
|
||||
if duration.num_days() > 0 {
|
||||
format!("{} days", duration.num_days())
|
||||
} else if duration.num_hours() > 0 {
|
||||
format!("{} hours", duration.num_hours())
|
||||
} else {
|
||||
format!("{} minutes", duration.num_minutes())
|
||||
},
|
||||
COLOR_RESET));
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Render a summary of discovered patterns
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `patterns` - List of significant patterns to summarize
|
||||
pub fn render_pattern_summary(patterns: &[SignificantPattern]) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
output.push_str(&format!("\n{}{}Pattern Discovery Summary{}{}\n",
|
||||
COLOR_BRIGHT, BOX_TL, BOX_TR, COLOR_RESET));
|
||||
output.push_str(&format!("{}\n", BOX_H.to_string().repeat(80)));
|
||||
|
||||
if patterns.is_empty() {
|
||||
output.push_str(&format!("{} No patterns discovered yet{}\n", COLOR_DIM, COLOR_RESET));
|
||||
return output;
|
||||
}
|
||||
|
||||
output.push_str(&format!(" Total patterns detected: {}\n", patterns.len()));
|
||||
|
||||
// Count by type
|
||||
let mut type_counts: HashMap<PatternType, usize> = HashMap::new();
|
||||
let mut significant_count = 0;
|
||||
|
||||
for pattern in patterns {
|
||||
*type_counts.entry(pattern.pattern.pattern_type).or_default() += 1;
|
||||
if pattern.is_significant {
|
||||
significant_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
output.push_str(&format!(" Statistically significant: {} ({:.1}%)\n\n",
|
||||
significant_count,
|
||||
(significant_count as f64 / patterns.len() as f64) * 100.0));
|
||||
|
||||
// Pattern type breakdown
|
||||
output.push_str(&format!("{}Pattern Types:{}\n", COLOR_BRIGHT, COLOR_RESET));
|
||||
for (pattern_type, count) in type_counts.iter() {
|
||||
let icon = match pattern_type {
|
||||
PatternType::CoherenceBreak => "⚠️ ",
|
||||
PatternType::Consolidation => "📈",
|
||||
PatternType::EmergingCluster => "🌟",
|
||||
PatternType::DissolvingCluster => "💫",
|
||||
PatternType::BridgeFormation => "🌉",
|
||||
PatternType::AnomalousNode => "🔴",
|
||||
PatternType::TemporalShift => "⏰",
|
||||
PatternType::Cascade => "🌊",
|
||||
};
|
||||
|
||||
let bar_length = ((*count as f64 / patterns.len() as f64) * 30.0) as usize;
|
||||
let bar = "█".repeat(bar_length);
|
||||
|
||||
output.push_str(&format!(" {} {:20} {:3} {}{}{}\n",
|
||||
icon,
|
||||
format!("{:?}", pattern_type),
|
||||
count,
|
||||
COLOR_CLIMATE,
|
||||
bar,
|
||||
COLOR_RESET));
|
||||
}
|
||||
|
||||
output.push('\n');
|
||||
|
||||
// Top patterns by confidence
|
||||
output.push_str(&format!("{}Top Patterns (by confidence):{}\n", COLOR_BRIGHT, COLOR_RESET));
|
||||
|
||||
let mut sorted_patterns: Vec<_> = patterns.iter().collect();
|
||||
sorted_patterns.sort_by(|a, b| b.pattern.confidence.partial_cmp(&a.pattern.confidence).unwrap());
|
||||
|
||||
for (i, pattern) in sorted_patterns.iter().take(5).enumerate() {
|
||||
let significance_marker = if pattern.is_significant {
|
||||
format!("{}*{}", COLOR_BRIGHT, COLOR_RESET)
|
||||
} else {
|
||||
" ".to_string()
|
||||
};
|
||||
|
||||
let color = if pattern.pattern.confidence > 0.8 {
|
||||
COLOR_CLIMATE
|
||||
} else if pattern.pattern.confidence > 0.5 {
|
||||
COLOR_FINANCE
|
||||
} else {
|
||||
COLOR_DIM
|
||||
};
|
||||
|
||||
output.push_str(&format!(" {}{}.{} {}{:?}{} (p={:.4}, effect={:.3}, conf={:.2})\n",
|
||||
significance_marker,
|
||||
i + 1,
|
||||
COLOR_RESET,
|
||||
color,
|
||||
pattern.pattern.pattern_type,
|
||||
COLOR_RESET,
|
||||
pattern.p_value,
|
||||
pattern.effect_size,
|
||||
pattern.pattern.confidence));
|
||||
|
||||
output.push_str(&format!(" {}{}{}\n",
|
||||
COLOR_DIM,
|
||||
pattern.pattern.description,
|
||||
COLOR_RESET));
|
||||
}
|
||||
|
||||
output.push_str(&format!("\n{}Note:{} * = statistically significant (p < 0.05)\n",
|
||||
COLOR_DIM, COLOR_RESET));
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Render a complete dashboard combining all visualizations
|
||||
pub fn render_dashboard(
|
||||
engine: &OptimizedDiscoveryEngine,
|
||||
patterns: &[SignificantPattern],
|
||||
coherence_history: &[(DateTime<Utc>, f64)],
|
||||
) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
// Title
|
||||
output.push_str(&format!("\n{}{}═══════════════════════════════════════════════════════════════════════════════{}\n",
|
||||
COLOR_BRIGHT, BOX_TL, COLOR_RESET));
|
||||
output.push_str(&format!("{}{} RuVector Discovery Framework - Live Dashboard {}\n",
|
||||
COLOR_BRIGHT, BOX_V, COLOR_RESET));
|
||||
output.push_str(&format!("{}{}═══════════════════════════════════════════════════════════════════════════════{}\n\n",
|
||||
COLOR_BRIGHT, BOX_BL, COLOR_RESET));
|
||||
|
||||
// Stats overview
|
||||
let stats = engine.stats();
|
||||
output.push_str(&format!("{}Quick Stats:{}\n", COLOR_BRIGHT, COLOR_RESET));
|
||||
output.push_str(&format!(" Nodes: {} │ Edges: {} │ Vectors: {} │ Cross-domain: {}\n",
|
||||
stats.total_nodes,
|
||||
stats.total_edges,
|
||||
stats.total_vectors,
|
||||
stats.cross_domain_edges));
|
||||
output.push_str(&format!(" Patterns: {} │ Coherence samples: {} │ Cache hit rate: {:.1}%\n\n",
|
||||
patterns.len(),
|
||||
coherence_history.len(),
|
||||
stats.cache_hit_rate * 100.0));
|
||||
|
||||
// Graph visualization
|
||||
output.push_str(&render_graph_ascii(engine, 80, 20));
|
||||
output.push('\n');
|
||||
|
||||
// Domain matrix
|
||||
output.push_str(&render_domain_matrix(engine));
|
||||
output.push('\n');
|
||||
|
||||
// Coherence timeline
|
||||
output.push_str(&render_coherence_timeline(coherence_history));
|
||||
output.push('\n');
|
||||
|
||||
// Pattern summary
|
||||
output.push_str(&render_pattern_summary(patterns));
|
||||
|
||||
output.push_str(&format!("\n{}{}═══════════════════════════════════════════════════════════════════════════════{}\n",
|
||||
COLOR_DIM, BOX_BL, COLOR_RESET));
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::optimized::{OptimizedConfig, OptimizedDiscoveryEngine};
|
||||
use crate::ruvector_native::SemanticVector;
|
||||
use chrono::Utc;
|
||||
|
||||
#[test]
|
||||
fn test_domain_color() {
|
||||
assert_eq!(domain_color(Domain::Climate), COLOR_CLIMATE);
|
||||
assert_eq!(domain_color(Domain::Finance), COLOR_FINANCE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_char() {
|
||||
assert_eq!(domain_char(Domain::Climate), 'C');
|
||||
assert_eq!(domain_char(Domain::Finance), 'F');
|
||||
assert_eq!(domain_char(Domain::Research), 'R');
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_empty_graph() {
|
||||
let config = OptimizedConfig::default();
|
||||
let engine = OptimizedDiscoveryEngine::new(config);
|
||||
let output = render_graph_ascii(&engine, 80, 20);
|
||||
assert!(output.contains("empty graph"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_pattern_summary_empty() {
|
||||
let output = render_pattern_summary(&[]);
|
||||
assert!(output.contains("No patterns"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_coherence_timeline_empty() {
|
||||
let output = render_coherence_timeline(&[]);
|
||||
assert!(output.contains("no coherence history"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_coherence_timeline_with_data() {
|
||||
let now = Utc::now();
|
||||
let history = vec![
|
||||
(now, 0.5),
|
||||
(now + chrono::Duration::hours(1), 0.6),
|
||||
(now + chrono::Duration::hours(2), 0.7),
|
||||
];
|
||||
let output = render_coherence_timeline(&history);
|
||||
assert!(output.contains("Coherence Timeline"));
|
||||
assert!(output.contains("Data points: 3"));
|
||||
}
|
||||
}
|
||||
906
vendor/ruvector/examples/data/framework/src/wiki_clients.rs
vendored
Normal file
906
vendor/ruvector/examples/data/framework/src/wiki_clients.rs
vendored
Normal file
@@ -0,0 +1,906 @@
|
||||
//! Wikipedia and Wikidata API clients for knowledge graph building
|
||||
//!
|
||||
//! This module provides async clients for:
|
||||
//! - Wikipedia: Article content, categories, links, and search
|
||||
//! - Wikidata: Entity lookup, SPARQL queries, and structured knowledge
|
||||
//!
|
||||
//! Both clients convert responses into RuVector's DataRecord format with
|
||||
//! semantic embeddings for vector search and graph analysis.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::{DataRecord, DataSource, FrameworkError, Relationship, Result};
|
||||
use crate::api_clients::SimpleEmbedder;
|
||||
|
||||
/// Rate limiting configuration
|
||||
const DEFAULT_RATE_LIMIT_DELAY_MS: u64 = 100;
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
|
||||
// ============================================================================
|
||||
// Wikipedia API Client
|
||||
// ============================================================================
|
||||
|
||||
/// Wikipedia API search response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiSearchResponse {
|
||||
query: WikiSearchQuery,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiSearchQuery {
|
||||
search: Vec<WikiSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiSearchResult {
|
||||
title: String,
|
||||
pageid: u64,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
/// Wikipedia API page response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiPageResponse {
|
||||
query: WikiPageQuery,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiPageQuery {
|
||||
pages: HashMap<String, WikiPage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiPage {
|
||||
pageid: u64,
|
||||
title: String,
|
||||
#[serde(default)]
|
||||
extract: String,
|
||||
#[serde(default)]
|
||||
categories: Vec<WikiCategory>,
|
||||
#[serde(default)]
|
||||
links: Vec<WikiLink>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiCategory {
|
||||
title: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikiLink {
|
||||
title: String,
|
||||
}
|
||||
|
||||
/// Client for Wikipedia API
|
||||
pub struct WikipediaClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
language: String,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl WikipediaClient {
|
||||
/// Create a new Wikipedia client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `language` - Wikipedia language code (e.g., "en", "de", "fr")
|
||||
pub fn new(language: String) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.user_agent("RuVector/1.0 (https://github.com/ruvnet/ruvector)")
|
||||
.build()
|
||||
.map_err(|e| FrameworkError::Network(e))?;
|
||||
|
||||
let base_url = format!("https://{}.wikipedia.org/w/api.php", language);
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
language,
|
||||
rate_limit_delay: Duration::from_millis(DEFAULT_RATE_LIMIT_DELAY_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(256)), // Larger dimension for richer content
|
||||
})
|
||||
}
|
||||
|
||||
/// Search Wikipedia articles
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query
|
||||
/// * `limit` - Maximum number of results (max 500)
|
||||
pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<DataRecord>> {
|
||||
let url = format!(
|
||||
"{}?action=query&list=search&srsearch={}&srlimit={}&format=json",
|
||||
self.base_url,
|
||||
urlencoding::encode(query),
|
||||
limit.min(500)
|
||||
);
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let search_response: WikiSearchResponse = response.json().await?;
|
||||
|
||||
let mut records = Vec::new();
|
||||
for result in search_response.query.search {
|
||||
// Get full article for each search result
|
||||
if let Ok(article) = self.get_article(&result.title).await {
|
||||
records.push(article);
|
||||
sleep(self.rate_limit_delay).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
/// Get a Wikipedia article by title
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `title` - Article title
|
||||
pub async fn get_article(&self, title: &str) -> Result<DataRecord> {
|
||||
let url = format!(
|
||||
"{}?action=query&prop=extracts|categories|links&titles={}&exintro=1&explaintext=1&format=json&cllimit=50&pllimit=50",
|
||||
self.base_url,
|
||||
urlencoding::encode(title)
|
||||
);
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let page_response: WikiPageResponse = response.json().await?;
|
||||
|
||||
// Extract the page (should be only one)
|
||||
let page = page_response
|
||||
.query
|
||||
.pages
|
||||
.values()
|
||||
.next()
|
||||
.ok_or_else(|| FrameworkError::Discovery("No page found".to_string()))?;
|
||||
|
||||
self.page_to_record(page)
|
||||
}
|
||||
|
||||
/// Get categories for an article
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `title` - Article title
|
||||
pub async fn get_categories(&self, title: &str) -> Result<Vec<String>> {
|
||||
let url = format!(
|
||||
"{}?action=query&prop=categories&titles={}&cllimit=500&format=json",
|
||||
self.base_url,
|
||||
urlencoding::encode(title)
|
||||
);
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let page_response: WikiPageResponse = response.json().await?;
|
||||
|
||||
let categories = page_response
|
||||
.query
|
||||
.pages
|
||||
.values()
|
||||
.next()
|
||||
.map(|page| page.categories.iter().map(|c| c.title.clone()).collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(categories)
|
||||
}
|
||||
|
||||
/// Get links from an article
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `title` - Article title
|
||||
pub async fn get_links(&self, title: &str) -> Result<Vec<String>> {
|
||||
let url = format!(
|
||||
"{}?action=query&prop=links&titles={}&pllimit=500&format=json",
|
||||
self.base_url,
|
||||
urlencoding::encode(title)
|
||||
);
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let page_response: WikiPageResponse = response.json().await?;
|
||||
|
||||
let links = page_response
|
||||
.query
|
||||
.pages
|
||||
.values()
|
||||
.next()
|
||||
.map(|page| page.links.iter().map(|l| l.title.clone()).collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(links)
|
||||
}
|
||||
|
||||
/// Convert Wikipedia page to DataRecord
|
||||
fn page_to_record(&self, page: &WikiPage) -> Result<DataRecord> {
|
||||
// Create embedding from title and extract
|
||||
let text = format!("{} {}", page.title, page.extract);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
// Build relationships from categories
|
||||
let mut relationships = Vec::new();
|
||||
for category in &page.categories {
|
||||
relationships.push(Relationship {
|
||||
target_id: category.title.clone(),
|
||||
rel_type: "in_category".to_string(),
|
||||
weight: 1.0,
|
||||
properties: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
// Build relationships from links (limit to first 20)
|
||||
for link in page.links.iter().take(20) {
|
||||
relationships.push(Relationship {
|
||||
target_id: link.title.clone(),
|
||||
rel_type: "links_to".to_string(),
|
||||
weight: 0.5,
|
||||
properties: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut data_map = serde_json::Map::new();
|
||||
data_map.insert("title".to_string(), serde_json::json!(page.title));
|
||||
data_map.insert("extract".to_string(), serde_json::json!(page.extract));
|
||||
data_map.insert("pageid".to_string(), serde_json::json!(page.pageid));
|
||||
data_map.insert("language".to_string(), serde_json::json!(self.language));
|
||||
data_map.insert(
|
||||
"url".to_string(),
|
||||
serde_json::json!(format!(
|
||||
"https://{}.wikipedia.org/wiki/{}",
|
||||
self.language,
|
||||
urlencoding::encode(&page.title)
|
||||
)),
|
||||
);
|
||||
|
||||
Ok(DataRecord {
|
||||
id: format!("wikipedia_{}_{}", self.language, page.pageid),
|
||||
source: "wikipedia".to_string(),
|
||||
record_type: "article".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
data: serde_json::Value::Object(data_map),
|
||||
embedding: Some(embedding),
|
||||
relationships,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
|
||||
{
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DataSource for WikipediaClient {
|
||||
fn source_id(&self) -> &str {
|
||||
"wikipedia"
|
||||
}
|
||||
|
||||
async fn fetch_batch(
|
||||
&self,
|
||||
cursor: Option<String>,
|
||||
batch_size: usize,
|
||||
) -> Result<(Vec<DataRecord>, Option<String>)> {
|
||||
// Default to searching for "machine learning" if no cursor provided
|
||||
let query = cursor.as_deref().unwrap_or("machine learning");
|
||||
let records = self.search(query, batch_size).await?;
|
||||
Ok((records, None))
|
||||
}
|
||||
|
||||
async fn total_count(&self) -> Result<Option<u64>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<bool> {
|
||||
let response = self.client.get(&self.base_url).send().await?;
|
||||
Ok(response.status().is_success())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Wikidata API Client
|
||||
// ============================================================================
|
||||
|
||||
/// Wikidata entity search response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataSearchResponse {
|
||||
search: Vec<WikidataSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataSearchResult {
|
||||
id: String,
|
||||
label: String,
|
||||
description: Option<String>,
|
||||
}
|
||||
|
||||
/// Wikidata entity response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataEntityResponse {
|
||||
entities: HashMap<String, WikidataEntityData>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataEntityData {
|
||||
id: String,
|
||||
labels: HashMap<String, WikidataLabel>,
|
||||
descriptions: HashMap<String, WikidataLabel>,
|
||||
aliases: HashMap<String, Vec<WikidataLabel>>,
|
||||
claims: HashMap<String, Vec<WikidataClaim>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataLabel {
|
||||
value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataClaim {
|
||||
mainsnak: WikidataSnak,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataSnak {
|
||||
datavalue: Option<WikidataValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataValue {
|
||||
value: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Wikidata SPARQL response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataSparqlResponse {
|
||||
results: WikidataSparqlResults,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataSparqlResults {
|
||||
bindings: Vec<HashMap<String, WikidataSparqlBinding>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WikidataSparqlBinding {
|
||||
value: String,
|
||||
}
|
||||
|
||||
/// Structured Wikidata entity
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WikidataEntity {
|
||||
/// Wikidata Q-identifier
|
||||
pub qid: String,
|
||||
/// Primary label
|
||||
pub label: String,
|
||||
/// Description
|
||||
pub description: String,
|
||||
/// Alternative names
|
||||
pub aliases: Vec<String>,
|
||||
/// Property claims (property ID -> values)
|
||||
pub claims: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
/// Client for Wikidata API and SPARQL endpoint
|
||||
pub struct WikidataClient {
|
||||
client: Client,
|
||||
api_url: String,
|
||||
sparql_url: String,
|
||||
rate_limit_delay: Duration,
|
||||
embedder: Arc<SimpleEmbedder>,
|
||||
}
|
||||
|
||||
impl WikidataClient {
|
||||
/// Create a new Wikidata client
|
||||
pub fn new() -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.user_agent("RuVector/1.0 (https://github.com/ruvnet/ruvector)")
|
||||
.build()
|
||||
.map_err(|e| FrameworkError::Network(e))?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
api_url: "https://www.wikidata.org/w/api.php".to_string(),
|
||||
sparql_url: "https://query.wikidata.org/sparql".to_string(),
|
||||
rate_limit_delay: Duration::from_millis(DEFAULT_RATE_LIMIT_DELAY_MS),
|
||||
embedder: Arc::new(SimpleEmbedder::new(256)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Search for Wikidata entities
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Search query
|
||||
pub async fn search_entities(&self, query: &str) -> Result<Vec<WikidataEntity>> {
|
||||
let url = format!(
|
||||
"{}?action=wbsearchentities&search={}&language=en&format=json&limit=50",
|
||||
self.api_url,
|
||||
urlencoding::encode(query)
|
||||
);
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let search_response: WikidataSearchResponse = response.json().await?;
|
||||
|
||||
let mut entities = Vec::new();
|
||||
for result in search_response.search {
|
||||
entities.push(WikidataEntity {
|
||||
qid: result.id,
|
||||
label: result.label,
|
||||
description: result.description.unwrap_or_default(),
|
||||
aliases: Vec::new(),
|
||||
claims: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(entities)
|
||||
}
|
||||
|
||||
/// Get a Wikidata entity by QID
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `qid` - Wikidata Q-identifier (e.g., "Q42" for Douglas Adams)
|
||||
pub async fn get_entity(&self, qid: &str) -> Result<WikidataEntity> {
|
||||
let url = format!(
|
||||
"{}?action=wbgetentities&ids={}&format=json",
|
||||
self.api_url, qid
|
||||
);
|
||||
|
||||
let response = self.fetch_with_retry(&url).await?;
|
||||
let entity_response: WikidataEntityResponse = response.json().await?;
|
||||
|
||||
let entity_data = entity_response
|
||||
.entities
|
||||
.get(qid)
|
||||
.ok_or_else(|| FrameworkError::Discovery(format!("Entity {} not found", qid)))?;
|
||||
|
||||
self.entity_data_to_entity(entity_data)
|
||||
}
|
||||
|
||||
/// Execute a SPARQL query
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - SPARQL query string
|
||||
pub async fn sparql_query(&self, query: &str) -> Result<Vec<HashMap<String, String>>> {
|
||||
let response = self
|
||||
.client
|
||||
.get(&self.sparql_url)
|
||||
.query(&[("query", query), ("format", "json")])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let sparql_response: WikidataSparqlResponse = response.json().await?;
|
||||
|
||||
let results = sparql_response
|
||||
.results
|
||||
.bindings
|
||||
.into_iter()
|
||||
.map(|binding| {
|
||||
binding
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.value))
|
||||
.collect::<HashMap<String, String>>()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Query climate change related entities
|
||||
pub async fn query_climate_entities(&self) -> Result<Vec<DataRecord>> {
|
||||
let query = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription WHERE {
|
||||
{
|
||||
?item wdt:P31 wd:Q125977. # climate change
|
||||
} UNION {
|
||||
?item wdt:P279* wd:Q125977. # subclass of climate change
|
||||
} UNION {
|
||||
?item wdt:P921 wd:Q125977. # main subject climate change
|
||||
}
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
LIMIT 100
|
||||
"#;
|
||||
|
||||
self.sparql_to_records(query, "climate").await
|
||||
}
|
||||
|
||||
/// Query pharmaceutical companies
|
||||
pub async fn query_pharmaceutical_companies(&self) -> Result<Vec<DataRecord>> {
|
||||
let query = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription ?founded ?employees WHERE {
|
||||
?item wdt:P31/wdt:P279* wd:Q507443. # pharmaceutical company
|
||||
OPTIONAL { ?item wdt:P571 ?founded. }
|
||||
OPTIONAL { ?item wdt:P1128 ?employees. }
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
LIMIT 100
|
||||
"#;
|
||||
|
||||
self.sparql_to_records(query, "pharma").await
|
||||
}
|
||||
|
||||
/// Query disease outbreaks
|
||||
pub async fn query_disease_outbreaks(&self) -> Result<Vec<DataRecord>> {
|
||||
let query = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription ?disease ?diseaseLabel ?startTime ?location ?locationLabel WHERE {
|
||||
?item wdt:P31 wd:Q3241045. # epidemic
|
||||
OPTIONAL { ?item wdt:P828 ?disease. }
|
||||
OPTIONAL { ?item wdt:P580 ?startTime. }
|
||||
OPTIONAL { ?item wdt:P276 ?location. }
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
LIMIT 100
|
||||
"#;
|
||||
|
||||
self.sparql_to_records(query, "disease").await
|
||||
}
|
||||
|
||||
/// Convert SPARQL results to DataRecords
|
||||
async fn sparql_to_records(&self, query: &str, category: &str) -> Result<Vec<DataRecord>> {
|
||||
let results = self.sparql_query(query).await?;
|
||||
|
||||
let mut records = Vec::new();
|
||||
for result in results {
|
||||
// Extract QID from URI
|
||||
let item_uri = result.get("item").cloned().unwrap_or_default();
|
||||
let qid = item_uri
|
||||
.split('/')
|
||||
.last()
|
||||
.unwrap_or(&item_uri)
|
||||
.to_string();
|
||||
|
||||
let label = result
|
||||
.get("itemLabel")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| qid.clone());
|
||||
let description = result.get("itemDescription").cloned().unwrap_or_default();
|
||||
|
||||
// Create embedding from label and description
|
||||
let text = format!("{} {}", label, description);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
let mut data_map = serde_json::Map::new();
|
||||
data_map.insert("qid".to_string(), serde_json::json!(qid));
|
||||
data_map.insert("label".to_string(), serde_json::json!(label));
|
||||
data_map.insert("description".to_string(), serde_json::json!(description));
|
||||
data_map.insert("category".to_string(), serde_json::json!(category));
|
||||
|
||||
// Add all other SPARQL result fields
|
||||
for (key, value) in result.iter() {
|
||||
if !key.ends_with("Label") && key != "item" && key != "itemDescription" {
|
||||
data_map.insert(key.clone(), serde_json::json!(value));
|
||||
}
|
||||
}
|
||||
|
||||
records.push(DataRecord {
|
||||
id: format!("wikidata_{}", qid),
|
||||
source: "wikidata".to_string(),
|
||||
record_type: category.to_string(),
|
||||
timestamp: Utc::now(),
|
||||
data: serde_json::Value::Object(data_map),
|
||||
embedding: Some(embedding),
|
||||
relationships: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
/// Convert entity data to WikidataEntity
|
||||
fn entity_data_to_entity(&self, data: &WikidataEntityData) -> Result<WikidataEntity> {
|
||||
let label = data
|
||||
.labels
|
||||
.get("en")
|
||||
.map(|l| l.value.clone())
|
||||
.unwrap_or_else(|| data.id.clone());
|
||||
|
||||
let description = data
|
||||
.descriptions
|
||||
.get("en")
|
||||
.map(|d| d.value.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let aliases = data
|
||||
.aliases
|
||||
.get("en")
|
||||
.map(|aliases| aliases.iter().map(|a| a.value.clone()).collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut claims = HashMap::new();
|
||||
for (property, claim_list) in &data.claims {
|
||||
let values: Vec<String> = claim_list
|
||||
.iter()
|
||||
.filter_map(|claim| {
|
||||
claim
|
||||
.mainsnak
|
||||
.datavalue
|
||||
.as_ref()
|
||||
.map(|dv| dv.value.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !values.is_empty() {
|
||||
claims.insert(property.clone(), values);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(WikidataEntity {
|
||||
qid: data.id.clone(),
|
||||
label,
|
||||
description,
|
||||
aliases,
|
||||
claims,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert WikidataEntity to DataRecord
|
||||
fn entity_to_record(&self, entity: &WikidataEntity) -> Result<DataRecord> {
|
||||
// Create embedding from label, description, and aliases
|
||||
let text = format!(
|
||||
"{} {} {}",
|
||||
entity.label,
|
||||
entity.description,
|
||||
entity.aliases.join(" ")
|
||||
);
|
||||
let embedding = self.embedder.embed_text(&text);
|
||||
|
||||
// Build relationships from claims
|
||||
let mut relationships = Vec::new();
|
||||
for (property, values) in &entity.claims {
|
||||
for value in values {
|
||||
// Try to extract QID if value is an entity reference
|
||||
if let Some(qid) = value.strip_prefix("Q") {
|
||||
if qid.chars().all(|c| c.is_ascii_digit()) {
|
||||
relationships.push(Relationship {
|
||||
target_id: value.clone(),
|
||||
rel_type: property.clone(),
|
||||
weight: 1.0,
|
||||
properties: HashMap::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut data_map = serde_json::Map::new();
|
||||
data_map.insert("qid".to_string(), serde_json::json!(entity.qid));
|
||||
data_map.insert("label".to_string(), serde_json::json!(entity.label));
|
||||
data_map.insert(
|
||||
"description".to_string(),
|
||||
serde_json::json!(entity.description),
|
||||
);
|
||||
data_map.insert("aliases".to_string(), serde_json::json!(entity.aliases));
|
||||
data_map.insert(
|
||||
"url".to_string(),
|
||||
serde_json::json!(format!(
|
||||
"https://www.wikidata.org/wiki/{}",
|
||||
entity.qid
|
||||
)),
|
||||
);
|
||||
|
||||
// Add claims as structured data
|
||||
let claims_json: serde_json::Value = serde_json::to_value(&entity.claims)?;
|
||||
data_map.insert("claims".to_string(), claims_json);
|
||||
|
||||
Ok(DataRecord {
|
||||
id: format!("wikidata_{}", entity.qid),
|
||||
source: "wikidata".to_string(),
|
||||
record_type: "entity".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
data: serde_json::Value::Object(data_map),
|
||||
embedding: Some(embedding),
|
||||
relationships,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch with retry logic
|
||||
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match self.client.get(url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
|
||||
{
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
continue;
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(_) if retries < MAX_RETRIES => {
|
||||
retries += 1;
|
||||
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
|
||||
}
|
||||
Err(e) => return Err(FrameworkError::Network(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WikidataClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create WikidataClient")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DataSource for WikidataClient {
|
||||
fn source_id(&self) -> &str {
|
||||
"wikidata"
|
||||
}
|
||||
|
||||
async fn fetch_batch(
|
||||
&self,
|
||||
cursor: Option<String>,
|
||||
_batch_size: usize,
|
||||
) -> Result<(Vec<DataRecord>, Option<String>)> {
|
||||
// Use cursor to determine which query to run
|
||||
let records = match cursor.as_deref() {
|
||||
Some("climate") => self.query_climate_entities().await?,
|
||||
Some("pharma") => self.query_pharmaceutical_companies().await?,
|
||||
Some("disease") => self.query_disease_outbreaks().await?,
|
||||
_ => {
|
||||
// Default: search for "artificial intelligence"
|
||||
let entities = self.search_entities("artificial intelligence").await?;
|
||||
let mut records = Vec::new();
|
||||
for entity in entities.iter().take(20) {
|
||||
records.push(self.entity_to_record(entity)?);
|
||||
}
|
||||
records
|
||||
}
|
||||
};
|
||||
|
||||
Ok((records, None))
|
||||
}
|
||||
|
||||
async fn total_count(&self) -> Result<Option<u64>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<bool> {
|
||||
let response = self.client.get(&self.api_url).send().await?;
|
||||
Ok(response.status().is_success())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Example SPARQL Queries
|
||||
// ============================================================================
|
||||
|
||||
/// Pre-defined SPARQL query templates
|
||||
pub mod sparql_queries {
|
||||
/// Query for climate change related entities
|
||||
pub const CLIMATE_CHANGE: &str = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription WHERE {
|
||||
{
|
||||
?item wdt:P31 wd:Q125977. # instance of climate change
|
||||
} UNION {
|
||||
?item wdt:P279* wd:Q125977. # subclass of climate change
|
||||
} UNION {
|
||||
?item wdt:P921 wd:Q125977. # main subject climate change
|
||||
}
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
LIMIT 100
|
||||
"#;
|
||||
|
||||
/// Query for pharmaceutical companies
|
||||
pub const PHARMACEUTICAL_COMPANIES: &str = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription ?founded ?employees ?headquarters ?headquartersLabel WHERE {
|
||||
?item wdt:P31/wdt:P279* wd:Q507443. # pharmaceutical company
|
||||
OPTIONAL { ?item wdt:P571 ?founded. }
|
||||
OPTIONAL { ?item wdt:P1128 ?employees. }
|
||||
OPTIONAL { ?item wdt:P159 ?headquarters. }
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
ORDER BY DESC(?employees)
|
||||
LIMIT 100
|
||||
"#;
|
||||
|
||||
/// Query for disease outbreaks
|
||||
pub const DISEASE_OUTBREAKS: &str = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription ?disease ?diseaseLabel ?startTime ?endTime ?location ?locationLabel ?deaths WHERE {
|
||||
?item wdt:P31 wd:Q3241045. # epidemic
|
||||
OPTIONAL { ?item wdt:P828 ?disease. }
|
||||
OPTIONAL { ?item wdt:P580 ?startTime. }
|
||||
OPTIONAL { ?item wdt:P582 ?endTime. }
|
||||
OPTIONAL { ?item wdt:P276 ?location. }
|
||||
OPTIONAL { ?item wdt:P1120 ?deaths. }
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
ORDER BY DESC(?startTime)
|
||||
LIMIT 100
|
||||
"#;
|
||||
|
||||
/// Query for scientific research institutions
|
||||
pub const RESEARCH_INSTITUTIONS: &str = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription ?country ?countryLabel ?founded WHERE {
|
||||
?item wdt:P31/wdt:P279* wd:Q31855. # research institute
|
||||
OPTIONAL { ?item wdt:P17 ?country. }
|
||||
OPTIONAL { ?item wdt:P571 ?founded. }
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
LIMIT 100
|
||||
"#;
|
||||
|
||||
/// Query for Nobel Prize winners in specific field
|
||||
pub const NOBEL_LAUREATES: &str = r#"
|
||||
SELECT ?item ?itemLabel ?itemDescription ?award ?awardLabel ?year ?field ?fieldLabel WHERE {
|
||||
?item wdt:P166 ?award.
|
||||
?award wdt:P279* wd:Q7191. # Nobel Prize
|
||||
OPTIONAL { ?item wdt:P166 ?award. ?award wdt:P585 ?year. }
|
||||
OPTIONAL { ?award wdt:P101 ?field. }
|
||||
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
|
||||
}
|
||||
ORDER BY DESC(?year)
|
||||
LIMIT 100
|
||||
"#;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wikipedia_client_creation() {
|
||||
let client = WikipediaClient::new("en".to_string());
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wikidata_client_creation() {
|
||||
let client = WikidataClient::new();
|
||||
assert!(client.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wikidata_entity_serialization() {
|
||||
let mut claims = HashMap::new();
|
||||
claims.insert("P31".to_string(), vec!["Q5".to_string()]);
|
||||
|
||||
let entity = WikidataEntity {
|
||||
qid: "Q42".to_string(),
|
||||
label: "Douglas Adams".to_string(),
|
||||
description: "English writer and humorist".to_string(),
|
||||
aliases: vec!["Douglas Noel Adams".to_string()],
|
||||
claims,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&entity).unwrap();
|
||||
let parsed: WikidataEntity = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.qid, "Q42");
|
||||
assert_eq!(parsed.label, "Douglas Adams");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sparql_query_templates() {
|
||||
assert!(!sparql_queries::CLIMATE_CHANGE.is_empty());
|
||||
assert!(!sparql_queries::PHARMACEUTICAL_COMPANIES.is_empty());
|
||||
assert!(!sparql_queries::DISEASE_OUTBREAKS.is_empty());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user