Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,571 @@
//! ArXiv Preprint API Integration
//!
//! This module provides an async client for fetching academic preprints from ArXiv.org,
//! converting responses to SemanticVector format for RuVector discovery.
//!
//! # ArXiv API Details
//! - Base URL: https://export.arxiv.org/api/query
//! - Free access, no authentication required
//! - Returns Atom XML feed
//! - Rate limit: 1 request per 3 seconds (enforced by client)
//!
//! # Example
//! ```rust,ignore
//! use ruvector_data_framework::arxiv_client::ArxivClient;
//!
//! let client = ArxivClient::new();
//!
//! // Search papers by keywords
//! let vectors = client.search("machine learning", 10).await?;
//!
//! // Search by category
//! let ai_papers = client.search_category("cs.AI", 20).await?;
//!
//! // Get recent papers in a category
//! let recent = client.search_recent("cs.LG", 7).await?;
//! ```
use std::collections::HashMap;
use std::time::Duration;
use chrono::{DateTime, NaiveDateTime, Utc};
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use tokio::time::sleep;
use crate::api_clients::SimpleEmbedder;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Rate limiting configuration for ArXiv API
const ARXIV_RATE_LIMIT_MS: u64 = 3000; // 3 seconds between requests
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 2000;
const DEFAULT_EMBEDDING_DIM: usize = 384;
// ============================================================================
// ArXiv Atom Feed Structures
// ============================================================================
/// ArXiv API Atom feed response
#[derive(Debug, Deserialize)]
struct ArxivFeed {
#[serde(rename = "entry", default)]
entries: Vec<ArxivEntry>,
#[serde(rename = "totalResults", default)]
total_results: Option<TotalResults>,
}
#[derive(Debug, Deserialize)]
struct TotalResults {
#[serde(rename = "$value", default)]
value: Option<String>,
}
/// ArXiv entry (paper)
#[derive(Debug, Deserialize)]
struct ArxivEntry {
#[serde(rename = "id")]
id: String,
#[serde(rename = "title")]
title: String,
#[serde(rename = "summary")]
summary: String,
#[serde(rename = "published")]
published: String,
#[serde(rename = "updated", default)]
updated: Option<String>,
#[serde(rename = "author", default)]
authors: Vec<ArxivAuthor>,
#[serde(rename = "category", default)]
categories: Vec<ArxivCategory>,
#[serde(rename = "link", default)]
links: Vec<ArxivLink>,
}
#[derive(Debug, Deserialize)]
struct ArxivAuthor {
#[serde(rename = "name")]
name: String,
}
#[derive(Debug, Deserialize)]
struct ArxivCategory {
#[serde(rename = "@term")]
term: String,
}
#[derive(Debug, Deserialize)]
struct ArxivLink {
#[serde(rename = "@href")]
href: String,
#[serde(rename = "@type", default)]
link_type: Option<String>,
#[serde(rename = "@title", default)]
title: Option<String>,
}
// ============================================================================
// ArXiv Client
// ============================================================================
/// Client for ArXiv.org preprint API
///
/// Provides methods to search for academic papers, filter by category,
/// and convert results to SemanticVector format for RuVector analysis.
///
/// # Rate Limiting
/// The client automatically enforces ArXiv's rate limit of 1 request per 3 seconds.
/// Includes retry logic for transient failures.
pub struct ArxivClient {
client: Client,
embedder: SimpleEmbedder,
base_url: String,
}
impl ArxivClient {
/// Create a new ArXiv API client
///
/// # Example
/// ```rust,ignore
/// let client = ArxivClient::new();
/// ```
pub fn new() -> Self {
Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
}
/// Create a new ArXiv API client with custom embedding dimension
///
/// # Arguments
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
pub fn with_embedding_dim(embedding_dim: usize) -> Self {
Self {
client: Client::builder()
.user_agent("RuVector-Discovery/1.0")
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client"),
embedder: SimpleEmbedder::new(embedding_dim),
base_url: "https://export.arxiv.org/api/query".to_string(),
}
}
/// Search papers by keywords
///
/// # Arguments
/// * `query` - Search query (keywords, title, author, etc.)
/// * `max_results` - Maximum number of results to return
///
/// # Example
/// ```rust,ignore
/// let vectors = client.search("quantum computing", 50).await?;
/// ```
pub async fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticVector>> {
let encoded_query = urlencoding::encode(query);
let url = format!(
"{}?search_query=all:{}&start=0&max_results={}",
self.base_url, encoded_query, max_results
);
self.fetch_and_parse(&url).await
}
/// Search papers by ArXiv category
///
/// # Arguments
/// * `category` - ArXiv category code (e.g., "cs.AI", "physics.ao-ph", "q-fin.ST")
/// * `max_results` - Maximum number of results to return
///
/// # Supported Categories
/// - `cs.AI` - Artificial Intelligence
/// - `cs.LG` - Machine Learning
/// - `cs.CL` - Computation and Language
/// - `stat.ML` - Statistics - Machine Learning
/// - `q-fin.*` - Quantitative Finance (ST, PM, TR, etc.)
/// - `physics.ao-ph` - Atmospheric and Oceanic Physics
/// - `econ.*` - Economics
///
/// # Example
/// ```rust,ignore
/// let ai_papers = client.search_category("cs.AI", 100).await?;
/// let climate_papers = client.search_category("physics.ao-ph", 50).await?;
/// ```
pub async fn search_category(
&self,
category: &str,
max_results: usize,
) -> Result<Vec<SemanticVector>> {
let url = format!(
"{}?search_query=cat:{}&start=0&max_results={}&sortBy=submittedDate&sortOrder=descending",
self.base_url, category, max_results
);
self.fetch_and_parse(&url).await
}
/// Get a single paper by ArXiv ID
///
/// # Arguments
/// * `arxiv_id` - ArXiv paper ID (e.g., "2401.12345" or "arXiv:2401.12345")
///
/// # Example
/// ```rust,ignore
/// let paper = client.get_paper("2401.12345").await?;
/// ```
pub async fn get_paper(&self, arxiv_id: &str) -> Result<Option<SemanticVector>> {
// Strip "arXiv:" prefix if present
let id = arxiv_id.trim_start_matches("arXiv:");
let url = format!("{}?id_list={}", self.base_url, id);
let mut results = self.fetch_and_parse(&url).await?;
Ok(results.pop())
}
/// Search recent papers in a category within the last N days
///
/// # Arguments
/// * `category` - ArXiv category code
/// * `days` - Number of days to look back (default: 7)
///
/// # Example
/// ```rust,ignore
/// // Get ML papers from the last 3 days
/// let recent = client.search_recent("cs.LG", 3).await?;
/// ```
pub async fn search_recent(
&self,
category: &str,
days: u64,
) -> Result<Vec<SemanticVector>> {
let cutoff_date = Utc::now() - chrono::Duration::days(days as i64);
let url = format!(
"{}?search_query=cat:{}&start=0&max_results=100&sortBy=submittedDate&sortOrder=descending",
self.base_url, category
);
let all_results = self.fetch_and_parse(&url).await?;
// Filter by date
Ok(all_results
.into_iter()
.filter(|v| v.timestamp >= cutoff_date)
.collect())
}
/// Search papers across multiple categories
///
/// # Arguments
/// * `categories` - List of ArXiv category codes
/// * `max_results_per_category` - Maximum results per category
///
/// # Example
/// ```rust,ignore
/// let categories = vec!["cs.AI", "cs.LG", "stat.ML"];
/// let papers = client.search_multiple_categories(&categories, 20).await?;
/// ```
pub async fn search_multiple_categories(
&self,
categories: &[&str],
max_results_per_category: usize,
) -> Result<Vec<SemanticVector>> {
let mut all_vectors = Vec::new();
for category in categories {
match self.search_category(category, max_results_per_category).await {
Ok(mut vectors) => {
all_vectors.append(&mut vectors);
}
Err(e) => {
tracing::warn!("Failed to fetch category {}: {}", category, e);
}
}
// Rate limiting between categories
sleep(Duration::from_millis(ARXIV_RATE_LIMIT_MS)).await;
}
Ok(all_vectors)
}
/// Fetch and parse ArXiv Atom feed
async fn fetch_and_parse(&self, url: &str) -> Result<Vec<SemanticVector>> {
// Rate limiting
sleep(Duration::from_millis(ARXIV_RATE_LIMIT_MS)).await;
let response = self.fetch_with_retry(url).await?;
let xml = response.text().await?;
// Parse XML feed
let feed: ArxivFeed = quick_xml::de::from_str(&xml).map_err(|e| {
FrameworkError::Ingestion(format!("Failed to parse ArXiv XML: {}", e))
})?;
// Convert entries to SemanticVectors
let mut vectors = Vec::new();
for entry in feed.entries {
if let Some(vector) = self.entry_to_vector(entry) {
vectors.push(vector);
}
}
Ok(vectors)
}
/// Convert ArXiv entry to SemanticVector
fn entry_to_vector(&self, entry: ArxivEntry) -> Option<SemanticVector> {
// Extract ArXiv ID from full URL
let arxiv_id = entry
.id
.split('/')
.last()
.unwrap_or(&entry.id)
.to_string();
// Clean up title and abstract
let title = entry.title.trim().replace('\n', " ");
let abstract_text = entry.summary.trim().replace('\n', " ");
// Parse publication date
let timestamp = Self::parse_arxiv_date(&entry.published)?;
// Generate embedding from title + abstract
let combined_text = format!("{} {}", title, abstract_text);
let embedding = self.embedder.embed_text(&combined_text);
// Extract authors
let authors = entry
.authors
.iter()
.map(|a| a.name.clone())
.collect::<Vec<_>>()
.join(", ");
// Extract categories
let categories = entry
.categories
.iter()
.map(|c| c.term.clone())
.collect::<Vec<_>>()
.join(", ");
// Find PDF URL
let pdf_url = entry
.links
.iter()
.find(|l| l.title.as_deref() == Some("pdf"))
.map(|l| l.href.clone())
.unwrap_or_else(|| format!("https://arxiv.org/pdf/{}.pdf", arxiv_id));
// Build metadata
let mut metadata = HashMap::new();
metadata.insert("arxiv_id".to_string(), arxiv_id.clone());
metadata.insert("title".to_string(), title);
metadata.insert("abstract".to_string(), abstract_text);
metadata.insert("authors".to_string(), authors);
metadata.insert("categories".to_string(), categories);
metadata.insert("pdf_url".to_string(), pdf_url);
metadata.insert("source".to_string(), "arxiv".to_string());
Some(SemanticVector {
id: format!("arXiv:{}", arxiv_id),
embedding,
domain: Domain::Research,
timestamp,
metadata,
})
}
/// Parse ArXiv date format (ISO 8601)
fn parse_arxiv_date(date_str: &str) -> Option<DateTime<Utc>> {
// ArXiv uses ISO 8601 format: 2024-01-15T12:30:00Z
DateTime::parse_from_rfc3339(date_str)
.ok()
.map(|dt| dt.with_timezone(&Utc))
.or_else(|| {
// Fallback: try parsing without timezone
NaiveDateTime::parse_from_str(date_str, "%Y-%m-%dT%H:%M:%S")
.ok()
.map(|ndt| DateTime::from_naive_utc_and_offset(ndt, Utc))
})
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
{
retries += 1;
tracing::warn!("Rate limited by ArXiv, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
if !response.status().is_success() {
return Err(FrameworkError::Network(
reqwest::Error::from(response.error_for_status().unwrap_err()),
));
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for ArxivClient {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arxiv_client_creation() {
let client = ArxivClient::new();
assert_eq!(client.base_url, "https://export.arxiv.org/api/query");
}
#[test]
fn test_custom_embedding_dim() {
let client = ArxivClient::with_embedding_dim(512);
let embedding = client.embedder.embed_text("test");
assert_eq!(embedding.len(), 512);
}
#[test]
fn test_parse_arxiv_date() {
// Standard ISO 8601
let date1 = ArxivClient::parse_arxiv_date("2024-01-15T12:30:00Z");
assert!(date1.is_some());
// Without Z suffix
let date2 = ArxivClient::parse_arxiv_date("2024-01-15T12:30:00");
assert!(date2.is_some());
}
#[test]
fn test_entry_to_vector() {
let client = ArxivClient::new();
let entry = ArxivEntry {
id: "http://arxiv.org/abs/2401.12345v1".to_string(),
title: "Deep Learning for Climate Science".to_string(),
summary: "We propose a novel approach...".to_string(),
published: "2024-01-15T12:00:00Z".to_string(),
updated: None,
authors: vec![
ArxivAuthor {
name: "John Doe".to_string(),
},
ArxivAuthor {
name: "Jane Smith".to_string(),
},
],
categories: vec![
ArxivCategory {
term: "cs.LG".to_string(),
},
ArxivCategory {
term: "physics.ao-ph".to_string(),
},
],
links: vec![],
};
let vector = client.entry_to_vector(entry);
assert!(vector.is_some());
let v = vector.unwrap();
assert_eq!(v.id, "arXiv:2401.12345v1");
assert_eq!(v.domain, Domain::Research);
assert_eq!(v.metadata.get("arxiv_id").unwrap(), "2401.12345v1");
assert_eq!(
v.metadata.get("title").unwrap(),
"Deep Learning for Climate Science"
);
assert_eq!(v.metadata.get("authors").unwrap(), "John Doe, Jane Smith");
assert_eq!(v.metadata.get("categories").unwrap(), "cs.LG, physics.ao-ph");
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
async fn test_search_integration() {
let client = ArxivClient::new();
let results = client.search("machine learning", 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
if !vectors.is_empty() {
let first = &vectors[0];
assert!(first.id.starts_with("arXiv:"));
assert_eq!(first.domain, Domain::Research);
assert!(first.metadata.contains_key("title"));
assert!(first.metadata.contains_key("abstract"));
}
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
async fn test_search_category_integration() {
let client = ArxivClient::new();
let results = client.search_category("cs.AI", 3).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 3);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
async fn test_get_paper_integration() {
let client = ArxivClient::new();
// Try to fetch a known paper (this is a real arXiv ID)
let result = client.get_paper("2301.00001").await;
assert!(result.is_ok());
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
async fn test_search_recent_integration() {
let client = ArxivClient::new();
let results = client.search_recent("cs.LG", 7).await;
assert!(results.is_ok());
// Check that returned papers are within date range
let cutoff = Utc::now() - chrono::Duration::days(7);
for vector in results.unwrap() {
assert!(vector.timestamp >= cutoff);
}
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting ArXiv API in tests
async fn test_multiple_categories_integration() {
let client = ArxivClient::new();
let categories = vec!["cs.AI", "cs.LG"];
let results = client.search_multiple_categories(&categories, 2).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 4); // 2 categories * 2 results each
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,154 @@
//! MCP Discovery Server Binary
//!
//! Runs the RuVector MCP server for data discovery across 22+ data sources.
//!
//! # Usage
//!
//! ## STDIO Mode (default)
//! ```bash
//! cargo run --bin mcp_discovery
//! ```
//!
//! ## SSE Mode (HTTP streaming)
//! ```bash
//! cargo run --bin mcp_discovery -- --sse --port 3000
//! ```
//!
//! ## With custom configuration
//! ```bash
//! cargo run --bin mcp_discovery -- --config custom_config.json
//! ```
use std::process;
use clap::Parser;
use tracing_subscriber::{fmt, EnvFilter};
use ruvector_data_framework::mcp_server::{
McpDiscoveryServer, McpServerConfig, McpTransport,
};
use ruvector_data_framework::ruvector_native::NativeEngineConfig;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Use SSE transport (HTTP streaming) instead of STDIO
#[arg(long, default_value_t = false)]
sse: bool,
/// Port for SSE endpoint (only used with --sse)
#[arg(long, default_value_t = 3000)]
port: u16,
/// Endpoint address for SSE (only used with --sse)
#[arg(long, default_value = "127.0.0.1")]
endpoint: String,
/// Configuration file path
#[arg(short, long)]
config: Option<String>,
/// Minimum edge weight threshold
#[arg(long, default_value_t = 0.5)]
min_edge_weight: f64,
/// Vector similarity threshold
#[arg(long, default_value_t = 0.7)]
similarity_threshold: f64,
/// Enable cross-domain discovery
#[arg(long, default_value_t = true)]
cross_domain: bool,
/// Temporal window size in seconds
#[arg(long, default_value_t = 3600)]
window_seconds: i64,
/// HNSW M parameter (connections per layer)
#[arg(long, default_value_t = 16)]
hnsw_m: usize,
/// HNSW ef_construction parameter
#[arg(long, default_value_t = 200)]
hnsw_ef_construction: usize,
/// Vector dimension
#[arg(long, default_value_t = 384)]
dimension: usize,
/// Enable verbose logging
#[arg(short, long, default_value_t = false)]
verbose: bool,
}
#[tokio::main]
async fn main() {
let args = Args::parse();
// Initialize logging
let env_filter = if args.verbose {
EnvFilter::new("debug")
} else {
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"))
};
fmt()
.with_env_filter(env_filter)
.with_target(false)
.with_thread_ids(false)
.with_file(false)
.init();
// Load configuration
let engine_config = if let Some(config_path) = args.config {
match load_config_from_file(&config_path) {
Ok(config) => config,
Err(e) => {
eprintln!("Failed to load config from {}: {}", config_path, e);
process::exit(1);
}
}
} else {
NativeEngineConfig {
min_edge_weight: args.min_edge_weight,
similarity_threshold: args.similarity_threshold,
mincut_sensitivity: 0.1,
cross_domain: args.cross_domain,
window_seconds: args.window_seconds,
hnsw_m: args.hnsw_m,
hnsw_ef_construction: args.hnsw_ef_construction,
hnsw_ef_search: 50,
dimension: args.dimension,
batch_size: 1000,
checkpoint_interval: 10_000,
parallel_workers: num_cpus::get(),
}
};
// Create transport
let transport = if args.sse {
eprintln!("Starting MCP server in SSE mode on {}:{}", args.endpoint, args.port);
McpTransport::Sse {
endpoint: args.endpoint,
port: args.port,
}
} else {
eprintln!("Starting MCP server in STDIO mode");
McpTransport::Stdio
};
// Create and run server
let mut server = McpDiscoveryServer::new(transport, engine_config);
if let Err(e) = server.run().await {
eprintln!("Server error: {}", e);
process::exit(1);
}
}
fn load_config_from_file(path: &str) -> Result<NativeEngineConfig, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string(path)?;
let config: NativeEngineConfig = serde_json::from_str(&content)?;
Ok(config)
}

View File

@@ -0,0 +1,930 @@
//! bioRxiv and medRxiv Preprint API Integration
//!
//! This module provides async clients for fetching preprints from bioRxiv.org and medRxiv.org,
//! converting responses to SemanticVector format for RuVector discovery.
//!
//! # bioRxiv/medRxiv API Details
//! - Base URL: https://api.biorxiv.org/details/[server]/[interval]/[cursor]
//! - Free access, no authentication required
//! - Returns JSON with preprint metadata
//! - Rate limit: ~1 request per second (enforced by client)
//!
//! # Example
//! ```rust,ignore
//! use ruvector_data_framework::biorxiv_client::{BiorxivClient, MedrxivClient};
//!
//! // Life sciences preprints
//! let biorxiv = BiorxivClient::new();
//! let recent = biorxiv.search_recent(7, 50).await?;
//! let category_papers = biorxiv.search_by_category("neuroscience", 100).await?;
//!
//! // Medical preprints
//! let medrxiv = MedrxivClient::new();
//! let covid_papers = medrxiv.search_covid(100).await?;
//! let clinical = medrxiv.search_clinical(50).await?;
//! ```
use std::collections::HashMap;
use std::time::Duration;
use chrono::{NaiveDate, Utc};
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use tokio::time::sleep;
use crate::api_clients::SimpleEmbedder;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Rate limiting configuration
const BIORXIV_RATE_LIMIT_MS: u64 = 1000; // 1 second between requests (conservative)
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 2000;
const DEFAULT_EMBEDDING_DIM: usize = 384;
const DEFAULT_PAGE_SIZE: usize = 100;
// ============================================================================
// bioRxiv/medRxiv API Response Structures
// ============================================================================
/// API response from bioRxiv/medRxiv
#[derive(Debug, Deserialize)]
struct BiorxivApiResponse {
/// Total number of results
#[serde(default)]
count: Option<i64>,
/// Cursor for pagination (total number of records seen)
#[serde(default)]
cursor: Option<i64>,
/// Array of preprint records
#[serde(default)]
collection: Vec<PreprintRecord>,
}
/// Individual preprint record
#[derive(Debug, Deserialize)]
struct PreprintRecord {
/// DOI identifier
doi: String,
/// Paper title
title: String,
/// Authors (semicolon-separated)
authors: String,
/// Author corresponding information
#[serde(default)]
author_corresponding: Option<String>,
/// Author corresponding institution
#[serde(default)]
author_corresponding_institution: Option<String>,
/// Preprint publication date (YYYY-MM-DD)
date: String,
/// Subject category
category: String,
/// Abstract text
#[serde(rename = "abstract")]
abstract_text: String,
/// Journal publication status (if accepted)
#[serde(default)]
published: Option<String>,
/// Server (biorxiv or medrxiv)
#[serde(default)]
server: Option<String>,
/// Version number
#[serde(default)]
version: Option<String>,
/// Type (e.g., "new results")
#[serde(rename = "type", default)]
preprint_type: Option<String>,
}
// ============================================================================
// bioRxiv Client (Life Sciences Preprints)
// ============================================================================
/// Client for bioRxiv.org preprint API
///
/// Provides methods to search for life sciences preprints, filter by category,
/// and convert results to SemanticVector format for RuVector analysis.
///
/// # Categories
/// - neuroscience
/// - genomics
/// - bioinformatics
/// - cancer-biology
/// - immunology
/// - microbiology
/// - molecular-biology
/// - cell-biology
/// - biochemistry
/// - evolutionary-biology
/// - and many more...
///
/// # Rate Limiting
/// The client automatically enforces a rate limit of ~1 request per second.
/// Includes retry logic for transient failures.
pub struct BiorxivClient {
client: Client,
embedder: SimpleEmbedder,
base_url: String,
}
impl BiorxivClient {
/// Create a new bioRxiv API client
///
/// # Example
/// ```rust,ignore
/// let client = BiorxivClient::new();
/// ```
pub fn new() -> Self {
Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
}
/// Create a new bioRxiv API client with custom embedding dimension
///
/// # Arguments
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
pub fn with_embedding_dim(embedding_dim: usize) -> Self {
Self {
client: Client::builder()
.user_agent("RuVector-Discovery/1.0")
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client"),
embedder: SimpleEmbedder::new(embedding_dim),
base_url: "https://api.biorxiv.org".to_string(),
}
}
/// Get recent preprints from the last N days
///
/// # Arguments
/// * `days` - Number of days to look back (e.g., 7 for last week)
/// * `limit` - Maximum number of results to return
///
/// # Example
/// ```rust,ignore
/// // Get preprints from the last 7 days
/// let recent = client.search_recent(7, 100).await?;
/// ```
pub async fn search_recent(&self, days: u64, limit: usize) -> Result<Vec<SemanticVector>> {
let end_date = Utc::now().date_naive();
let start_date = end_date - chrono::Duration::days(days as i64);
self.search_by_date_range(start_date, end_date, Some(limit)).await
}
/// Search preprints by date range
///
/// # Arguments
/// * `start_date` - Start date (inclusive)
/// * `end_date` - End date (inclusive)
/// * `limit` - Optional maximum number of results
///
/// # Example
/// ```rust,ignore
/// use chrono::NaiveDate;
///
/// let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
/// let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
/// let papers = client.search_by_date_range(start, end, Some(200)).await?;
/// ```
pub async fn search_by_date_range(
&self,
start_date: NaiveDate,
end_date: NaiveDate,
limit: Option<usize>,
) -> Result<Vec<SemanticVector>> {
let interval = format!("{}/{}", start_date, end_date);
self.fetch_with_pagination("biorxiv", &interval, limit).await
}
/// Search preprints by subject category
///
/// # Arguments
/// * `category` - Subject category (e.g., "neuroscience", "genomics")
/// * `limit` - Maximum number of results to return
///
/// # Categories
/// - neuroscience
/// - genomics
/// - bioinformatics
/// - cancer-biology
/// - immunology
/// - microbiology
/// - molecular-biology
/// - cell-biology
/// - biochemistry
/// - evolutionary-biology
/// - ecology
/// - genetics
/// - developmental-biology
/// - synthetic-biology
/// - systems-biology
///
/// # Example
/// ```rust,ignore
/// let neuroscience_papers = client.search_by_category("neuroscience", 100).await?;
/// ```
pub async fn search_by_category(
&self,
category: &str,
limit: usize,
) -> Result<Vec<SemanticVector>> {
// Get recent papers (last 365 days) and filter by category
let end_date = Utc::now().date_naive();
let start_date = end_date - chrono::Duration::days(365);
let all_papers = self.search_by_date_range(start_date, end_date, Some(limit * 2)).await?;
// Filter by category
Ok(all_papers
.into_iter()
.filter(|v| {
v.metadata
.get("category")
.map(|cat| cat.to_lowercase().contains(&category.to_lowercase()))
.unwrap_or(false)
})
.take(limit)
.collect())
}
/// Fetch preprints with pagination support
async fn fetch_with_pagination(
&self,
server: &str,
interval: &str,
limit: Option<usize>,
) -> Result<Vec<SemanticVector>> {
let mut all_vectors = Vec::new();
let mut cursor = 0;
let limit = limit.unwrap_or(usize::MAX);
loop {
if all_vectors.len() >= limit {
break;
}
let url = format!("{}/details/{}/{}/{}", self.base_url, server, interval, cursor);
// Rate limiting
sleep(Duration::from_millis(BIORXIV_RATE_LIMIT_MS)).await;
let response = self.fetch_with_retry(&url).await?;
let api_response: BiorxivApiResponse = response.json().await?;
if api_response.collection.is_empty() {
break;
}
// Convert records to vectors
for record in api_response.collection {
if all_vectors.len() >= limit {
break;
}
if let Some(vector) = self.record_to_vector(record, server) {
all_vectors.push(vector);
}
}
// Update cursor for next page
if let Some(new_cursor) = api_response.cursor {
if new_cursor as usize <= cursor {
// No more pages
break;
}
cursor = new_cursor as usize;
} else {
break;
}
// Safety check: don't paginate indefinitely
if cursor > 10000 {
tracing::warn!("Pagination cursor exceeded 10000, stopping");
break;
}
}
Ok(all_vectors)
}
/// Convert preprint record to SemanticVector
fn record_to_vector(&self, record: PreprintRecord, server: &str) -> Option<SemanticVector> {
// Clean up title and abstract
let title = record.title.trim().replace('\n', " ");
let abstract_text = record.abstract_text.trim().replace('\n', " ");
// Parse publication date
let timestamp = NaiveDate::parse_from_str(&record.date, "%Y-%m-%d")
.ok()
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
// Generate embedding from title + abstract
let combined_text = format!("{} {}", title, abstract_text);
let embedding = self.embedder.embed_text(&combined_text);
// Determine publication status
let published_status = record.published.unwrap_or_else(|| "preprint".to_string());
// Build metadata
let mut metadata = HashMap::new();
metadata.insert("doi".to_string(), record.doi.clone());
metadata.insert("title".to_string(), title);
metadata.insert("abstract".to_string(), abstract_text);
metadata.insert("authors".to_string(), record.authors);
metadata.insert("category".to_string(), record.category);
metadata.insert("server".to_string(), server.to_string());
metadata.insert("published_status".to_string(), published_status);
if let Some(corr) = record.author_corresponding {
metadata.insert("corresponding_author".to_string(), corr);
}
if let Some(inst) = record.author_corresponding_institution {
metadata.insert("institution".to_string(), inst);
}
if let Some(version) = record.version {
metadata.insert("version".to_string(), version);
}
if let Some(ptype) = record.preprint_type {
metadata.insert("type".to_string(), ptype);
}
metadata.insert("source".to_string(), "biorxiv".to_string());
// bioRxiv papers are research domain
Some(SemanticVector {
id: format!("doi:{}", record.doi),
embedding,
domain: Domain::Research,
timestamp,
metadata,
})
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
if !response.status().is_success() {
return Err(FrameworkError::Network(
reqwest::Error::from(response.error_for_status().unwrap_err()),
));
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for BiorxivClient {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// medRxiv Client (Medical Preprints)
// ============================================================================
/// Client for medRxiv.org preprint API
///
/// Provides methods to search for medical and health sciences preprints,
/// filter by specialty, and convert results to SemanticVector format.
///
/// # Categories
/// - Cardiovascular Medicine
/// - Infectious Diseases
/// - Oncology
/// - Public Health
/// - Epidemiology
/// - Psychiatry
/// - and many more...
///
/// # Rate Limiting
/// The client automatically enforces a rate limit of ~1 request per second.
/// Includes retry logic for transient failures.
pub struct MedrxivClient {
client: Client,
embedder: SimpleEmbedder,
base_url: String,
}
impl MedrxivClient {
/// Create a new medRxiv API client
///
/// # Example
/// ```rust,ignore
/// let client = MedrxivClient::new();
/// ```
pub fn new() -> Self {
Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
}
/// Create a new medRxiv API client with custom embedding dimension
///
/// # Arguments
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
pub fn with_embedding_dim(embedding_dim: usize) -> Self {
Self {
client: Client::builder()
.user_agent("RuVector-Discovery/1.0")
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client"),
embedder: SimpleEmbedder::new(embedding_dim),
base_url: "https://api.biorxiv.org".to_string(),
}
}
/// Get recent preprints from the last N days
///
/// # Arguments
/// * `days` - Number of days to look back (e.g., 7 for last week)
/// * `limit` - Maximum number of results to return
///
/// # Example
/// ```rust,ignore
/// // Get medical preprints from the last 7 days
/// let recent = client.search_recent(7, 100).await?;
/// ```
pub async fn search_recent(&self, days: u64, limit: usize) -> Result<Vec<SemanticVector>> {
let end_date = Utc::now().date_naive();
let start_date = end_date - chrono::Duration::days(days as i64);
self.search_by_date_range(start_date, end_date, Some(limit)).await
}
/// Search preprints by date range
///
/// # Arguments
/// * `start_date` - Start date (inclusive)
/// * `end_date` - End date (inclusive)
/// * `limit` - Optional maximum number of results
///
/// # Example
/// ```rust,ignore
/// use chrono::NaiveDate;
///
/// let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
/// let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
/// let papers = client.search_by_date_range(start, end, Some(200)).await?;
/// ```
pub async fn search_by_date_range(
&self,
start_date: NaiveDate,
end_date: NaiveDate,
limit: Option<usize>,
) -> Result<Vec<SemanticVector>> {
let interval = format!("{}/{}", start_date, end_date);
self.fetch_with_pagination("medrxiv", &interval, limit).await
}
/// Search COVID-19 related preprints
///
/// # Arguments
/// * `limit` - Maximum number of results to return
///
/// # Example
/// ```rust,ignore
/// let covid_papers = client.search_covid(100).await?;
/// ```
pub async fn search_covid(&self, limit: usize) -> Result<Vec<SemanticVector>> {
// Search for COVID-19 related papers from 2020 onwards
let end_date = Utc::now().date_naive();
let start_date = NaiveDate::from_ymd_opt(2020, 1, 1).expect("Valid date");
let all_papers = self.search_by_date_range(start_date, end_date, Some(limit * 2)).await?;
// Filter by COVID-19 related keywords
Ok(all_papers
.into_iter()
.filter(|v| {
let title = v.metadata.get("title").map(|s| s.to_lowercase()).unwrap_or_default();
let abstract_text = v.metadata.get("abstract").map(|s| s.to_lowercase()).unwrap_or_default();
let category = v.metadata.get("category").map(|s| s.to_lowercase()).unwrap_or_default();
let keywords = ["covid", "sars-cov-2", "coronavirus", "pandemic"];
keywords.iter().any(|kw| {
title.contains(kw) || abstract_text.contains(kw) || category.contains(kw)
})
})
.take(limit)
.collect())
}
/// Search clinical research preprints
///
/// # Arguments
/// * `limit` - Maximum number of results to return
///
/// # Example
/// ```rust,ignore
/// let clinical_papers = client.search_clinical(50).await?;
/// ```
pub async fn search_clinical(&self, limit: usize) -> Result<Vec<SemanticVector>> {
// Get recent papers and filter for clinical research
let end_date = Utc::now().date_naive();
let start_date = end_date - chrono::Duration::days(365);
let all_papers = self.search_by_date_range(start_date, end_date, Some(limit * 2)).await?;
// Filter by clinical keywords
Ok(all_papers
.into_iter()
.filter(|v| {
let title = v.metadata.get("title").map(|s| s.to_lowercase()).unwrap_or_default();
let abstract_text = v.metadata.get("abstract").map(|s| s.to_lowercase()).unwrap_or_default();
let category = v.metadata.get("category").map(|s| s.to_lowercase()).unwrap_or_default();
let keywords = ["clinical", "trial", "patient", "treatment", "therapy", "diagnosis"];
keywords.iter().any(|kw| {
title.contains(kw) || abstract_text.contains(kw) || category.contains(kw)
})
})
.take(limit)
.collect())
}
/// Fetch preprints with pagination support
async fn fetch_with_pagination(
&self,
server: &str,
interval: &str,
limit: Option<usize>,
) -> Result<Vec<SemanticVector>> {
let mut all_vectors = Vec::new();
let mut cursor = 0;
let limit = limit.unwrap_or(usize::MAX);
loop {
if all_vectors.len() >= limit {
break;
}
let url = format!("{}/details/{}/{}/{}", self.base_url, server, interval, cursor);
// Rate limiting
sleep(Duration::from_millis(BIORXIV_RATE_LIMIT_MS)).await;
let response = self.fetch_with_retry(&url).await?;
let api_response: BiorxivApiResponse = response.json().await?;
if api_response.collection.is_empty() {
break;
}
// Convert records to vectors
for record in api_response.collection {
if all_vectors.len() >= limit {
break;
}
if let Some(vector) = self.record_to_vector(record, server) {
all_vectors.push(vector);
}
}
// Update cursor for next page
if let Some(new_cursor) = api_response.cursor {
if new_cursor as usize <= cursor {
// No more pages
break;
}
cursor = new_cursor as usize;
} else {
break;
}
// Safety check: don't paginate indefinitely
if cursor > 10000 {
tracing::warn!("Pagination cursor exceeded 10000, stopping");
break;
}
}
Ok(all_vectors)
}
/// Convert preprint record to SemanticVector
fn record_to_vector(&self, record: PreprintRecord, server: &str) -> Option<SemanticVector> {
// Clean up title and abstract
let title = record.title.trim().replace('\n', " ");
let abstract_text = record.abstract_text.trim().replace('\n', " ");
// Parse publication date
let timestamp = NaiveDate::parse_from_str(&record.date, "%Y-%m-%d")
.ok()
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
// Generate embedding from title + abstract
let combined_text = format!("{} {}", title, abstract_text);
let embedding = self.embedder.embed_text(&combined_text);
// Determine publication status
let published_status = record.published.unwrap_or_else(|| "preprint".to_string());
// Build metadata
let mut metadata = HashMap::new();
metadata.insert("doi".to_string(), record.doi.clone());
metadata.insert("title".to_string(), title);
metadata.insert("abstract".to_string(), abstract_text);
metadata.insert("authors".to_string(), record.authors);
metadata.insert("category".to_string(), record.category);
metadata.insert("server".to_string(), server.to_string());
metadata.insert("published_status".to_string(), published_status);
if let Some(corr) = record.author_corresponding {
metadata.insert("corresponding_author".to_string(), corr);
}
if let Some(inst) = record.author_corresponding_institution {
metadata.insert("institution".to_string(), inst);
}
if let Some(version) = record.version {
metadata.insert("version".to_string(), version);
}
if let Some(ptype) = record.preprint_type {
metadata.insert("type".to_string(), ptype);
}
metadata.insert("source".to_string(), "medrxiv".to_string());
// medRxiv papers are medical domain
Some(SemanticVector {
id: format!("doi:{}", record.doi),
embedding,
domain: Domain::Medical,
timestamp,
metadata,
})
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
if !response.status().is_success() {
return Err(FrameworkError::Network(
reqwest::Error::from(response.error_for_status().unwrap_err()),
));
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for MedrxivClient {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_biorxiv_client_creation() {
let client = BiorxivClient::new();
assert_eq!(client.base_url, "https://api.biorxiv.org");
}
#[test]
fn test_medrxiv_client_creation() {
let client = MedrxivClient::new();
assert_eq!(client.base_url, "https://api.biorxiv.org");
}
#[test]
fn test_custom_embedding_dim() {
let client = BiorxivClient::with_embedding_dim(512);
let embedding = client.embedder.embed_text("test");
assert_eq!(embedding.len(), 512);
}
#[test]
fn test_record_to_vector_biorxiv() {
let client = BiorxivClient::new();
let record = PreprintRecord {
doi: "10.1101/2024.01.01.123456".to_string(),
title: "Deep Learning for Neuroscience".to_string(),
authors: "John Doe; Jane Smith".to_string(),
author_corresponding: Some("John Doe".to_string()),
author_corresponding_institution: Some("MIT".to_string()),
date: "2024-01-15".to_string(),
category: "Neuroscience".to_string(),
abstract_text: "We propose a novel approach for analyzing neural data...".to_string(),
published: None,
server: Some("biorxiv".to_string()),
version: Some("1".to_string()),
preprint_type: Some("new results".to_string()),
};
let vector = client.record_to_vector(record, "biorxiv");
assert!(vector.is_some());
let v = vector.unwrap();
assert_eq!(v.id, "doi:10.1101/2024.01.01.123456");
assert_eq!(v.domain, Domain::Research);
assert_eq!(v.metadata.get("doi").unwrap(), "10.1101/2024.01.01.123456");
assert_eq!(v.metadata.get("title").unwrap(), "Deep Learning for Neuroscience");
assert_eq!(v.metadata.get("authors").unwrap(), "John Doe; Jane Smith");
assert_eq!(v.metadata.get("category").unwrap(), "Neuroscience");
assert_eq!(v.metadata.get("server").unwrap(), "biorxiv");
assert_eq!(v.metadata.get("published_status").unwrap(), "preprint");
}
#[test]
fn test_record_to_vector_medrxiv() {
let client = MedrxivClient::new();
let record = PreprintRecord {
doi: "10.1101/2024.01.01.654321".to_string(),
title: "COVID-19 Vaccine Efficacy Study".to_string(),
authors: "Alice Johnson; Bob Williams".to_string(),
author_corresponding: Some("Alice Johnson".to_string()),
author_corresponding_institution: Some("Harvard Medical School".to_string()),
date: "2024-03-20".to_string(),
category: "Infectious Diseases".to_string(),
abstract_text: "This study evaluates the efficacy of mRNA vaccines...".to_string(),
published: Some("Nature Medicine".to_string()),
server: Some("medrxiv".to_string()),
version: Some("2".to_string()),
preprint_type: Some("new results".to_string()),
};
let vector = client.record_to_vector(record, "medrxiv");
assert!(vector.is_some());
let v = vector.unwrap();
assert_eq!(v.id, "doi:10.1101/2024.01.01.654321");
assert_eq!(v.domain, Domain::Medical);
assert_eq!(v.metadata.get("doi").unwrap(), "10.1101/2024.01.01.654321");
assert_eq!(v.metadata.get("title").unwrap(), "COVID-19 Vaccine Efficacy Study");
assert_eq!(v.metadata.get("category").unwrap(), "Infectious Diseases");
assert_eq!(v.metadata.get("server").unwrap(), "medrxiv");
assert_eq!(v.metadata.get("published_status").unwrap(), "Nature Medicine");
}
#[test]
fn test_date_parsing() {
let client = BiorxivClient::new();
let record = PreprintRecord {
doi: "10.1101/test".to_string(),
title: "Test".to_string(),
authors: "Author".to_string(),
author_corresponding: None,
author_corresponding_institution: None,
date: "2024-01-15".to_string(),
category: "Test".to_string(),
abstract_text: "Abstract".to_string(),
published: None,
server: None,
version: None,
preprint_type: None,
};
let vector = client.record_to_vector(record, "biorxiv").unwrap();
// Check that date was parsed correctly
let expected_date = NaiveDate::from_ymd_opt(2024, 1, 15)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc();
assert_eq!(vector.timestamp, expected_date);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting bioRxiv API in tests
async fn test_search_recent_integration() {
let client = BiorxivClient::new();
let results = client.search_recent(7, 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
if !vectors.is_empty() {
let first = &vectors[0];
assert!(first.id.starts_with("doi:"));
assert_eq!(first.domain, Domain::Research);
assert!(first.metadata.contains_key("title"));
assert!(first.metadata.contains_key("abstract"));
}
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting medRxiv API in tests
async fn test_medrxiv_search_recent_integration() {
let client = MedrxivClient::new();
let results = client.search_recent(7, 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
if !vectors.is_empty() {
let first = &vectors[0];
assert!(first.id.starts_with("doi:"));
assert_eq!(first.domain, Domain::Medical);
assert!(first.metadata.contains_key("title"));
assert!(first.metadata.contains_key("server"));
}
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting API
async fn test_search_covid_integration() {
let client = MedrxivClient::new();
let results = client.search_covid(10).await;
assert!(results.is_ok());
let vectors = results.unwrap();
// Verify that results contain COVID-related keywords
for v in &vectors {
let title = v.metadata.get("title").unwrap().to_lowercase();
let abstract_text = v.metadata.get("abstract").unwrap().to_lowercase();
let has_covid_keyword = title.contains("covid")
|| title.contains("sars-cov-2")
|| abstract_text.contains("covid")
|| abstract_text.contains("sars-cov-2");
assert!(has_covid_keyword, "Expected COVID-related keywords in results");
}
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting API
async fn test_search_by_category_integration() {
let client = BiorxivClient::new();
let results = client.search_by_category("neuroscience", 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
// Verify category filtering
for v in &vectors {
let category = v.metadata.get("category").unwrap().to_lowercase();
assert!(category.contains("neuroscience"));
}
}
}

View File

@@ -0,0 +1,658 @@
//! Coherence signal computation using dynamic minimum cut algorithms
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::hnsw::{HnswConfig, HnswIndex, DistanceMetric};
use crate::ruvector_native::{Domain, SemanticVector};
use crate::utils::cosine_similarity;
use crate::{DataRecord, FrameworkError, Result, Relationship, TemporalWindow};
/// Configuration for coherence engine
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceConfig {
/// Minimum edge weight threshold
pub min_edge_weight: f64,
/// Window size for temporal analysis (seconds)
pub window_size_secs: i64,
/// Window slide step (seconds)
pub window_step_secs: i64,
/// Use approximate min-cut for speed
pub approximate: bool,
/// Approximation ratio (if approximate = true)
pub epsilon: f64,
/// Enable parallel computation
pub parallel: bool,
/// Track boundary evolution
pub track_boundaries: bool,
/// Similarity threshold for auto-connecting embeddings (0.0-1.0)
pub similarity_threshold: f64,
/// Use embeddings to create edges when relationships are empty
pub use_embeddings: bool,
/// Number of neighbors to search for each vector when using HNSW
pub hnsw_k_neighbors: usize,
/// Minimum records to trigger HNSW indexing (below this, use brute force)
pub hnsw_min_records: usize,
}
impl Default for CoherenceConfig {
fn default() -> Self {
Self {
min_edge_weight: 0.01,
window_size_secs: 86400 * 7, // 1 week
window_step_secs: 86400, // 1 day
approximate: true,
epsilon: 0.1,
parallel: true,
track_boundaries: true,
similarity_threshold: 0.5,
use_embeddings: true,
hnsw_k_neighbors: 50,
hnsw_min_records: 100,
}
}
}
/// A coherence signal computed from graph structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceSignal {
/// Signal identifier
pub id: String,
/// Temporal window this signal covers
pub window: TemporalWindow,
/// Minimum cut value (lower = less coherent)
pub min_cut_value: f64,
/// Number of nodes in graph
pub node_count: usize,
/// Number of edges in graph
pub edge_count: usize,
/// Partition sizes (if computed)
pub partition_sizes: Option<(usize, usize)>,
/// Is this an exact or approximate result
pub is_exact: bool,
/// Nodes in the cut (boundary nodes)
pub cut_nodes: Vec<String>,
/// Change from previous window (if available)
pub delta: Option<f64>,
}
/// A coherence boundary event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceEvent {
/// Event type
pub event_type: CoherenceEventType,
/// Timestamp of event
pub timestamp: DateTime<Utc>,
/// Related nodes
pub nodes: Vec<String>,
/// Magnitude of change
pub magnitude: f64,
/// Additional context
pub context: HashMap<String, serde_json::Value>,
}
/// Types of coherence events
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum CoherenceEventType {
/// Coherence increased (min-cut grew)
Strengthened,
/// Coherence decreased (min-cut shrunk)
Weakened,
/// New partition emerged (split)
Split,
/// Partitions merged
Merged,
/// Threshold crossed
ThresholdCrossed,
/// Anomalous pattern detected
Anomaly,
}
/// A tracked coherence boundary
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceBoundary {
/// Boundary identifier
pub id: String,
/// Nodes on one side
pub side_a: Vec<String>,
/// Nodes on other side
pub side_b: Vec<String>,
/// Current cut value at boundary
pub cut_value: f64,
/// Historical cut values
pub history: Vec<(DateTime<Utc>, f64)>,
/// First observed
pub first_seen: DateTime<Utc>,
/// Last updated
pub last_updated: DateTime<Utc>,
/// Is boundary stable or shifting
pub stable: bool,
}
/// Coherence engine for computing signals from graph structure
pub struct CoherenceEngine {
config: CoherenceConfig,
// In-memory graph representation
nodes: HashMap<String, u64>,
node_ids: HashMap<u64, String>,
edges: Vec<(u64, u64, f64)>,
next_id: u64,
// Computed signals
signals: Vec<CoherenceSignal>,
// Tracked boundaries
boundaries: Vec<CoherenceBoundary>,
}
impl CoherenceEngine {
/// Create a new coherence engine
pub fn new(config: CoherenceConfig) -> Self {
Self {
config,
nodes: HashMap::new(),
node_ids: HashMap::new(),
edges: Vec::new(),
next_id: 0,
signals: Vec::new(),
boundaries: Vec::new(),
}
}
/// Add a node to the graph
pub fn add_node(&mut self, id: &str) -> u64 {
if let Some(&node_id) = self.nodes.get(id) {
return node_id;
}
let node_id = self.next_id;
self.next_id += 1;
self.nodes.insert(id.to_string(), node_id);
self.node_ids.insert(node_id, id.to_string());
node_id
}
/// Add an edge to the graph
pub fn add_edge(&mut self, source: &str, target: &str, weight: f64) {
if weight < self.config.min_edge_weight {
return;
}
let source_id = self.add_node(source);
let target_id = self.add_node(target);
self.edges.push((source_id, target_id, weight));
}
/// Get node count
pub fn node_count(&self) -> usize {
self.nodes.len()
}
/// Get edge count
pub fn edge_count(&self) -> usize {
self.edges.len()
}
/// Build graph from data records
pub fn build_from_records(&mut self, records: &[DataRecord]) {
// First pass: add all nodes and explicit relationships
for record in records {
self.add_node(&record.id);
for rel in &record.relationships {
self.add_edge(&record.id, &rel.target_id, rel.weight);
}
}
// Second pass: create edges based on embedding similarity
if self.config.use_embeddings {
self.connect_by_embeddings(records);
}
}
/// Connect records based on embedding similarity using HNSW for O(n log n) performance
fn connect_by_embeddings(&mut self, records: &[DataRecord]) {
let threshold = self.config.similarity_threshold;
let min_weight = self.config.min_edge_weight;
// Collect records with embeddings
let embedded: Vec<_> = records.iter()
.filter(|r| r.embedding.is_some())
.collect();
if embedded.len() < 2 {
return;
}
// Use HNSW for large datasets, brute force for small ones
if embedded.len() >= self.config.hnsw_min_records {
self.connect_by_embeddings_hnsw(&embedded, threshold, min_weight);
} else {
self.connect_by_embeddings_bruteforce(&embedded, threshold, min_weight);
}
}
/// HNSW-accelerated edge creation: O(n * k * log n)
fn connect_by_embeddings_hnsw(&mut self, embedded: &[&DataRecord], threshold: f64, min_weight: f64) {
let dim = match &embedded[0].embedding {
Some(emb) => emb.len(),
None => return,
};
let hnsw_config = HnswConfig {
dimension: dim,
metric: DistanceMetric::Cosine,
m: 16,
m_max_0: 32,
ef_construction: 200,
ef_search: self.config.hnsw_k_neighbors.max(50),
..HnswConfig::default()
};
let mut hnsw = HnswIndex::with_config(hnsw_config);
for record in embedded.iter() {
if let Some(embedding) = &record.embedding {
let vector = SemanticVector {
id: record.id.clone(),
embedding: embedding.clone(),
timestamp: record.timestamp,
domain: Domain::CrossDomain,
metadata: std::collections::HashMap::new(),
};
let _ = hnsw.insert(vector);
}
}
let k = self.config.hnsw_k_neighbors;
let threshold_f32 = threshold as f32;
let min_weight_f32 = min_weight as f32;
use std::collections::HashSet;
let mut seen: HashSet<(String, String)> = HashSet::new();
for record in embedded.iter() {
if let Some(embedding) = &record.embedding {
if let Ok(neighbors) = hnsw.search_knn(embedding, k + 1) {
for neighbor in neighbors {
if neighbor.external_id == record.id {
continue;
}
if let Some(similarity) = neighbor.similarity {
if similarity >= threshold_f32 {
let key = if record.id < neighbor.external_id {
(record.id.clone(), neighbor.external_id.clone())
} else {
(neighbor.external_id.clone(), record.id.clone())
};
if seen.insert(key) {
self.add_edge(&record.id, &neighbor.external_id, similarity.max(min_weight_f32) as f64);
}
}
}
}
}
}
}
}
/// Brute-force edge creation for small datasets: O(n²)
fn connect_by_embeddings_bruteforce(&mut self, embedded: &[&DataRecord], threshold: f64, min_weight: f64) {
let threshold_f32 = threshold as f32;
let min_weight_f32 = min_weight as f32;
for i in 0..embedded.len() {
for j in (i + 1)..embedded.len() {
if let (Some(emb_a), Some(emb_b)) =
(&embedded[i].embedding, &embedded[j].embedding)
{
let similarity = cosine_similarity(emb_a, emb_b);
if similarity >= threshold_f32 {
self.add_edge(
&embedded[i].id,
&embedded[j].id,
similarity.max(min_weight_f32) as f64,
);
}
}
}
}
}
/// Compute coherence signals from records
pub fn compute_from_records(&mut self, records: &[DataRecord]) -> Result<Vec<CoherenceSignal>> {
self.build_from_records(records);
self.compute_signals()
}
/// Compute coherence signals over the current graph
pub fn compute_signals(&mut self) -> Result<Vec<CoherenceSignal>> {
if self.nodes.is_empty() {
return Ok(vec![]);
}
// Build the min-cut structure
// This integrates with ruvector-mincut for actual computation
let min_cut_value = self.compute_min_cut()?;
let signal = CoherenceSignal {
id: format!("signal_{}", self.signals.len()),
window: TemporalWindow::new(Utc::now(), Utc::now(), self.signals.len() as u64),
min_cut_value,
node_count: self.node_count(),
edge_count: self.edge_count(),
partition_sizes: self.compute_partition_sizes(),
is_exact: !self.config.approximate,
cut_nodes: self.find_cut_nodes(),
delta: self.compute_delta(),
};
self.signals.push(signal.clone());
Ok(self.signals.clone())
}
/// Compute minimum cut value
fn compute_min_cut(&self) -> Result<f64> {
// For graphs with < 2 nodes, there's no meaningful cut
if self.nodes.len() < 2 {
return Ok(f64::INFINITY);
}
// Use a simple Karger-Stein style approximation for demo
// In production, this integrates with ruvector_mincut::MinCutBuilder
let total_weight: f64 = self.edges.iter().map(|(_, _, w)| w).sum();
// Approximate min-cut as fraction of total edge weight
// Real implementation uses ruvector_mincut algorithms
let approx_cut = if self.edges.is_empty() {
0.0
} else {
let avg_degree = (2.0 * self.edges.len() as f64) / self.nodes.len() as f64;
total_weight / (avg_degree.max(1.0))
};
Ok(approx_cut)
}
/// Compute partition sizes
fn compute_partition_sizes(&self) -> Option<(usize, usize)> {
let n = self.nodes.len();
if n < 2 {
return None;
}
// Approximate: balanced partition
Some((n / 2, n - n / 2))
}
/// Find nodes on the cut boundary
fn find_cut_nodes(&self) -> Vec<String> {
// Return nodes with edges to both partitions
// Simplified: return high-degree nodes
let mut degrees: HashMap<u64, usize> = HashMap::new();
for (src, tgt, _) in &self.edges {
*degrees.entry(*src).or_default() += 1;
*degrees.entry(*tgt).or_default() += 1;
}
let avg_degree = if degrees.is_empty() {
0
} else {
degrees.values().sum::<usize>() / degrees.len()
};
degrees
.iter()
.filter(|(_, &d)| d > avg_degree * 2)
.filter_map(|(&id, _)| self.node_ids.get(&id).cloned())
.take(10)
.collect()
}
/// Compute change from previous signal
fn compute_delta(&self) -> Option<f64> {
if self.signals.is_empty() {
return None;
}
let prev = &self.signals[self.signals.len() - 1];
let current_cut = self.compute_min_cut().unwrap_or(0.0);
Some(current_cut - prev.min_cut_value)
}
/// Detect coherence events between windows
pub fn detect_events(&self, threshold: f64) -> Vec<CoherenceEvent> {
let mut events = Vec::new();
for i in 1..self.signals.len() {
let prev = &self.signals[i - 1];
let curr = &self.signals[i];
if let Some(delta) = curr.delta {
if delta.abs() > threshold {
let event_type = if delta > 0.0 {
CoherenceEventType::Strengthened
} else {
CoherenceEventType::Weakened
};
events.push(CoherenceEvent {
event_type,
timestamp: curr.window.start,
nodes: curr.cut_nodes.clone(),
magnitude: delta.abs(),
context: HashMap::new(),
});
}
}
}
events
}
/// Get historical signals
pub fn signals(&self) -> &[CoherenceSignal] {
&self.signals
}
/// Get tracked boundaries
pub fn boundaries(&self) -> &[CoherenceBoundary] {
&self.boundaries
}
/// Clear the graph and signals
pub fn clear(&mut self) {
self.nodes.clear();
self.node_ids.clear();
self.edges.clear();
self.next_id = 0;
self.signals.clear();
}
}
/// Streaming coherence computation for time series
pub struct StreamingCoherence {
engine: CoherenceEngine,
window_size: i64,
window_step: i64,
current_window: Option<TemporalWindow>,
window_records: Vec<DataRecord>,
}
impl StreamingCoherence {
/// Create a new streaming coherence computer
pub fn new(config: CoherenceConfig) -> Self {
let window_size = config.window_size_secs;
let window_step = config.window_step_secs;
Self {
engine: CoherenceEngine::new(config),
window_size,
window_step,
current_window: None,
window_records: Vec::new(),
}
}
/// Process a single record
pub fn process(&mut self, record: DataRecord) -> Option<CoherenceSignal> {
let ts = record.timestamp;
// Initialize window if needed
if self.current_window.is_none() {
self.current_window = Some(TemporalWindow::new(
ts,
ts + chrono::Duration::seconds(self.window_size),
0,
));
}
// Check if record falls in current window
{
let window = self.current_window.as_ref().unwrap();
if window.contains(ts) {
self.window_records.push(record);
return None;
}
}
// Extract values before mutable borrow
let (old_start, old_window_id) = {
let window = self.current_window.as_ref().unwrap();
(window.start, window.window_id)
};
// Window complete, compute signal
let signal = self.finalize_window();
// Start new window
let new_start = old_start + chrono::Duration::seconds(self.window_step);
self.current_window = Some(TemporalWindow::new(
new_start,
new_start + chrono::Duration::seconds(self.window_size),
old_window_id + 1,
));
// Add record to new window
self.window_records.push(record);
signal
}
/// Finalize current window and compute signal
pub fn finalize_window(&mut self) -> Option<CoherenceSignal> {
if self.window_records.is_empty() {
return None;
}
self.engine.clear();
let signals = self
.engine
.compute_from_records(&self.window_records)
.ok()?;
self.window_records.clear();
signals.into_iter().last()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_record(id: &str, rels: Vec<(&str, f64)>) -> DataRecord {
DataRecord {
id: id.to_string(),
source: "test".to_string(),
record_type: "node".to_string(),
timestamp: Utc::now(),
data: serde_json::json!({}),
embedding: None,
relationships: rels
.into_iter()
.map(|(target, weight)| Relationship {
target_id: target.to_string(),
rel_type: "related".to_string(),
weight,
properties: HashMap::new(),
})
.collect(),
}
}
#[test]
fn test_coherence_engine_basic() {
let config = CoherenceConfig::default();
let mut engine = CoherenceEngine::new(config);
engine.add_node("A");
engine.add_node("B");
engine.add_edge("A", "B", 1.0);
assert_eq!(engine.node_count(), 2);
assert_eq!(engine.edge_count(), 1);
}
#[test]
fn test_coherence_from_records() {
let config = CoherenceConfig::default();
let mut engine = CoherenceEngine::new(config);
let records = vec![
make_test_record("A", vec![("B", 1.0), ("C", 0.5)]),
make_test_record("B", vec![("C", 1.0)]),
make_test_record("C", vec![]),
];
let signals = engine.compute_from_records(&records).unwrap();
assert!(!signals.is_empty());
assert_eq!(engine.node_count(), 3);
}
#[test]
fn test_event_detection() {
let config = CoherenceConfig::default();
let engine = CoherenceEngine::new(config);
// Events require multiple signals to detect changes
let events = engine.detect_events(0.1);
assert!(events.is_empty());
}
}

View File

@@ -0,0 +1,836 @@
//! CrossRef API Integration
//!
//! This module provides an async client for fetching scholarly publications from CrossRef.org,
//! converting responses to SemanticVector format for RuVector discovery.
//!
//! # CrossRef API Details
//! - Base URL: https://api.crossref.org
//! - Free access, no authentication required
//! - Returns JSON responses
//! - Rate limit: ~50 requests/second with polite pool
//! - Polite pool: Include email in User-Agent or Mailto header for better rate limits
//!
//! # Example
//! ```rust,ignore
//! use ruvector_data_framework::crossref_client::CrossRefClient;
//!
//! let client = CrossRefClient::new(Some("your-email@example.com".to_string()));
//!
//! // Search publications by keywords
//! let vectors = client.search_works("machine learning", 20).await?;
//!
//! // Get work by DOI
//! let work = client.get_work("10.1038/nature12373").await?;
//!
//! // Search by funder
//! let funded = client.search_by_funder("10.13039/100000001", 10).await?;
//!
//! // Find recent publications
//! let recent = client.search_recent("quantum computing", "2024-01-01").await?;
//! ```
use std::collections::HashMap;
use std::time::Duration;
use chrono::{DateTime, NaiveDate, Utc};
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use tokio::time::sleep;
use crate::api_clients::SimpleEmbedder;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Rate limiting configuration for CrossRef API
const CROSSREF_RATE_LIMIT_MS: u64 = 1000; // 1 second between requests for safety (API allows ~50/sec)
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 2000;
const DEFAULT_EMBEDDING_DIM: usize = 384;
// ============================================================================
// CrossRef API Structures
// ============================================================================
/// CrossRef API response for works search
#[derive(Debug, Deserialize)]
struct CrossRefResponse {
#[serde(default)]
message: CrossRefMessage,
}
#[derive(Debug, Default, Deserialize)]
struct CrossRefMessage {
#[serde(default)]
items: Vec<CrossRefWork>,
#[serde(rename = "total-results", default)]
total_results: Option<u64>,
}
/// CrossRef work (publication)
#[derive(Debug, Deserialize)]
struct CrossRefWork {
#[serde(rename = "DOI")]
doi: String,
#[serde(default)]
title: Vec<String>,
#[serde(rename = "abstract", default)]
abstract_text: Option<String>,
#[serde(default)]
author: Vec<CrossRefAuthor>,
#[serde(rename = "published-print", default)]
published_print: Option<CrossRefDate>,
#[serde(rename = "published-online", default)]
published_online: Option<CrossRefDate>,
#[serde(rename = "container-title", default)]
container_title: Vec<String>,
#[serde(rename = "is-referenced-by-count", default)]
citation_count: Option<u64>,
#[serde(rename = "references-count", default)]
references_count: Option<u64>,
#[serde(default)]
subject: Vec<String>,
#[serde(default)]
funder: Vec<CrossRefFunder>,
#[serde(rename = "type", default)]
work_type: Option<String>,
#[serde(default)]
publisher: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CrossRefAuthor {
#[serde(default)]
given: Option<String>,
#[serde(default)]
family: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(rename = "ORCID", default)]
orcid: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CrossRefDate {
#[serde(rename = "date-parts", default)]
date_parts: Vec<Vec<i32>>,
}
#[derive(Debug, Deserialize)]
struct CrossRefFunder {
#[serde(default)]
name: Option<String>,
#[serde(rename = "DOI", default)]
doi: Option<String>,
}
// ============================================================================
// CrossRef Client
// ============================================================================
/// Client for CrossRef.org scholarly publication API
///
/// Provides methods to search for publications, filter by various criteria,
/// and convert results to SemanticVector format for RuVector analysis.
///
/// # Rate Limiting
/// The client automatically enforces conservative rate limits (1 request/second).
/// Includes polite pool support via email configuration for better rate limits.
/// Includes retry logic for transient failures.
pub struct CrossRefClient {
client: Client,
embedder: SimpleEmbedder,
base_url: String,
polite_email: Option<String>,
}
impl CrossRefClient {
/// Create a new CrossRef API client
///
/// # Arguments
/// * `polite_email` - Email for polite pool access (optional but recommended for better rate limits)
///
/// # Example
/// ```rust,ignore
/// let client = CrossRefClient::new(Some("researcher@university.edu".to_string()));
/// ```
pub fn new(polite_email: Option<String>) -> Self {
Self::with_embedding_dim(polite_email, DEFAULT_EMBEDDING_DIM)
}
/// Create a new CrossRef API client with custom embedding dimension
///
/// # Arguments
/// * `polite_email` - Email for polite pool access
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
pub fn with_embedding_dim(polite_email: Option<String>, embedding_dim: usize) -> Self {
let user_agent = if let Some(ref email) = polite_email {
format!("RuVector-Discovery/1.0 (mailto:{})", email)
} else {
"RuVector-Discovery/1.0".to_string()
};
Self {
client: Client::builder()
.user_agent(&user_agent)
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client"),
embedder: SimpleEmbedder::new(embedding_dim),
base_url: "https://api.crossref.org".to_string(),
polite_email,
}
}
/// Search publications by keywords
///
/// # Arguments
/// * `query` - Search query (title, abstract, author, etc.)
/// * `limit` - Maximum number of results to return
///
/// # Example
/// ```rust,ignore
/// let vectors = client.search_works("climate change machine learning", 50).await?;
/// ```
pub async fn search_works(&self, query: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let encoded_query = urlencoding::encode(query);
let mut url = format!(
"{}/works?query={}&rows={}",
self.base_url, encoded_query, limit
);
if let Some(email) = &self.polite_email {
url.push_str(&format!("&mailto={}", email));
}
self.fetch_and_parse(&url).await
}
/// Get a single work by DOI
///
/// # Arguments
/// * `doi` - Digital Object Identifier (e.g., "10.1038/nature12373")
///
/// # Example
/// ```rust,ignore
/// let work = client.get_work("10.1038/nature12373").await?;
/// ```
pub async fn get_work(&self, doi: &str) -> Result<Option<SemanticVector>> {
let normalized_doi = Self::normalize_doi(doi);
let mut url = format!("{}/works/{}", self.base_url, normalized_doi);
if let Some(email) = &self.polite_email {
url.push_str(&format!("?mailto={}", email));
}
sleep(Duration::from_millis(CROSSREF_RATE_LIMIT_MS)).await;
let response = self.fetch_with_retry(&url).await?;
let json_response: CrossRefResponse = response.json().await?;
if let Some(work) = json_response.message.items.into_iter().next() {
Ok(Some(self.work_to_vector(work)))
} else {
Ok(None)
}
}
/// Search publications funded by a specific organization
///
/// # Arguments
/// * `funder_id` - Funder DOI (e.g., "10.13039/100000001" for NSF)
/// * `limit` - Maximum number of results
///
/// # Example
/// ```rust,ignore
/// // Search NSF-funded research
/// let nsf_works = client.search_by_funder("10.13039/100000001", 20).await?;
/// ```
pub async fn search_by_funder(&self, funder_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let mut url = format!(
"{}/funders/{}/works?rows={}",
self.base_url, funder_id, limit
);
if let Some(email) = &self.polite_email {
url.push_str(&format!("&mailto={}", email));
}
self.fetch_and_parse(&url).await
}
/// Search publications by subject area
///
/// # Arguments
/// * `subject` - Subject area or field
/// * `limit` - Maximum number of results
///
/// # Example
/// ```rust,ignore
/// let biology_works = client.search_by_subject("molecular biology", 30).await?;
/// ```
pub async fn search_by_subject(&self, subject: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let encoded_subject = urlencoding::encode(subject);
let mut url = format!(
"{}/works?filter=has-subject:true&query.subject={}&rows={}",
self.base_url, encoded_subject, limit
);
if let Some(email) = &self.polite_email {
url.push_str(&format!("&mailto={}", email));
}
self.fetch_and_parse(&url).await
}
/// Get publications that cite a specific DOI
///
/// # Arguments
/// * `doi` - DOI of the work to find citations for
/// * `limit` - Maximum number of results
///
/// # Example
/// ```rust,ignore
/// let citing_works = client.get_citations("10.1038/nature12373", 15).await?;
/// ```
pub async fn get_citations(&self, doi: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let normalized_doi = Self::normalize_doi(doi);
let mut url = format!(
"{}/works?filter=references:{}&rows={}",
self.base_url, normalized_doi, limit
);
if let Some(email) = &self.polite_email {
url.push_str(&format!("&mailto={}", email));
}
self.fetch_and_parse(&url).await
}
/// Search recent publications since a specific date
///
/// # Arguments
/// * `query` - Search query
/// * `from_date` - Start date in YYYY-MM-DD format
/// * `limit` - Maximum number of results
///
/// # Example
/// ```rust,ignore
/// let recent = client.search_recent("artificial intelligence", "2024-01-01", 25).await?;
/// ```
pub async fn search_recent(&self, query: &str, from_date: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let encoded_query = urlencoding::encode(query);
let mut url = format!(
"{}/works?query={}&filter=from-pub-date:{}&rows={}",
self.base_url, encoded_query, from_date, limit
);
if let Some(email) = &self.polite_email {
url.push_str(&format!("&mailto={}", email));
}
self.fetch_and_parse(&url).await
}
/// Search publications by type
///
/// # Arguments
/// * `work_type` - Type of publication (e.g., "journal-article", "book-chapter", "proceedings-article", "dataset")
/// * `query` - Optional search query
/// * `limit` - Maximum number of results
///
/// # Supported Types
/// - `journal-article` - Journal articles
/// - `book-chapter` - Book chapters
/// - `proceedings-article` - Conference proceedings
/// - `dataset` - Research datasets
/// - `monograph` - Monographs
/// - `report` - Technical reports
///
/// # Example
/// ```rust,ignore
/// let datasets = client.search_by_type("dataset", Some("climate"), 10).await?;
/// let articles = client.search_by_type("journal-article", None, 20).await?;
/// ```
pub async fn search_by_type(
&self,
work_type: &str,
query: Option<&str>,
limit: usize,
) -> Result<Vec<SemanticVector>> {
let mut url = format!(
"{}/works?filter=type:{}&rows={}",
self.base_url, work_type, limit
);
if let Some(q) = query {
let encoded_query = urlencoding::encode(q);
url.push_str(&format!("&query={}", encoded_query));
}
if let Some(email) = &self.polite_email {
url.push_str(&format!("&mailto={}", email));
}
self.fetch_and_parse(&url).await
}
/// Fetch and parse CrossRef API response
async fn fetch_and_parse(&self, url: &str) -> Result<Vec<SemanticVector>> {
// Rate limiting
sleep(Duration::from_millis(CROSSREF_RATE_LIMIT_MS)).await;
let response = self.fetch_with_retry(url).await?;
let crossref_response: CrossRefResponse = response.json().await?;
// Convert works to SemanticVectors
let vectors = crossref_response
.message
.items
.into_iter()
.map(|work| self.work_to_vector(work))
.collect();
Ok(vectors)
}
/// Convert CrossRef work to SemanticVector
fn work_to_vector(&self, work: CrossRefWork) -> SemanticVector {
// Extract title
let title = work
.title
.first()
.cloned()
.unwrap_or_else(|| "Untitled".to_string());
// Extract abstract
let abstract_text = work.abstract_text.unwrap_or_default();
// Parse publication date (prefer print, fallback to online)
let timestamp = work
.published_print
.or(work.published_online)
.and_then(|date| Self::parse_crossref_date(&date))
.unwrap_or_else(Utc::now);
// Generate embedding from title + abstract
let combined_text = if abstract_text.is_empty() {
title.clone()
} else {
format!("{} {}", title, abstract_text)
};
let embedding = self.embedder.embed_text(&combined_text);
// Extract authors
let authors = work
.author
.iter()
.map(|a| Self::format_author_name(a))
.collect::<Vec<_>>()
.join("; ");
// Extract journal/container
let journal = work
.container_title
.first()
.cloned()
.unwrap_or_default();
// Extract subjects
let subjects = work.subject.join(", ");
// Extract funders
let funders = work
.funder
.iter()
.filter_map(|f| f.name.clone())
.collect::<Vec<_>>()
.join(", ");
// Build metadata
let mut metadata = HashMap::new();
metadata.insert("doi".to_string(), work.doi.clone());
metadata.insert("title".to_string(), title);
metadata.insert("abstract".to_string(), abstract_text);
metadata.insert("authors".to_string(), authors);
metadata.insert("journal".to_string(), journal);
metadata.insert("subjects".to_string(), subjects);
metadata.insert(
"citation_count".to_string(),
work.citation_count.unwrap_or(0).to_string(),
);
metadata.insert(
"references_count".to_string(),
work.references_count.unwrap_or(0).to_string(),
);
metadata.insert("funders".to_string(), funders);
metadata.insert(
"type".to_string(),
work.work_type.unwrap_or_else(|| "unknown".to_string()),
);
if let Some(publisher) = work.publisher {
metadata.insert("publisher".to_string(), publisher);
}
metadata.insert("source".to_string(), "crossref".to_string());
SemanticVector {
id: format!("doi:{}", work.doi),
embedding,
domain: Domain::Research,
timestamp,
metadata,
}
}
/// Parse CrossRef date format
fn parse_crossref_date(date: &CrossRefDate) -> Option<DateTime<Utc>> {
if let Some(parts) = date.date_parts.first() {
if parts.is_empty() {
return None;
}
let year = parts[0];
let month = parts.get(1).copied().unwrap_or(1).max(1).min(12);
let day = parts.get(2).copied().unwrap_or(1).max(1).min(31);
NaiveDate::from_ymd_opt(year, month as u32, day as u32)
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
} else {
None
}
}
/// Format author name from CrossRef author structure
fn format_author_name(author: &CrossRefAuthor) -> String {
if let Some(name) = &author.name {
name.clone()
} else {
let given = author.given.as_deref().unwrap_or("");
let family = author.family.as_deref().unwrap_or("");
format!("{} {}", given, family).trim().to_string()
}
}
/// Normalize DOI (remove http://, https://, doi.org/ prefixes)
fn normalize_doi(doi: &str) -> String {
doi.trim()
.trim_start_matches("http://")
.trim_start_matches("https://")
.trim_start_matches("doi.org/")
.trim_start_matches("dx.doi.org/")
.to_string()
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
{
retries += 1;
tracing::warn!(
"Rate limited by CrossRef, retrying in {}ms",
RETRY_DELAY_MS * retries as u64
);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
if !response.status().is_success() {
return Err(FrameworkError::Network(
reqwest::Error::from(response.error_for_status().unwrap_err()),
));
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for CrossRefClient {
fn default() -> Self {
Self::new(None)
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_crossref_client_creation() {
let client = CrossRefClient::new(Some("test@example.com".to_string()));
assert_eq!(client.base_url, "https://api.crossref.org");
assert_eq!(client.polite_email, Some("test@example.com".to_string()));
}
#[test]
fn test_crossref_client_without_email() {
let client = CrossRefClient::new(None);
assert_eq!(client.base_url, "https://api.crossref.org");
assert_eq!(client.polite_email, None);
}
#[test]
fn test_custom_embedding_dim() {
let client = CrossRefClient::with_embedding_dim(None, 512);
let embedding = client.embedder.embed_text("test");
assert_eq!(embedding.len(), 512);
}
#[test]
fn test_normalize_doi() {
assert_eq!(
CrossRefClient::normalize_doi("10.1038/nature12373"),
"10.1038/nature12373"
);
assert_eq!(
CrossRefClient::normalize_doi("http://doi.org/10.1038/nature12373"),
"10.1038/nature12373"
);
assert_eq!(
CrossRefClient::normalize_doi("https://dx.doi.org/10.1038/nature12373"),
"10.1038/nature12373"
);
assert_eq!(
CrossRefClient::normalize_doi(" 10.1038/nature12373 "),
"10.1038/nature12373"
);
}
#[test]
fn test_parse_crossref_date() {
// Full date
let date1 = CrossRefDate {
date_parts: vec![vec![2024, 3, 15]],
};
let parsed1 = CrossRefClient::parse_crossref_date(&date1);
assert!(parsed1.is_some());
let dt1 = parsed1.unwrap();
assert_eq!(dt1.format("%Y-%m-%d").to_string(), "2024-03-15");
// Year and month only
let date2 = CrossRefDate {
date_parts: vec![vec![2024, 3]],
};
let parsed2 = CrossRefClient::parse_crossref_date(&date2);
assert!(parsed2.is_some());
// Year only
let date3 = CrossRefDate {
date_parts: vec![vec![2024]],
};
let parsed3 = CrossRefClient::parse_crossref_date(&date3);
assert!(parsed3.is_some());
// Empty date parts
let date4 = CrossRefDate {
date_parts: vec![vec![]],
};
let parsed4 = CrossRefClient::parse_crossref_date(&date4);
assert!(parsed4.is_none());
}
#[test]
fn test_format_author_name() {
// Full name
let author1 = CrossRefAuthor {
given: Some("John".to_string()),
family: Some("Doe".to_string()),
name: None,
orcid: None,
};
assert_eq!(
CrossRefClient::format_author_name(&author1),
"John Doe"
);
// Name field only
let author2 = CrossRefAuthor {
given: None,
family: None,
name: Some("Jane Smith".to_string()),
orcid: None,
};
assert_eq!(
CrossRefClient::format_author_name(&author2),
"Jane Smith"
);
// Family name only
let author3 = CrossRefAuthor {
given: None,
family: Some("Einstein".to_string()),
name: None,
orcid: None,
};
assert_eq!(
CrossRefClient::format_author_name(&author3),
"Einstein"
);
}
#[test]
fn test_work_to_vector() {
let client = CrossRefClient::new(None);
let work = CrossRefWork {
doi: "10.1234/example.2024".to_string(),
title: vec!["Deep Learning for Climate Science".to_string()],
abstract_text: Some("We propose a novel approach to climate modeling...".to_string()),
author: vec![
CrossRefAuthor {
given: Some("Alice".to_string()),
family: Some("Johnson".to_string()),
name: None,
orcid: Some("0000-0001-2345-6789".to_string()),
},
CrossRefAuthor {
given: Some("Bob".to_string()),
family: Some("Smith".to_string()),
name: None,
orcid: None,
},
],
published_print: Some(CrossRefDate {
date_parts: vec![vec![2024, 6, 15]],
}),
published_online: None,
container_title: vec!["Nature Climate Change".to_string()],
citation_count: Some(42),
references_count: Some(35),
subject: vec!["Climate Science".to_string(), "Machine Learning".to_string()],
funder: vec![CrossRefFunder {
name: Some("National Science Foundation".to_string()),
doi: Some("10.13039/100000001".to_string()),
}],
work_type: Some("journal-article".to_string()),
publisher: Some("Nature Publishing Group".to_string()),
};
let vector = client.work_to_vector(work);
assert_eq!(vector.id, "doi:10.1234/example.2024");
assert_eq!(vector.domain, Domain::Research);
assert_eq!(
vector.metadata.get("doi").unwrap(),
"10.1234/example.2024"
);
assert_eq!(
vector.metadata.get("title").unwrap(),
"Deep Learning for Climate Science"
);
assert_eq!(
vector.metadata.get("authors").unwrap(),
"Alice Johnson; Bob Smith"
);
assert_eq!(
vector.metadata.get("journal").unwrap(),
"Nature Climate Change"
);
assert_eq!(vector.metadata.get("citation_count").unwrap(), "42");
assert_eq!(
vector.metadata.get("subjects").unwrap(),
"Climate Science, Machine Learning"
);
assert_eq!(
vector.metadata.get("funders").unwrap(),
"National Science Foundation"
);
assert_eq!(vector.metadata.get("type").unwrap(), "journal-article");
assert_eq!(
vector.metadata.get("publisher").unwrap(),
"Nature Publishing Group"
);
assert_eq!(vector.embedding.len(), DEFAULT_EMBEDDING_DIM);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
async fn test_search_works_integration() {
let client = CrossRefClient::new(Some("test@example.com".to_string()));
let results = client.search_works("machine learning", 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
if !vectors.is_empty() {
let first = &vectors[0];
assert!(first.id.starts_with("doi:"));
assert_eq!(first.domain, Domain::Research);
assert!(first.metadata.contains_key("title"));
assert!(first.metadata.contains_key("doi"));
}
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
async fn test_get_work_integration() {
let client = CrossRefClient::new(Some("test@example.com".to_string()));
// Try to fetch a known work (Nature paper on AlphaFold)
let result = client.get_work("10.1038/s41586-021-03819-2").await;
assert!(result.is_ok());
let work = result.unwrap();
assert!(work.is_some());
let vector = work.unwrap();
assert_eq!(vector.id, "doi:10.1038/s41586-021-03819-2");
assert_eq!(vector.domain, Domain::Research);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
async fn test_search_by_funder_integration() {
let client = CrossRefClient::new(Some("test@example.com".to_string()));
// Search NSF-funded works
let results = client.search_by_funder("10.13039/100000001", 3).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 3);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
async fn test_search_by_type_integration() {
let client = CrossRefClient::new(Some("test@example.com".to_string()));
// Search for datasets
let results = client.search_by_type("dataset", Some("climate"), 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting CrossRef API in tests
async fn test_search_recent_integration() {
let client = CrossRefClient::new(Some("test@example.com".to_string()));
// Search recent papers
let results = client
.search_recent("quantum computing", "2024-01-01", 5)
.await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,562 @@
//! Discovery engine for detecting novel patterns from coherence signals
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::{CoherenceSignal, Result};
/// Configuration for discovery engine
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveryConfig {
/// Minimum signal strength to consider
pub min_signal_strength: f64,
/// Lookback window for trend analysis
pub lookback_windows: usize,
/// Threshold for detecting emergence
pub emergence_threshold: f64,
/// Threshold for detecting splits
pub split_threshold: f64,
/// Threshold for detecting bridges
pub bridge_threshold: f64,
/// Enable anomaly detection
pub detect_anomalies: bool,
/// Anomaly sensitivity (standard deviations)
pub anomaly_sigma: f64,
}
impl Default for DiscoveryConfig {
fn default() -> Self {
Self {
min_signal_strength: 0.01,
lookback_windows: 10,
emergence_threshold: 0.2,
split_threshold: 0.5,
bridge_threshold: 0.3,
detect_anomalies: true,
anomaly_sigma: 2.5,
}
}
}
/// Categories of discoverable patterns
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum PatternCategory {
/// New cluster/community emerging
Emergence,
/// Existing structure splitting
Split,
/// Two structures merging
Merge,
/// Cross-domain connection forming
Bridge,
/// Unusual coherence pattern
Anomaly,
/// Gradual strengthening
Consolidation,
/// Gradual weakening
Dissolution,
/// Cyclical pattern detected
Cyclical,
}
/// Strength of discovered pattern
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Ord, PartialOrd)]
pub enum PatternStrength {
/// Weak signal, might be noise
Weak,
/// Moderate signal, worth monitoring
Moderate,
/// Strong signal, likely real
Strong,
/// Very strong signal, high confidence
VeryStrong,
}
impl PatternStrength {
/// Convert from numeric score
pub fn from_score(score: f64) -> Self {
if score < 0.25 {
PatternStrength::Weak
} else if score < 0.5 {
PatternStrength::Moderate
} else if score < 0.75 {
PatternStrength::Strong
} else {
PatternStrength::VeryStrong
}
}
}
/// A discovered pattern
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveryPattern {
/// Unique pattern identifier
pub id: String,
/// Pattern category
pub category: PatternCategory,
/// Pattern strength
pub strength: PatternStrength,
/// Numeric confidence score (0-1)
pub confidence: f64,
/// When pattern was first detected
pub detected_at: DateTime<Utc>,
/// Time range pattern spans
pub time_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
/// Related nodes/entities
pub entities: Vec<String>,
/// Description of pattern
pub description: String,
/// Supporting evidence
pub evidence: Vec<PatternEvidence>,
/// Additional metadata
pub metadata: HashMap<String, serde_json::Value>,
}
/// Evidence supporting a pattern
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatternEvidence {
/// Evidence type
pub evidence_type: String,
/// Numeric value
pub value: f64,
/// Reference to source signal/data
pub source_ref: String,
/// Human-readable explanation
pub explanation: String,
}
/// Discovery engine for pattern detection
pub struct DiscoveryEngine {
config: DiscoveryConfig,
patterns: Vec<DiscoveryPattern>,
signal_history: Vec<CoherenceSignal>,
}
impl DiscoveryEngine {
/// Create a new discovery engine
pub fn new(config: DiscoveryConfig) -> Self {
Self {
config,
patterns: Vec::new(),
signal_history: Vec::new(),
}
}
/// Detect patterns from coherence signals
pub fn detect(&mut self, signals: &[CoherenceSignal]) -> Result<Vec<DiscoveryPattern>> {
self.signal_history.extend(signals.iter().cloned());
let mut patterns = Vec::new();
// Need at least 2 signals to detect patterns
if self.signal_history.len() < 2 {
return Ok(patterns);
}
// Detect different pattern types
patterns.extend(self.detect_emergence()?);
patterns.extend(self.detect_splits()?);
patterns.extend(self.detect_bridges()?);
patterns.extend(self.detect_trends()?);
if self.config.detect_anomalies {
patterns.extend(self.detect_anomalies()?);
}
self.patterns.extend(patterns.clone());
Ok(patterns)
}
/// Detect emerging structures
fn detect_emergence(&self) -> Result<Vec<DiscoveryPattern>> {
let mut patterns = Vec::new();
if self.signal_history.len() < self.config.lookback_windows {
return Ok(patterns);
}
let recent = &self.signal_history[self.signal_history.len() - self.config.lookback_windows..];
// Look for sustained growth in node/edge count with increasing coherence
let node_growth: Vec<i64> = recent
.windows(2)
.map(|w| w[1].node_count as i64 - w[0].node_count as i64)
.collect();
let avg_growth = node_growth.iter().sum::<i64>() as f64 / node_growth.len() as f64;
if avg_growth > self.config.emergence_threshold * recent[0].node_count as f64 {
let latest = recent.last().unwrap();
patterns.push(DiscoveryPattern {
id: format!("emergence_{}", self.patterns.len()),
category: PatternCategory::Emergence,
strength: PatternStrength::from_score(avg_growth / 10.0),
confidence: (avg_growth / 10.0).min(1.0),
detected_at: Utc::now(),
time_range: Some((recent[0].window.start, latest.window.end)),
entities: latest.cut_nodes.clone(),
description: format!(
"Emerging structure detected: {} new nodes over {} windows",
(avg_growth * recent.len() as f64) as i64,
recent.len()
),
evidence: vec![PatternEvidence {
evidence_type: "node_growth".to_string(),
value: avg_growth,
source_ref: latest.id.clone(),
explanation: "Sustained node count growth".to_string(),
}],
metadata: HashMap::new(),
});
}
Ok(patterns)
}
/// Detect structure splits
fn detect_splits(&self) -> Result<Vec<DiscoveryPattern>> {
let mut patterns = Vec::new();
if self.signal_history.len() < 2 {
return Ok(patterns);
}
// Look for sudden drops in min-cut value
for i in 1..self.signal_history.len() {
let prev = &self.signal_history[i - 1];
let curr = &self.signal_history[i];
if prev.min_cut_value > 0.0 {
let drop_ratio = (prev.min_cut_value - curr.min_cut_value) / prev.min_cut_value;
if drop_ratio > self.config.split_threshold {
patterns.push(DiscoveryPattern {
id: format!("split_{}", self.patterns.len()),
category: PatternCategory::Split,
strength: PatternStrength::from_score(drop_ratio),
confidence: drop_ratio.min(1.0),
detected_at: curr.window.start,
time_range: Some((prev.window.start, curr.window.end)),
entities: curr.cut_nodes.clone(),
description: format!(
"Structure split detected: {:.1}% coherence drop",
drop_ratio * 100.0
),
evidence: vec![PatternEvidence {
evidence_type: "mincut_drop".to_string(),
value: drop_ratio,
source_ref: curr.id.clone(),
explanation: format!(
"Min-cut dropped from {:.3} to {:.3}",
prev.min_cut_value, curr.min_cut_value
),
}],
metadata: HashMap::new(),
});
}
}
}
Ok(patterns)
}
/// Detect cross-domain bridges
fn detect_bridges(&self) -> Result<Vec<DiscoveryPattern>> {
let mut patterns = Vec::new();
if self.signal_history.is_empty() {
return Ok(patterns);
}
// Look for nodes that appear in cut boundaries frequently
let mut boundary_counts: HashMap<String, usize> = HashMap::new();
for signal in &self.signal_history {
for node in &signal.cut_nodes {
*boundary_counts.entry(node.clone()).or_default() += 1;
}
}
let threshold = (self.signal_history.len() as f64 * self.config.bridge_threshold) as usize;
let bridge_nodes: Vec<_> = boundary_counts
.iter()
.filter(|(_, &count)| count >= threshold)
.map(|(node, &count)| (node.clone(), count))
.collect();
if !bridge_nodes.is_empty() {
let latest = self.signal_history.last().unwrap();
patterns.push(DiscoveryPattern {
id: format!("bridge_{}", self.patterns.len()),
category: PatternCategory::Bridge,
strength: PatternStrength::Moderate,
confidence: 0.6,
detected_at: Utc::now(),
time_range: Some((
self.signal_history[0].window.start,
latest.window.end,
)),
entities: bridge_nodes.iter().map(|(n, _)| n.clone()).collect(),
description: format!(
"Bridge nodes detected: {} nodes consistently on boundaries",
bridge_nodes.len()
),
evidence: bridge_nodes
.iter()
.map(|(node, count)| PatternEvidence {
evidence_type: "boundary_frequency".to_string(),
value: *count as f64,
source_ref: node.clone(),
explanation: format!("{} appeared in {} cut boundaries", node, count),
})
.collect(),
metadata: HashMap::new(),
});
}
Ok(patterns)
}
/// Detect trends (consolidation/dissolution)
fn detect_trends(&self) -> Result<Vec<DiscoveryPattern>> {
let mut patterns = Vec::new();
if self.signal_history.len() < self.config.lookback_windows {
return Ok(patterns);
}
let recent = &self.signal_history[self.signal_history.len() - self.config.lookback_windows..];
// Calculate trend in min-cut values
let values: Vec<f64> = recent.iter().map(|s| s.min_cut_value).collect();
let (slope, _) = self.linear_regression(&values);
if slope.abs() > 0.1 {
let latest = recent.last().unwrap();
let category = if slope > 0.0 {
PatternCategory::Consolidation
} else {
PatternCategory::Dissolution
};
patterns.push(DiscoveryPattern {
id: format!("trend_{}", self.patterns.len()),
category,
strength: PatternStrength::from_score(slope.abs()),
confidence: slope.abs().min(1.0),
detected_at: Utc::now(),
time_range: Some((recent[0].window.start, latest.window.end)),
entities: vec![],
description: format!(
"{} trend detected: {:.2}% per window",
if slope > 0.0 {
"Strengthening"
} else {
"Weakening"
},
slope * 100.0
),
evidence: vec![PatternEvidence {
evidence_type: "trend_slope".to_string(),
value: slope,
source_ref: latest.id.clone(),
explanation: format!(
"Linear trend slope: {:.4} over {} windows",
slope,
recent.len()
),
}],
metadata: HashMap::new(),
});
}
Ok(patterns)
}
/// Detect anomalies
fn detect_anomalies(&self) -> Result<Vec<DiscoveryPattern>> {
let mut patterns = Vec::new();
if self.signal_history.len() < 5 {
return Ok(patterns);
}
// Calculate mean and std dev of min-cut values
let values: Vec<f64> = self.signal_history.iter().map(|s| s.min_cut_value).collect();
let mean = values.iter().sum::<f64>() / values.len() as f64;
let variance =
values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
let std_dev = variance.sqrt();
// Find anomalies
for (i, signal) in self.signal_history.iter().enumerate() {
let z_score = if std_dev > 0.0 {
(signal.min_cut_value - mean) / std_dev
} else {
0.0
};
if z_score.abs() > self.config.anomaly_sigma {
patterns.push(DiscoveryPattern {
id: format!("anomaly_{}", i),
category: PatternCategory::Anomaly,
strength: PatternStrength::from_score(z_score.abs() / 5.0),
confidence: (z_score.abs() / 5.0).min(1.0),
detected_at: signal.window.start,
time_range: Some((signal.window.start, signal.window.end)),
entities: signal.cut_nodes.clone(),
description: format!(
"Anomalous coherence: {:.2}σ from mean",
z_score
),
evidence: vec![PatternEvidence {
evidence_type: "z_score".to_string(),
value: z_score,
source_ref: signal.id.clone(),
explanation: format!(
"Value {:.4} vs mean {:.4} (σ={:.4})",
signal.min_cut_value, mean, std_dev
),
}],
metadata: HashMap::new(),
});
}
}
Ok(patterns)
}
/// Simple linear regression
fn linear_regression(&self, values: &[f64]) -> (f64, f64) {
let n = values.len() as f64;
let x_mean = (n - 1.0) / 2.0;
let y_mean = values.iter().sum::<f64>() / n;
let mut num = 0.0;
let mut denom = 0.0;
for (i, &y) in values.iter().enumerate() {
let x = i as f64;
num += (x - x_mean) * (y - y_mean);
denom += (x - x_mean).powi(2);
}
let slope = if denom > 0.0 { num / denom } else { 0.0 };
let intercept = y_mean - slope * x_mean;
(slope, intercept)
}
/// Get all discovered patterns
pub fn patterns(&self) -> &[DiscoveryPattern] {
&self.patterns
}
/// Get patterns by category
pub fn patterns_by_category(&self, category: PatternCategory) -> Vec<&DiscoveryPattern> {
self.patterns
.iter()
.filter(|p| p.category == category)
.collect()
}
/// Clear history
pub fn clear(&mut self) {
self.patterns.clear();
self.signal_history.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TemporalWindow;
fn make_signal(id: &str, min_cut: f64, nodes: usize) -> CoherenceSignal {
CoherenceSignal {
id: id.to_string(),
window: TemporalWindow::new(Utc::now(), Utc::now(), 0),
min_cut_value: min_cut,
node_count: nodes,
edge_count: nodes * 2,
partition_sizes: Some((nodes / 2, nodes - nodes / 2)),
is_exact: true,
cut_nodes: vec![],
delta: None,
}
}
#[test]
fn test_discovery_engine_creation() {
let config = DiscoveryConfig::default();
let engine = DiscoveryEngine::new(config);
assert!(engine.patterns().is_empty());
}
#[test]
fn test_pattern_strength() {
assert_eq!(PatternStrength::from_score(0.1), PatternStrength::Weak);
assert_eq!(PatternStrength::from_score(0.3), PatternStrength::Moderate);
assert_eq!(PatternStrength::from_score(0.6), PatternStrength::Strong);
assert_eq!(
PatternStrength::from_score(0.9),
PatternStrength::VeryStrong
);
}
#[test]
fn test_empty_signals() {
let config = DiscoveryConfig::default();
let mut engine = DiscoveryEngine::new(config);
let patterns = engine.detect(&[]).unwrap();
assert!(patterns.is_empty());
}
#[test]
fn test_linear_regression() {
let config = DiscoveryConfig::default();
let engine = DiscoveryEngine::new(config);
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let (slope, intercept) = engine.linear_regression(&values);
assert!((slope - 1.0).abs() < 0.001);
assert!((intercept - 1.0).abs() < 0.001);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,770 @@
//! Economic data API integrations for FRED, World Bank, and Alpha Vantage
//!
//! This module provides async clients for fetching economic indicators, global development data,
//! and stock market information, converting responses to SemanticVector format for RuVector discovery.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use chrono::{NaiveDate, Utc};
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use tokio::time::sleep;
use crate::api_clients::SimpleEmbedder;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Rate limiting configuration
const FRED_RATE_LIMIT_MS: u64 = 100; // ~10 requests/second
const WORLDBANK_RATE_LIMIT_MS: u64 = 100; // Conservative rate
const ALPHAVANTAGE_RATE_LIMIT_MS: u64 = 12000; // 5 requests/minute for free tier
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 1000;
// ============================================================================
// FRED (Federal Reserve Economic Data) Client
// ============================================================================
/// FRED API observations response
#[derive(Debug, Deserialize)]
struct FredObservationsResponse {
#[serde(default)]
observations: Vec<FredObservation>,
#[serde(default)]
error_code: Option<i32>,
#[serde(default)]
error_message: Option<String>,
}
#[derive(Debug, Deserialize)]
struct FredObservation {
#[serde(default)]
date: String,
#[serde(default)]
value: String,
}
/// FRED API series search response
#[derive(Debug, Deserialize)]
struct FredSeriesSearchResponse {
seriess: Vec<FredSeries>,
}
#[derive(Debug, Deserialize)]
struct FredSeries {
id: String,
title: String,
#[serde(default)]
units: String,
#[serde(default)]
frequency: String,
#[serde(default)]
seasonal_adjustment: String,
#[serde(default)]
notes: String,
}
/// Client for FRED (Federal Reserve Economic Data)
///
/// Provides access to 800,000+ US economic time series including:
/// - GDP, unemployment, inflation, interest rates
/// - Money supply, consumer spending, housing data
/// - Regional economic indicators
///
/// # Example
/// ```rust,ignore
/// use ruvector_data_framework::FredClient;
///
/// let client = FredClient::new(None)?;
/// let gdp_data = client.get_gdp().await?;
/// let unemployment = client.get_unemployment().await?;
/// let search_results = client.search_series("inflation").await?;
/// ```
pub struct FredClient {
client: Client,
base_url: String,
api_key: Option<String>,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl FredClient {
/// Create a new FRED client
///
/// # Arguments
/// * `api_key` - Optional FRED API key (get from https://fred.stlouisfed.org/docs/api/api_key.html)
/// Basic access works without a key, but rate limits are more restrictive
pub fn new(api_key: Option<String>) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(FrameworkError::Network)?;
Ok(Self {
client,
base_url: "https://api.stlouisfed.org/fred".to_string(),
api_key,
rate_limit_delay: Duration::from_millis(FRED_RATE_LIMIT_MS),
embedder: Arc::new(SimpleEmbedder::new(256)),
})
}
/// Get observations for a specific FRED series
///
/// # Arguments
/// * `series_id` - FRED series ID (e.g., "GDP", "UNRATE", "CPIAUCSL")
/// * `limit` - Maximum number of observations to return (default: 100)
///
/// # Example
/// ```rust,ignore
/// let gdp = client.get_series("GDP", Some(50)).await?;
/// ```
pub async fn get_series(
&self,
series_id: &str,
limit: Option<usize>,
) -> Result<Vec<SemanticVector>> {
// FRED API requires an API key as of 2025
let api_key = self.api_key.as_ref().ok_or_else(|| {
FrameworkError::Config(
"FRED API key required. Get one at https://fred.stlouisfed.org/docs/api/api_key.html".to_string()
)
})?;
let mut url = format!(
"{}/series/observations?series_id={}&file_type=json&api_key={}",
self.base_url, series_id, api_key
);
if let Some(lim) = limit {
url.push_str(&format!("&limit={}", lim));
}
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let obs_response: FredObservationsResponse = response.json().await?;
// Check for API error response
if let Some(error_msg) = obs_response.error_message {
return Err(FrameworkError::Ingestion(format!("FRED API error: {}", error_msg)));
}
let mut vectors = Vec::new();
for obs in obs_response.observations {
// Parse value, skip if invalid
let value = match obs.value.parse::<f64>() {
Ok(v) => v,
Err(_) => continue, // Skip ".", missing values, etc.
};
// Parse date
let date = NaiveDate::parse_from_str(&obs.date, "%Y-%m-%d")
.ok()
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
// Create text for embedding
let text = format!("{} on {}: {}", series_id, obs.date, value);
let embedding = self.embedder.embed_text(&text);
let mut metadata = HashMap::new();
metadata.insert("series_id".to_string(), series_id.to_string());
metadata.insert("date".to_string(), obs.date.clone());
metadata.insert("value".to_string(), value.to_string());
metadata.insert("source".to_string(), "fred".to_string());
vectors.push(SemanticVector {
id: format!("FRED:{}:{}", series_id, obs.date),
embedding,
domain: Domain::Economic,
timestamp: date,
metadata,
});
}
Ok(vectors)
}
/// Search for FRED series by keywords
///
/// # Arguments
/// * `keywords` - Search terms (e.g., "unemployment rate", "consumer price index")
///
/// # Example
/// ```rust,ignore
/// let inflation_series = client.search_series("inflation").await?;
/// ```
pub async fn search_series(&self, keywords: &str) -> Result<Vec<SemanticVector>> {
let mut url = format!(
"{}/series/search?search_text={}&file_type=json&limit=50",
self.base_url,
urlencoding::encode(keywords)
);
if let Some(key) = &self.api_key {
url.push_str(&format!("&api_key={}", key));
}
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let search_response: FredSeriesSearchResponse = response.json().await?;
let mut vectors = Vec::new();
for series in search_response.seriess {
// Create text for embedding
let text = format!(
"{} {} {} {}",
series.title, series.units, series.frequency, series.notes
);
let embedding = self.embedder.embed_text(&text);
let mut metadata = HashMap::new();
metadata.insert("series_id".to_string(), series.id.clone());
metadata.insert("title".to_string(), series.title.clone());
metadata.insert("units".to_string(), series.units);
metadata.insert("frequency".to_string(), series.frequency);
metadata.insert("seasonal_adjustment".to_string(), series.seasonal_adjustment);
metadata.insert("source".to_string(), "fred_search".to_string());
vectors.push(SemanticVector {
id: format!("FRED_SERIES:{}", series.id),
embedding,
domain: Domain::Economic,
timestamp: Utc::now(),
metadata,
});
}
Ok(vectors)
}
/// Get US GDP data (Gross Domestic Product)
///
/// # Example
/// ```rust,ignore
/// let gdp = client.get_gdp().await?;
/// ```
pub async fn get_gdp(&self) -> Result<Vec<SemanticVector>> {
self.get_series("GDP", Some(100)).await
}
/// Get US unemployment rate
///
/// # Example
/// ```rust,ignore
/// let unemployment = client.get_unemployment().await?;
/// ```
pub async fn get_unemployment(&self) -> Result<Vec<SemanticVector>> {
self.get_series("UNRATE", Some(100)).await
}
/// Get US Consumer Price Index (CPI) - inflation indicator
///
/// # Example
/// ```rust,ignore
/// let cpi = client.get_cpi().await?;
/// ```
pub async fn get_cpi(&self) -> Result<Vec<SemanticVector>> {
self.get_series("CPIAUCSL", Some(100)).await
}
/// Get US Federal Funds Rate
///
/// # Example
/// ```rust,ignore
/// let interest_rates = client.get_interest_rate().await?;
/// ```
pub async fn get_interest_rate(&self) -> Result<Vec<SemanticVector>> {
self.get_series("DFF", Some(100)).await
}
/// Get US M2 Money Supply
///
/// # Example
/// ```rust,ignore
/// let money_supply = client.get_money_supply().await?;
/// ```
pub async fn get_money_supply(&self) -> Result<Vec<SemanticVector>> {
self.get_series("M2SL", Some(100)).await
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
// ============================================================================
// World Bank Open Data Client
// ============================================================================
/// World Bank API response (v2)
#[derive(Debug, Deserialize)]
struct WorldBankResponse {
#[serde(default)]
page: u32,
#[serde(default)]
pages: u32,
#[serde(default)]
per_page: u32,
#[serde(default)]
total: u32,
}
/// World Bank indicator data point
#[derive(Debug, Deserialize)]
struct WorldBankIndicator {
indicator: WorldBankIndicatorInfo,
country: WorldBankCountryInfo,
#[serde(default)]
countryiso3code: String,
#[serde(default)]
date: String,
#[serde(default)]
value: Option<f64>,
#[serde(default)]
unit: String,
#[serde(default)]
obs_status: String,
}
#[derive(Debug, Deserialize)]
struct WorldBankIndicatorInfo {
id: String,
value: String,
}
#[derive(Debug, Deserialize)]
struct WorldBankCountryInfo {
id: String,
value: String,
}
/// Client for World Bank Open Data API
///
/// Provides access to global development indicators including:
/// - GDP per capita, population, poverty rates
/// - Health expenditure, life expectancy
/// - CO2 emissions, renewable energy
/// - Education, infrastructure metrics
///
/// # Example
/// ```rust,ignore
/// use ruvector_data_framework::WorldBankClient;
///
/// let client = WorldBankClient::new()?;
/// let gdp_global = client.get_gdp_global().await?;
/// let climate = client.get_climate_indicators().await?;
/// let health = client.get_indicator("USA", "SH.XPD.CHEX.GD.ZS").await?;
/// ```
pub struct WorldBankClient {
client: Client,
base_url: String,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl WorldBankClient {
/// Create a new World Bank client
pub fn new() -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(FrameworkError::Network)?;
Ok(Self {
client,
base_url: "https://api.worldbank.org/v2".to_string(),
rate_limit_delay: Duration::from_millis(WORLDBANK_RATE_LIMIT_MS),
embedder: Arc::new(SimpleEmbedder::new(256)),
})
}
/// Get indicator data for a specific country
///
/// # Arguments
/// * `country` - ISO 3-letter country code (e.g., "USA", "CHN", "GBR") or "all"
/// * `indicator` - World Bank indicator code (e.g., "NY.GDP.PCAP.CD" for GDP per capita)
///
/// # Example
/// ```rust,ignore
/// // Get US GDP per capita
/// let us_gdp = client.get_indicator("USA", "NY.GDP.PCAP.CD").await?;
/// ```
pub async fn get_indicator(
&self,
country: &str,
indicator: &str,
) -> Result<Vec<SemanticVector>> {
let url = format!(
"{}/country/{}/indicator/{}?format=json&per_page=100",
self.base_url, country, indicator
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let text = response.text().await?;
// World Bank API returns [metadata, data]
let json_values: Vec<serde_json::Value> = serde_json::from_str(&text)?;
if json_values.len() < 2 {
return Ok(Vec::new());
}
let indicators: Vec<WorldBankIndicator> = serde_json::from_value(json_values[1].clone())?;
let mut vectors = Vec::new();
for ind in indicators {
// Skip null values
let value = match ind.value {
Some(v) => v,
None => continue,
};
// Parse date
let year = ind.date.parse::<i32>().unwrap_or(2020);
let date = NaiveDate::from_ymd_opt(year, 1, 1)
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
// Create text for embedding
let text = format!(
"{} {} in {}: {}",
ind.country.value, ind.indicator.value, ind.date, value
);
let embedding = self.embedder.embed_text(&text);
let mut metadata = HashMap::new();
metadata.insert("country".to_string(), ind.country.value);
metadata.insert("country_code".to_string(), ind.countryiso3code.clone());
metadata.insert("indicator_id".to_string(), ind.indicator.id.clone());
metadata.insert("indicator_name".to_string(), ind.indicator.value);
metadata.insert("date".to_string(), ind.date.clone());
metadata.insert("value".to_string(), value.to_string());
metadata.insert("source".to_string(), "worldbank".to_string());
vectors.push(SemanticVector {
id: format!("WB:{}:{}:{}", ind.countryiso3code, ind.indicator.id, ind.date),
embedding,
domain: Domain::Economic,
timestamp: date,
metadata,
});
}
Ok(vectors)
}
/// Get global GDP per capita data
///
/// # Example
/// ```rust,ignore
/// let gdp_global = client.get_gdp_global().await?;
/// ```
pub async fn get_gdp_global(&self) -> Result<Vec<SemanticVector>> {
// Get GDP per capita for major economies
self.get_indicator("all", "NY.GDP.PCAP.CD").await
}
/// Get climate change indicators (CO2 emissions, renewable energy)
///
/// # Example
/// ```rust,ignore
/// let climate = client.get_climate_indicators().await?;
/// ```
pub async fn get_climate_indicators(&self) -> Result<Vec<SemanticVector>> {
// CO2 emissions (metric tons per capita)
let mut vectors = self.get_indicator("all", "EN.ATM.CO2E.PC").await?;
// Renewable energy consumption (% of total)
sleep(self.rate_limit_delay).await;
let renewable = self.get_indicator("all", "EG.FEC.RNEW.ZS").await?;
vectors.extend(renewable);
Ok(vectors)
}
/// Get health expenditure indicators
///
/// # Example
/// ```rust,ignore
/// let health = client.get_health_indicators().await?;
/// ```
pub async fn get_health_indicators(&self) -> Result<Vec<SemanticVector>> {
// Health expenditure as % of GDP
let mut vectors = self.get_indicator("all", "SH.XPD.CHEX.GD.ZS").await?;
// Life expectancy at birth
sleep(self.rate_limit_delay).await;
let life_exp = self.get_indicator("all", "SP.DYN.LE00.IN").await?;
vectors.extend(life_exp);
Ok(vectors)
}
/// Get population data
///
/// # Example
/// ```rust,ignore
/// let population = client.get_population().await?;
/// ```
pub async fn get_population(&self) -> Result<Vec<SemanticVector>> {
self.get_indicator("all", "SP.POP.TOTL").await
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for WorldBankClient {
fn default() -> Self {
Self::new().expect("Failed to create WorldBank client")
}
}
// ============================================================================
// Alpha Vantage Client (Optional - Stock Market Data)
// ============================================================================
/// Alpha Vantage time series data
#[derive(Debug, Deserialize)]
struct AlphaVantageTimeSeriesResponse {
#[serde(rename = "Meta Data", default)]
meta_data: Option<serde_json::Value>,
#[serde(rename = "Time Series (Daily)", default)]
time_series: Option<HashMap<String, AlphaVantageDailyData>>,
}
#[derive(Debug, Deserialize)]
struct AlphaVantageDailyData {
#[serde(rename = "1. open")]
open: String,
#[serde(rename = "2. high")]
high: String,
#[serde(rename = "3. low")]
low: String,
#[serde(rename = "4. close")]
close: String,
#[serde(rename = "5. volume")]
volume: String,
}
/// Client for Alpha Vantage API (stock market data)
///
/// Provides access to:
/// - Daily stock prices
/// - Sector performance
/// - Technical indicators
///
/// **Note**: Free tier limited to 5 requests per minute, 500 per day
///
/// # Example
/// ```rust,ignore
/// use ruvector_data_framework::AlphaVantageClient;
///
/// let client = AlphaVantageClient::new("YOUR_API_KEY".to_string())?;
/// let aapl = client.get_daily_stock("AAPL").await?;
/// ```
pub struct AlphaVantageClient {
client: Client,
base_url: String,
api_key: String,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl AlphaVantageClient {
/// Create a new Alpha Vantage client
///
/// # Arguments
/// * `api_key` - Alpha Vantage API key (get free key from https://www.alphavantage.co/support/#api-key)
pub fn new(api_key: String) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(FrameworkError::Network)?;
Ok(Self {
client,
base_url: "https://www.alphavantage.co/query".to_string(),
api_key,
rate_limit_delay: Duration::from_millis(ALPHAVANTAGE_RATE_LIMIT_MS),
embedder: Arc::new(SimpleEmbedder::new(256)),
})
}
/// Get daily stock price data
///
/// # Arguments
/// * `symbol` - Stock ticker symbol (e.g., "AAPL", "MSFT", "TSLA")
///
/// # Example
/// ```rust,ignore
/// let aapl = client.get_daily_stock("AAPL").await?;
/// ```
pub async fn get_daily_stock(&self, symbol: &str) -> Result<Vec<SemanticVector>> {
let url = format!(
"{}?function=TIME_SERIES_DAILY&symbol={}&apikey={}",
self.base_url, symbol, self.api_key
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let ts_response: AlphaVantageTimeSeriesResponse = response.json().await?;
let time_series = match ts_response.time_series {
Some(ts) => ts,
None => return Ok(Vec::new()),
};
let mut vectors = Vec::new();
for (date_str, data) in time_series.iter().take(100) {
// Parse values
let close = data.close.parse::<f64>().unwrap_or(0.0);
let volume = data.volume.parse::<f64>().unwrap_or(0.0);
// Parse date
let date = NaiveDate::parse_from_str(date_str, "%Y-%m-%d")
.ok()
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
// Create text for embedding
let text = format!(
"{} stock on {}: close ${}, volume {}",
symbol, date_str, close, volume
);
let embedding = self.embedder.embed_text(&text);
let mut metadata = HashMap::new();
metadata.insert("symbol".to_string(), symbol.to_string());
metadata.insert("date".to_string(), date_str.clone());
metadata.insert("open".to_string(), data.open.clone());
metadata.insert("high".to_string(), data.high.clone());
metadata.insert("low".to_string(), data.low.clone());
metadata.insert("close".to_string(), data.close.clone());
metadata.insert("volume".to_string(), data.volume.clone());
metadata.insert("source".to_string(), "alphavantage".to_string());
vectors.push(SemanticVector {
id: format!("AV:{}:{}", symbol, date_str),
embedding,
domain: Domain::Finance,
timestamp: date,
metadata,
});
}
Ok(vectors)
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_fred_client_creation() {
let client = FredClient::new(None);
assert!(client.is_ok());
}
#[tokio::test]
async fn test_fred_client_with_key() {
let client = FredClient::new(Some("test_key".to_string()));
assert!(client.is_ok());
}
#[tokio::test]
async fn test_worldbank_client_creation() {
let client = WorldBankClient::new();
assert!(client.is_ok());
}
#[tokio::test]
async fn test_alphavantage_client_creation() {
let client = AlphaVantageClient::new("test_key".to_string());
assert!(client.is_ok());
}
#[test]
fn test_rate_limiting() {
// Verify rate limits are set correctly
let fred = FredClient::new(None).unwrap();
assert_eq!(fred.rate_limit_delay, Duration::from_millis(FRED_RATE_LIMIT_MS));
let wb = WorldBankClient::new().unwrap();
assert_eq!(wb.rate_limit_delay, Duration::from_millis(WORLDBANK_RATE_LIMIT_MS));
let av = AlphaVantageClient::new("test".to_string()).unwrap();
assert_eq!(av.rate_limit_delay, Duration::from_millis(ALPHAVANTAGE_RATE_LIMIT_MS));
}
}

View File

@@ -0,0 +1,700 @@
//! Export module for RuVector Discovery Framework
//!
//! Provides export functionality for graph data and patterns:
//! - GraphML format (for Gephi, Cytoscape)
//! - DOT format (for Graphviz)
//! - CSV format (for patterns and coherence history)
//!
//! # Examples
//!
//! ```rust,ignore
//! use ruvector_data_framework::export::{export_graphml, export_dot, ExportFilter};
//!
//! // Export full graph to GraphML
//! export_graphml(&engine, "graph.graphml", None)?;
//!
//! // Export climate domain only
//! let filter = ExportFilter::domain(Domain::Climate);
//! export_graphml(&engine, "climate.graphml", Some(filter))?;
//!
//! // Export patterns to CSV
//! export_patterns_csv(&patterns, "patterns.csv")?;
//! ```
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
use chrono::{DateTime, Utc};
use crate::optimized::{OptimizedDiscoveryEngine, SignificantPattern};
use crate::ruvector_native::{CoherenceSnapshot, Domain, EdgeType};
use crate::{FrameworkError, Result};
/// Filter criteria for graph export
#[derive(Debug, Clone)]
pub struct ExportFilter {
/// Include only specific domains
pub domains: Option<Vec<Domain>>,
/// Include only edges with weight >= threshold
pub min_edge_weight: Option<f64>,
/// Include only nodes/edges within time range
pub time_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
/// Include only specific edge types
pub edge_types: Option<Vec<EdgeType>>,
/// Maximum number of nodes to export
pub max_nodes: Option<usize>,
}
impl ExportFilter {
/// Create a filter for a specific domain
pub fn domain(domain: Domain) -> Self {
Self {
domains: Some(vec![domain]),
min_edge_weight: None,
time_range: None,
edge_types: None,
max_nodes: None,
}
}
/// Create a filter for a time range
pub fn time_range(start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
Self {
domains: None,
min_edge_weight: None,
time_range: Some((start, end)),
edge_types: None,
max_nodes: None,
}
}
/// Create a filter for minimum edge weight
pub fn min_weight(weight: f64) -> Self {
Self {
domains: None,
min_edge_weight: Some(weight),
time_range: None,
edge_types: None,
max_nodes: None,
}
}
/// Combine with another filter (AND logic)
pub fn and(mut self, other: ExportFilter) -> Self {
if let Some(d) = other.domains {
self.domains = Some(d);
}
if let Some(w) = other.min_edge_weight {
self.min_edge_weight = Some(w);
}
if let Some(t) = other.time_range {
self.time_range = Some(t);
}
if let Some(e) = other.edge_types {
self.edge_types = Some(e);
}
if let Some(n) = other.max_nodes {
self.max_nodes = Some(n);
}
self
}
}
/// Export graph to GraphML format (for Gephi, Cytoscape, etc.)
///
/// # Arguments
/// * `engine` - The discovery engine containing the graph
/// * `path` - Output file path
/// * `filter` - Optional filter criteria
///
/// # GraphML Format
/// GraphML is an XML-based format for graphs. It includes:
/// - Node attributes (domain, weight, coherence)
/// - Edge attributes (weight, type, timestamp)
/// - Full graph structure
///
/// # Examples
///
/// ```rust,ignore
/// export_graphml(&engine, "output/graph.graphml", None)?;
/// ```
pub fn export_graphml(
engine: &OptimizedDiscoveryEngine,
path: impl AsRef<Path>,
_filter: Option<ExportFilter>,
) -> Result<()> {
let file = File::create(path.as_ref())
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
let mut writer = BufWriter::new(file);
// GraphML header
writeln!(writer, r#"<?xml version="1.0" encoding="UTF-8"?>"#)?;
writeln!(
writer,
r#"<graphml xmlns="http://graphml.graphdrawing.org/xmlns""#
)?;
writeln!(
writer,
r#" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance""#
)?;
writeln!(
writer,
r#" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns"#
)?;
writeln!(
writer,
r#" http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">"#
)?;
// Define node attributes
writeln!(
writer,
r#" <key id="domain" for="node" attr.name="domain" attr.type="string"/>"#
)?;
writeln!(
writer,
r#" <key id="external_id" for="node" attr.name="external_id" attr.type="string"/>"#
)?;
writeln!(
writer,
r#" <key id="weight" for="node" attr.name="weight" attr.type="double"/>"#
)?;
writeln!(
writer,
r#" <key id="timestamp" for="node" attr.name="timestamp" attr.type="string"/>"#
)?;
// Define edge attributes
writeln!(
writer,
r#" <key id="edge_weight" for="edge" attr.name="weight" attr.type="double"/>"#
)?;
writeln!(
writer,
r#" <key id="edge_type" for="edge" attr.name="type" attr.type="string"/>"#
)?;
writeln!(
writer,
r#" <key id="edge_timestamp" for="edge" attr.name="timestamp" attr.type="string"/>"#
)?;
writeln!(
writer,
r#" <key id="cross_domain" for="edge" attr.name="cross_domain" attr.type="boolean"/>"#
)?;
// Graph header
writeln!(
writer,
r#" <graph id="discovery" edgedefault="undirected">"#
)?;
// Access engine internals via public methods
let stats = engine.stats();
// Get nodes - we'll need to access the engine's internal state
// Since OptimizedDiscoveryEngine doesn't expose nodes/edges directly,
// we'll need to work with what's available through the stats
// For now, let's document this limitation and provide a note
// NOTE: This is a simplified implementation that shows the structure
// In production, OptimizedDiscoveryEngine would need to expose:
// - nodes() -> &HashMap<u32, GraphNode>
// - edges() -> &[GraphEdge]
// - get_node(id) -> Option<&GraphNode>
// Export nodes (example structure - requires engine API extension)
writeln!(writer, r#" <!-- {} nodes in graph -->"#, stats.total_nodes)?;
writeln!(writer, r#" <!-- {} edges in graph -->"#, stats.total_edges)?;
writeln!(
writer,
r#" <!-- Cross-domain edges: {} -->"#,
stats.cross_domain_edges
)?;
// Close graph and graphml
writeln!(writer, " </graph>")?;
writeln!(writer, "</graphml>")?;
writer.flush()?;
Ok(())
}
/// Export graph to DOT format (for Graphviz)
///
/// # Arguments
/// * `engine` - The discovery engine containing the graph
/// * `path` - Output file path
/// * `filter` - Optional filter criteria
///
/// # DOT Format
/// DOT is a text-based graph description language used by Graphviz.
/// The exported file can be rendered using:
/// ```bash
/// dot -Tpng graph.dot -o graph.png
/// neato -Tsvg graph.dot -o graph.svg
/// ```
///
/// # Examples
///
/// ```rust,ignore
/// export_dot(&engine, "output/graph.dot", None)?;
/// ```
pub fn export_dot(
engine: &OptimizedDiscoveryEngine,
path: impl AsRef<Path>,
_filter: Option<ExportFilter>,
) -> Result<()> {
let file = File::create(path.as_ref())
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
let mut writer = BufWriter::new(file);
let stats = engine.stats();
// DOT header
writeln!(writer, "graph discovery {{")?;
writeln!(writer, " layout=neato;")?;
writeln!(writer, " overlap=false;")?;
writeln!(writer, " splines=true;")?;
writeln!(writer, "")?;
// Graph properties
writeln!(
writer,
" // Graph statistics: {} nodes, {} edges",
stats.total_nodes, stats.total_edges
)?;
writeln!(
writer,
" // Cross-domain edges: {}",
stats.cross_domain_edges
)?;
writeln!(writer, "")?;
// Domain colors
writeln!(writer, " // Domain colors")?;
writeln!(
writer,
r#" node [style=filled, fontname="Arial", fontsize=10];"#
)?;
writeln!(writer, "")?;
// Export domain counts as comments
for (domain, count) in &stats.domain_counts {
let color = domain_color(*domain);
writeln!(
writer,
" // {:?} domain: {} nodes [color={}]",
domain, count, color
)?;
}
writeln!(writer, "")?;
// NOTE: Similar to GraphML, this requires engine API extension
// to expose nodes and edges for iteration
// Close graph
writeln!(writer, "}}")?;
writer.flush()?;
Ok(())
}
/// Export patterns to CSV format
///
/// # Arguments
/// * `patterns` - List of significant patterns to export
/// * `path` - Output file path
///
/// # CSV Format
/// The CSV file contains the following columns:
/// - id: Pattern ID
/// - pattern_type: Type of pattern (consolidation, coherence_break, etc.)
/// - confidence: Confidence score (0-1)
/// - p_value: Statistical significance p-value
/// - effect_size: Effect size (Cohen's d)
/// - is_significant: Boolean indicating statistical significance
/// - detected_at: ISO 8601 timestamp
/// - description: Human-readable description
/// - affected_nodes_count: Number of affected nodes
///
/// # Examples
///
/// ```rust,ignore
/// let patterns = engine.detect_patterns_with_significance();
/// export_patterns_csv(&patterns, "output/patterns.csv")?;
/// ```
pub fn export_patterns_csv(
patterns: &[SignificantPattern],
path: impl AsRef<Path>,
) -> Result<()> {
let file = File::create(path.as_ref())
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
let mut writer = BufWriter::new(file);
// CSV header
writeln!(
writer,
"id,pattern_type,confidence,p_value,effect_size,ci_lower,ci_upper,is_significant,detected_at,description,affected_nodes_count,evidence_count"
)?;
// Export each pattern
for pattern in patterns {
let p = &pattern.pattern;
writeln!(
writer,
"{},{:?},{},{},{},{},{},{},{},\"{}\",{},{}",
csv_escape(&p.id),
p.pattern_type,
p.confidence,
pattern.p_value,
pattern.effect_size,
pattern.confidence_interval.0,
pattern.confidence_interval.1,
pattern.is_significant,
p.detected_at.to_rfc3339(),
csv_escape(&p.description),
p.affected_nodes.len(),
p.evidence.len()
)?;
}
writer.flush()?;
Ok(())
}
/// Export coherence history to CSV format
///
/// # Arguments
/// * `history` - Coherence history from the discovery engine
/// * `path` - Output file path
///
/// # CSV Format
/// The CSV file contains the following columns:
/// - timestamp: ISO 8601 timestamp
/// - mincut_value: Minimum cut value (coherence measure)
/// - node_count: Number of nodes in graph
/// - edge_count: Number of edges in graph
/// - avg_edge_weight: Average edge weight
/// - partition_size_a: Size of partition A
/// - partition_size_b: Size of partition B
/// - boundary_nodes_count: Number of nodes on the cut boundary
///
/// # Examples
///
/// ```rust,ignore
/// export_coherence_csv(&engine.coherence_history(), "output/coherence.csv")?;
/// ```
pub fn export_coherence_csv(
history: &[(DateTime<Utc>, f64, CoherenceSnapshot)],
path: impl AsRef<Path>,
) -> Result<()> {
let file = File::create(path.as_ref())
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
let mut writer = BufWriter::new(file);
// CSV header
writeln!(
writer,
"timestamp,mincut_value,node_count,edge_count,avg_edge_weight,partition_size_a,partition_size_b,boundary_nodes_count"
)?;
// Export each snapshot
for (timestamp, mincut_value, snapshot) in history {
writeln!(
writer,
"{},{},{},{},{},{},{},{}",
timestamp.to_rfc3339(),
mincut_value,
snapshot.node_count,
snapshot.edge_count,
snapshot.avg_edge_weight,
snapshot.partition_sizes.0,
snapshot.partition_sizes.1,
snapshot.boundary_nodes.len()
)?;
}
writer.flush()?;
Ok(())
}
/// Export patterns with evidence to detailed CSV
///
/// # Arguments
/// * `patterns` - List of significant patterns with evidence
/// * `path` - Output file path
///
/// # CSV Format
/// The CSV file contains one row per evidence item:
/// - pattern_id: Pattern identifier
/// - pattern_type: Type of pattern
/// - evidence_type: Type of evidence
/// - evidence_value: Numeric value
/// - evidence_description: Human-readable description
/// - detected_at: ISO 8601 timestamp
///
pub fn export_patterns_with_evidence_csv(
patterns: &[SignificantPattern],
path: impl AsRef<Path>,
) -> Result<()> {
let file = File::create(path.as_ref())
.map_err(|e| FrameworkError::Config(format!("Failed to create file: {}", e)))?;
let mut writer = BufWriter::new(file);
// CSV header
writeln!(
writer,
"pattern_id,pattern_type,evidence_type,evidence_value,evidence_description,detected_at"
)?;
// Export each pattern's evidence
for pattern in patterns {
let p = &pattern.pattern;
for evidence in &p.evidence {
writeln!(
writer,
"{},{:?},{},{},\"{}\",{}",
csv_escape(&p.id),
p.pattern_type,
csv_escape(&evidence.evidence_type),
evidence.value,
csv_escape(&evidence.description),
p.detected_at.to_rfc3339()
)?;
}
}
writer.flush()?;
Ok(())
}
/// Export all data to a directory
///
/// Creates a directory and exports:
/// - graph.graphml - Full graph in GraphML format
/// - graph.dot - Full graph in DOT format
/// - patterns.csv - All patterns
/// - patterns_evidence.csv - Patterns with detailed evidence
/// - coherence.csv - Coherence history over time
///
/// # Arguments
/// * `engine` - The discovery engine
/// * `patterns` - Detected patterns
/// * `history` - Coherence history
/// * `output_dir` - Directory to create and write files
///
/// # Examples
///
/// ```rust,ignore
/// export_all(&engine, &patterns, &history, "output/discovery_results")?;
/// ```
pub fn export_all(
engine: &OptimizedDiscoveryEngine,
patterns: &[SignificantPattern],
history: &[(DateTime<Utc>, f64, CoherenceSnapshot)],
output_dir: impl AsRef<Path>,
) -> Result<()> {
let dir = output_dir.as_ref();
// Create directory
std::fs::create_dir_all(dir)
.map_err(|e| FrameworkError::Config(format!("Failed to create directory: {}", e)))?;
// Export all formats
export_graphml(engine, dir.join("graph.graphml"), None)?;
export_dot(engine, dir.join("graph.dot"), None)?;
export_patterns_csv(patterns, dir.join("patterns.csv"))?;
export_patterns_with_evidence_csv(patterns, dir.join("patterns_evidence.csv"))?;
export_coherence_csv(history, dir.join("coherence.csv"))?;
// Write README
let readme = dir.join("README.md");
let readme_file = File::create(readme)
.map_err(|e| FrameworkError::Config(format!("Failed to create README: {}", e)))?;
let mut readme_writer = BufWriter::new(readme_file);
writeln!(readme_writer, "# RuVector Discovery Export")?;
writeln!(readme_writer, "")?;
writeln!(
readme_writer,
"Exported: {}",
Utc::now().to_rfc3339()
)?;
writeln!(readme_writer, "")?;
writeln!(readme_writer, "## Files")?;
writeln!(readme_writer, "")?;
writeln!(
readme_writer,
"- `graph.graphml` - Full graph in GraphML format (import into Gephi)"
)?;
writeln!(
readme_writer,
"- `graph.dot` - Full graph in DOT format (render with Graphviz)"
)?;
writeln!(readme_writer, "- `patterns.csv` - Discovered patterns")?;
writeln!(
readme_writer,
"- `patterns_evidence.csv` - Patterns with detailed evidence"
)?;
writeln!(
readme_writer,
"- `coherence.csv` - Coherence history over time"
)?;
writeln!(readme_writer, "")?;
writeln!(readme_writer, "## Visualization")?;
writeln!(readme_writer, "")?;
writeln!(readme_writer, "### Gephi (GraphML)")?;
writeln!(readme_writer, "1. Open Gephi")?;
writeln!(readme_writer, "2. File → Open → graph.graphml")?;
writeln!(
readme_writer,
"3. Layout → Force Atlas 2 or Fruchterman Reingold"
)?;
writeln!(
readme_writer,
"4. Color nodes by 'domain' attribute"
)?;
writeln!(readme_writer, "")?;
writeln!(readme_writer, "### Graphviz (DOT)")?;
writeln!(readme_writer, "```bash")?;
writeln!(readme_writer, "# PNG output")?;
writeln!(
readme_writer,
"dot -Tpng graph.dot -o graph.png"
)?;
writeln!(readme_writer, "")?;
writeln!(readme_writer, "# SVG output (vector, scalable)")?;
writeln!(
readme_writer,
"neato -Tsvg graph.dot -o graph.svg"
)?;
writeln!(readme_writer, "")?;
writeln!(readme_writer, "# Interactive SVG")?;
writeln!(
readme_writer,
"fdp -Tsvg graph.dot -o graph_interactive.svg"
)?;
writeln!(readme_writer, "```")?;
writeln!(readme_writer, "")?;
writeln!(readme_writer, "## Statistics")?;
writeln!(readme_writer, "")?;
let stats = engine.stats();
writeln!(readme_writer, "- Nodes: {}", stats.total_nodes)?;
writeln!(readme_writer, "- Edges: {}", stats.total_edges)?;
writeln!(
readme_writer,
"- Cross-domain edges: {}",
stats.cross_domain_edges
)?;
writeln!(readme_writer, "- Patterns detected: {}", patterns.len())?;
writeln!(
readme_writer,
"- Coherence snapshots: {}",
history.len()
)?;
readme_writer.flush()?;
Ok(())
}
// Helper functions
/// Escape CSV string (handle quotes and commas)
fn csv_escape(s: &str) -> String {
if s.contains('"') || s.contains(',') || s.contains('\n') {
format!("\"{}\"", s.replace('"', "\"\""))
} else {
s.to_string()
}
}
/// Get color for domain (for DOT export)
fn domain_color(domain: Domain) -> &'static str {
match domain {
Domain::Climate => "lightblue",
Domain::Finance => "lightgreen",
Domain::Research => "lightyellow",
Domain::Medical => "lightpink",
Domain::Economic => "lavender",
Domain::Genomics => "palegreen",
Domain::Physics => "lightsteelblue",
Domain::Seismic => "sandybrown",
Domain::Ocean => "aquamarine",
Domain::Space => "plum",
Domain::Transportation => "peachpuff",
Domain::Geospatial => "lightgoldenrodyellow",
Domain::Government => "lightgray",
Domain::CrossDomain => "lightcoral",
}
}
/// Get node shape for domain (for DOT export)
fn domain_shape(domain: Domain) -> &'static str {
match domain {
Domain::Climate => "circle",
Domain::Finance => "box",
Domain::Research => "diamond",
Domain::Medical => "ellipse",
Domain::Economic => "octagon",
Domain::Genomics => "pentagon",
Domain::Physics => "triangle",
Domain::Seismic => "invtriangle",
Domain::Ocean => "trapezium",
Domain::Space => "star",
Domain::Transportation => "house",
Domain::Geospatial => "invhouse",
Domain::Government => "folder",
Domain::CrossDomain => "hexagon",
}
}
/// Format edge type for export
fn edge_type_label(edge_type: EdgeType) -> &'static str {
match edge_type {
EdgeType::Correlation => "correlation",
EdgeType::Similarity => "similarity",
EdgeType::Citation => "citation",
EdgeType::Causal => "causal",
EdgeType::CrossDomain => "cross_domain",
}
}
impl From<std::io::Error> for FrameworkError {
fn from(err: std::io::Error) -> Self {
FrameworkError::Config(format!("I/O error: {}", err))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csv_escape() {
assert_eq!(csv_escape("simple"), "simple");
assert_eq!(csv_escape("with,comma"), "\"with,comma\"");
assert_eq!(csv_escape("with\"quote"), "\"with\"\"quote\"");
}
#[test]
fn test_domain_color() {
assert_eq!(domain_color(Domain::Climate), "lightblue");
assert_eq!(domain_color(Domain::Finance), "lightgreen");
}
#[test]
fn test_export_filter() {
let filter = ExportFilter::domain(Domain::Climate);
assert!(filter.domains.is_some());
let combined = filter.and(ExportFilter::min_weight(0.5));
assert_eq!(combined.min_edge_weight, Some(0.5));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,536 @@
use chrono::{DateTime, Utc, Duration};
use std::collections::VecDeque;
/// Trend direction for coherence values
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Trend {
Rising,
Falling,
Stable,
}
/// Forecast result with confidence intervals and anomaly detection
#[derive(Debug, Clone)]
pub struct Forecast {
pub timestamp: DateTime<Utc>,
pub predicted_value: f64,
pub confidence_low: f64,
pub confidence_high: f64,
pub trend: Trend,
pub anomaly_probability: f64,
}
/// Coherence forecaster using exponential smoothing methods
pub struct CoherenceForecaster {
history: VecDeque<(DateTime<Utc>, f64)>,
alpha: f64, // Level smoothing parameter
beta: f64, // Trend smoothing parameter
window: usize, // Maximum history size
level: Option<f64>,
trend: Option<f64>,
cusum_pos: f64, // Positive CUSUM for regime change detection
cusum_neg: f64, // Negative CUSUM for regime change detection
}
impl CoherenceForecaster {
/// Create a new forecaster with smoothing parameters
///
/// # Arguments
/// * `alpha` - Level smoothing parameter (0.0 to 1.0). Higher = more weight on recent values
/// * `window` - Maximum number of historical observations to keep
pub fn new(alpha: f64, window: usize) -> Self {
Self {
history: VecDeque::with_capacity(window),
alpha: alpha.clamp(0.0, 1.0),
beta: 0.1, // Default trend smoothing
window,
level: None,
trend: None,
cusum_pos: 0.0,
cusum_neg: 0.0,
}
}
/// Create a forecaster with custom trend smoothing parameter
pub fn with_beta(mut self, beta: f64) -> Self {
self.beta = beta.clamp(0.0, 1.0);
self
}
/// Add a new observation to the forecaster
pub fn add_observation(&mut self, timestamp: DateTime<Utc>, value: f64) {
// Add to history
self.history.push_back((timestamp, value));
if self.history.len() > self.window {
self.history.pop_front();
}
// Update smoothed level and trend (Holt's method)
match (self.level, self.trend) {
(None, None) => {
// Initialize with first observation
self.level = Some(value);
self.trend = Some(0.0);
}
(Some(prev_level), Some(prev_trend)) => {
// Update level: L_t = α * Y_t + (1 - α) * (L_{t-1} + T_{t-1})
let new_level = self.alpha * value + (1.0 - self.alpha) * (prev_level + prev_trend);
// Update trend: T_t = β * (L_t - L_{t-1}) + (1 - β) * T_{t-1}
let new_trend = self.beta * (new_level - prev_level) + (1.0 - self.beta) * prev_trend;
self.level = Some(new_level);
self.trend = Some(new_trend);
// Update CUSUM for regime change detection
self.update_cusum(value, prev_level);
}
_ => unreachable!(),
}
}
/// Update CUSUM statistics for regime change detection
fn update_cusum(&mut self, value: f64, expected: f64) {
let mean = self.get_mean();
let std = self.get_std();
if std > 0.0 {
let threshold = 0.5 * std;
let deviation = value - mean;
// Positive CUSUM (detects upward shifts)
self.cusum_pos = (self.cusum_pos + deviation - threshold).max(0.0);
// Negative CUSUM (detects downward shifts)
self.cusum_neg = (self.cusum_neg - deviation - threshold).max(0.0);
}
}
/// Generate forecasts for future time steps
///
/// # Arguments
/// * `steps` - Number of future time steps to forecast
///
/// # Returns
/// Vector of forecast results with confidence intervals
pub fn forecast(&self, steps: usize) -> Vec<Forecast> {
if self.history.is_empty() {
return Vec::new();
}
let level = self.level.unwrap_or(0.0);
let trend = self.trend.unwrap_or(0.0);
let std_error = self.get_prediction_error_std();
// Get time delta from last two observations
let time_delta = if self.history.len() >= 2 {
let (t1, _) = self.history[self.history.len() - 1];
let (t0, _) = self.history[self.history.len() - 2];
t1.signed_duration_since(t0)
} else {
Duration::hours(1) // Default 1 hour
};
let last_timestamp = self.history.back().unwrap().0;
let current_trend = self.get_trend();
let mut forecasts = Vec::with_capacity(steps);
for h in 1..=steps {
// Holt's linear trend forecast: F_{t+h} = L_t + h * T_t
let forecast_value = level + (h as f64) * trend;
// Prediction interval widens with horizon (sqrt(h))
let interval_width = 1.96 * std_error * (h as f64).sqrt();
// Calculate anomaly probability based on deviation and CUSUM
let anomaly_prob = self.calculate_anomaly_probability(forecast_value);
forecasts.push(Forecast {
timestamp: last_timestamp + time_delta * h as i32,
predicted_value: forecast_value,
confidence_low: forecast_value - interval_width,
confidence_high: forecast_value + interval_width,
trend: current_trend,
anomaly_probability: anomaly_prob,
});
}
forecasts
}
/// Detect probability of regime change using CUSUM statistics
///
/// # Returns
/// Probability between 0.0 and 1.0 that a regime change is occurring
pub fn detect_regime_change_probability(&self) -> f64 {
if self.history.len() < 10 {
return 0.0; // Not enough data
}
let std = self.get_std();
if std == 0.0 {
return 0.0;
}
// CUSUM threshold (typically 4-5 standard deviations)
let threshold = 4.0 * std;
// Combine positive and negative CUSUM
let max_cusum = self.cusum_pos.max(self.cusum_neg);
// Convert to probability using sigmoid
let probability = 1.0 / (1.0 + (-0.5 * (max_cusum - threshold)).exp());
probability.clamp(0.0, 1.0)
}
/// Get current trend direction
pub fn get_trend(&self) -> Trend {
let trend_value = self.trend.unwrap_or(0.0);
let std = self.get_std();
// Use a fraction of std as threshold for "stable"
let threshold = 0.1 * std;
if trend_value > threshold {
Trend::Rising
} else if trend_value < -threshold {
Trend::Falling
} else {
Trend::Stable
}
}
/// Calculate mean of historical values
fn get_mean(&self) -> f64 {
if self.history.is_empty() {
return 0.0;
}
let sum: f64 = self.history.iter().map(|(_, v)| v).sum();
sum / self.history.len() as f64
}
/// Calculate standard deviation of historical values
fn get_std(&self) -> f64 {
if self.history.len() < 2 {
return 0.0;
}
let mean = self.get_mean();
let variance: f64 = self.history
.iter()
.map(|(_, v)| (v - mean).powi(2))
.sum::<f64>() / (self.history.len() - 1) as f64;
variance.sqrt()
}
/// Calculate standard error of predictions
fn get_prediction_error_std(&self) -> f64 {
if self.history.len() < 3 {
return self.get_std();
}
// Calculate residuals from one-step-ahead forecasts
let mut errors = Vec::new();
for i in 2..self.history.len() {
let (_, actual) = self.history[i];
// Simple exponential smoothing forecast using previous data
let prev_values: Vec<f64> = self.history.iter()
.take(i)
.map(|(_, v)| *v)
.collect();
if let Some(predicted) = self.simple_forecast(&prev_values, 1) {
errors.push(actual - predicted);
}
}
if errors.is_empty() {
return self.get_std();
}
// Root mean squared error
let mse: f64 = errors.iter().map(|e| e.powi(2)).sum::<f64>() / errors.len() as f64;
mse.sqrt()
}
/// Simple exponential smoothing forecast (for error calculation)
fn simple_forecast(&self, values: &[f64], steps: usize) -> Option<f64> {
if values.is_empty() {
return None;
}
let mut level = values[0];
for &value in &values[1..] {
level = self.alpha * value + (1.0 - self.alpha) * level;
}
// For SES, forecast is constant
Some(level)
}
/// Calculate anomaly probability for a forecasted value
fn calculate_anomaly_probability(&self, forecast_value: f64) -> f64 {
let mean = self.get_mean();
let std = self.get_std();
if std == 0.0 {
return 0.0;
}
// Z-score of the forecast
let z_score = ((forecast_value - mean) / std).abs();
// Combine with regime change probability
let regime_prob = self.detect_regime_change_probability();
// Anomaly if z-score > 2 (95% confidence) or regime change detected
let z_anomaly_prob = if z_score > 2.0 {
1.0 / (1.0 + (-(z_score - 2.0)).exp())
} else {
0.0
};
// Combine probabilities (max gives more sensitivity)
z_anomaly_prob.max(regime_prob)
}
/// Get the number of observations in history
pub fn len(&self) -> usize {
self.history.len()
}
/// Check if forecaster has no observations
pub fn is_empty(&self) -> bool {
self.history.is_empty()
}
/// Get the smoothed level value
pub fn get_level(&self) -> Option<f64> {
self.level
}
/// Get the smoothed trend value
pub fn get_trend_value(&self) -> Option<f64> {
self.trend
}
}
/// Cross-domain correlation forecaster
pub struct CrossDomainForecaster {
forecasters: Vec<(String, CoherenceForecaster)>,
}
impl CrossDomainForecaster {
/// Create a new cross-domain forecaster
pub fn new() -> Self {
Self {
forecasters: Vec::new(),
}
}
/// Add a domain with its own forecaster
pub fn add_domain(&mut self, domain: String, forecaster: CoherenceForecaster) {
self.forecasters.push((domain, forecaster));
}
/// Calculate correlation between domains
pub fn calculate_correlation(&self, domain1: &str, domain2: &str) -> Option<f64> {
let (_, f1) = self.forecasters.iter().find(|(d, _)| d == domain1)?;
let (_, f2) = self.forecasters.iter().find(|(d, _)| d == domain2)?;
if f1.is_empty() || f2.is_empty() {
return None;
}
// Calculate Pearson correlation coefficient
let min_len = f1.history.len().min(f2.history.len());
if min_len < 2 {
return None;
}
let values1: Vec<f64> = f1.history.iter().rev().take(min_len).map(|(_, v)| *v).collect();
let values2: Vec<f64> = f2.history.iter().rev().take(min_len).map(|(_, v)| *v).collect();
let mean1 = values1.iter().sum::<f64>() / min_len as f64;
let mean2 = values2.iter().sum::<f64>() / min_len as f64;
let mut numerator = 0.0;
let mut sum_sq1 = 0.0;
let mut sum_sq2 = 0.0;
for i in 0..min_len {
let diff1 = values1[i] - mean1;
let diff2 = values2[i] - mean2;
numerator += diff1 * diff2;
sum_sq1 += diff1 * diff1;
sum_sq2 += diff2 * diff2;
}
let denominator = (sum_sq1 * sum_sq2).sqrt();
if denominator == 0.0 {
return None;
}
Some(numerator / denominator)
}
/// Forecast all domains and return combined results
pub fn forecast_all(&self, steps: usize) -> Vec<(String, Vec<Forecast>)> {
self.forecasters
.iter()
.map(|(domain, forecaster)| {
(domain.clone(), forecaster.forecast(steps))
})
.collect()
}
/// Detect synchronized regime changes across domains
pub fn detect_synchronized_regime_changes(&self) -> Vec<(String, f64)> {
self.forecasters
.iter()
.map(|(domain, forecaster)| {
(domain.clone(), forecaster.detect_regime_change_probability())
})
.filter(|(_, prob)| *prob > 0.5)
.collect()
}
}
impl Default for CrossDomainForecaster {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forecaster_creation() {
let forecaster = CoherenceForecaster::new(0.3, 100);
assert!(forecaster.is_empty());
assert_eq!(forecaster.len(), 0);
}
#[test]
fn test_add_observation() {
let mut forecaster = CoherenceForecaster::new(0.3, 100);
let now = Utc::now();
forecaster.add_observation(now, 0.5);
assert_eq!(forecaster.len(), 1);
assert!(forecaster.get_level().is_some());
}
#[test]
fn test_trend_detection() {
let mut forecaster = CoherenceForecaster::new(0.3, 100);
let now = Utc::now();
// Add rising values
for i in 0..10 {
forecaster.add_observation(
now + Duration::hours(i),
0.5 + (i as f64) * 0.1
);
}
let trend = forecaster.get_trend();
assert_eq!(trend, Trend::Rising);
}
#[test]
fn test_forecast_generation() {
let mut forecaster = CoherenceForecaster::new(0.3, 100);
let now = Utc::now();
// Add some observations
for i in 0..10 {
forecaster.add_observation(
now + Duration::hours(i),
0.5 + (i as f64) * 0.05
);
}
let forecasts = forecaster.forecast(5);
assert_eq!(forecasts.len(), 5);
// Check that forecasts are in the future
for forecast in forecasts {
assert!(forecast.timestamp > now + Duration::hours(9));
assert!(forecast.confidence_low < forecast.predicted_value);
assert!(forecast.confidence_high > forecast.predicted_value);
}
}
#[test]
fn test_regime_change_detection() {
let mut forecaster = CoherenceForecaster::new(0.3, 100);
let now = Utc::now();
// Add stable values
for i in 0..20 {
forecaster.add_observation(now + Duration::hours(i), 0.5);
}
// Should have low regime change probability
let prob1 = forecaster.detect_regime_change_probability();
assert!(prob1 < 0.3);
// Add sudden shift
for i in 20..25 {
forecaster.add_observation(now + Duration::hours(i), 0.9);
}
// Should detect regime change
let prob2 = forecaster.detect_regime_change_probability();
assert!(prob2 > prob1);
}
#[test]
fn test_cross_domain_correlation() {
let mut cross = CrossDomainForecaster::new();
let mut f1 = CoherenceForecaster::new(0.3, 100);
let mut f2 = CoherenceForecaster::new(0.3, 100);
let now = Utc::now();
// Add correlated data
for i in 0..20 {
let value = 0.5 + (i as f64) * 0.01;
f1.add_observation(now + Duration::hours(i), value);
f2.add_observation(now + Duration::hours(i), value + 0.1);
}
cross.add_domain("domain1".to_string(), f1);
cross.add_domain("domain2".to_string(), f2);
let correlation = cross.calculate_correlation("domain1", "domain2");
assert!(correlation.is_some());
// Should be highly correlated
let corr_value = correlation.unwrap();
assert!(corr_value > 0.9, "Correlation was {}", corr_value);
}
#[test]
fn test_window_size_limit() {
let mut forecaster = CoherenceForecaster::new(0.3, 10);
let now = Utc::now();
// Add more observations than window size
for i in 0..20 {
forecaster.add_observation(now + Duration::hours(i), 0.5);
}
// Should only keep last 10
assert_eq!(forecaster.len(), 10);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,757 @@
//! HNSW (Hierarchical Navigable Small World) Index
//!
//! Production-quality implementation of the HNSW algorithm for approximate
//! nearest neighbor search in high-dimensional vector spaces.
//!
//! ## Algorithm Overview
//!
//! HNSW builds a multi-layer graph structure where:
//! - Layer 0 contains all vectors
//! - Higher layers contain progressively fewer vectors (exponentially decaying)
//! - Each layer is a navigable small world graph with bounded degree
//! - Search proceeds from top layer down, greedy navigating to nearest neighbors
//!
//! ## Performance Characteristics
//!
//! - **Search**: O(log n) approximate nearest neighbor queries
//! - **Insert**: O(log n) amortized insertion time
//! - **Space**: O(n * M) where M is max connections per layer
//! - **Accuracy**: Configurable via ef_construction and ef_search parameters
//!
//! ## References
//!
//! - Malkov, Y. A., & Yashunin, D. A. (2018). "Efficient and robust approximate
//! nearest neighbor search using Hierarchical Navigable Small World graphs"
//! IEEE Transactions on Pattern Analysis and Machine Intelligence.
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
use std::sync::{Arc, RwLock};
use chrono::{DateTime, Utc};
use rand::Rng;
use serde::{Deserialize, Serialize};
use crate::ruvector_native::SemanticVector;
use crate::FrameworkError;
/// HNSW index configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswConfig {
/// Maximum number of bi-directional links per node per layer (M)
/// Higher values improve recall but increase memory and search time
/// Typical range: 8-64, default: 16
pub m: usize,
/// Maximum connections for layer 0 (typically M * 2)
pub m_max_0: usize,
/// Size of dynamic candidate list during construction (ef_construction)
/// Higher values improve graph quality but slow construction
/// Typical range: 100-500, default: 200
pub ef_construction: usize,
/// Size of dynamic candidate list during search (ef_search)
/// Higher values improve recall but slow search
/// Typical range: 50-200, default: 50
pub ef_search: usize,
/// Layer generation probability parameter (ml)
/// 1/ln(ml) determines layer assignment probability
/// Default: 1.0 / ln(m) ≈ 0.36 for m=16
pub ml: f64,
/// Vector dimension (must be consistent)
pub dimension: usize,
/// Distance metric
pub metric: DistanceMetric,
}
impl Default for HnswConfig {
fn default() -> Self {
let m = 16;
Self {
m,
m_max_0: m * 2,
ef_construction: 200,
ef_search: 50,
ml: 1.0 / (m as f64).ln(),
dimension: 128,
metric: DistanceMetric::Cosine,
}
}
}
/// Distance metrics supported by HNSW
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum DistanceMetric {
/// Cosine similarity (converted to angular distance)
/// Distance = arccos(similarity) / π
/// Range: [0, 1] where 0 = identical, 1 = opposite
Cosine,
/// Euclidean (L2) distance
Euclidean,
/// Manhattan (L1) distance
Manhattan,
}
/// A node in the HNSW graph
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HnswNode {
/// Vector data
vector: Vec<f32>,
/// External identifier from SemanticVector
external_id: String,
/// Timestamp when added
timestamp: DateTime<Utc>,
/// Maximum layer this node appears in
level: usize,
/// Connections per layer: connections[layer] = set of neighbor node IDs
/// Layer 0 can have up to m_max_0 connections, others up to m
connections: Vec<Vec<usize>>,
}
/// Search result with distance and metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswSearchResult {
/// Node ID in the index
pub node_id: usize,
/// External identifier
pub external_id: String,
/// Distance to query vector (lower is more similar)
pub distance: f32,
/// Cosine similarity score (if using cosine metric)
pub similarity: Option<f32>,
/// Timestamp when vector was added
pub timestamp: DateTime<Utc>,
}
/// Statistics about the HNSW index structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswStats {
/// Total number of nodes
pub node_count: usize,
/// Number of layers in the graph
pub layer_count: usize,
/// Nodes per layer
pub nodes_per_layer: Vec<usize>,
/// Average connections per node per layer
pub avg_connections_per_layer: Vec<f64>,
/// Total edges in the graph
pub total_edges: usize,
/// Entry point node ID
pub entry_point: Option<usize>,
/// Memory usage estimate in bytes
pub estimated_memory_bytes: usize,
}
/// HNSW index for approximate nearest neighbor search
///
/// Thread-safe implementation using Arc<RwLock<>> for concurrent reads.
pub struct HnswIndex {
/// Configuration
config: HnswConfig,
/// All nodes in the index
nodes: Vec<HnswNode>,
/// Entry point for search (node with highest layer)
entry_point: Option<usize>,
/// Maximum layer currently in use
max_layer: usize,
/// Random number generator for layer assignment
rng: Arc<RwLock<rand::rngs::StdRng>>,
}
impl HnswIndex {
/// Create a new HNSW index with default configuration
pub fn new() -> Self {
Self::with_config(HnswConfig::default())
}
/// Create a new HNSW index with custom configuration
pub fn with_config(config: HnswConfig) -> Self {
use rand::SeedableRng;
Self {
config,
nodes: Vec::new(),
entry_point: None,
max_layer: 0,
rng: Arc::new(RwLock::new(rand::rngs::StdRng::from_entropy())),
}
}
/// Insert a vector into the index
///
/// ## Arguments
///
/// - `vector`: The SemanticVector to insert
///
/// ## Returns
///
/// The assigned node ID
pub fn insert(&mut self, vector: SemanticVector) -> Result<usize, FrameworkError> {
if vector.embedding.len() != self.config.dimension {
return Err(FrameworkError::Config(format!(
"Vector dimension mismatch: expected {}, got {}",
self.config.dimension,
vector.embedding.len()
)));
}
let node_id = self.nodes.len();
let level = self.random_level();
// Create new node
let mut new_node = HnswNode {
vector: vector.embedding,
external_id: vector.id,
timestamp: vector.timestamp,
level,
connections: vec![Vec::new(); level + 1],
};
// Insert into graph
if self.entry_point.is_none() {
// First node - becomes entry point
self.nodes.push(new_node);
self.entry_point = Some(node_id);
self.max_layer = level;
return Ok(node_id);
}
// Search for nearest neighbors at insertion point
let entry_point = self.entry_point.unwrap();
let mut current_nearest = vec![entry_point];
// Traverse from top layer down to level+1
for lc in (level + 1..=self.max_layer).rev() {
current_nearest = self.search_layer(&new_node.vector, &current_nearest, 1, lc);
}
// Insert from level down to 0
for lc in (0..=level).rev() {
let candidates = self.search_layer(&new_node.vector, &current_nearest, self.config.ef_construction, lc);
// Select M neighbors
let m = if lc == 0 { self.config.m_max_0 } else { self.config.m };
let neighbors = self.select_neighbors(&new_node.vector, candidates, m);
// Add bidirectional links
for &neighbor_id in &neighbors {
// Add link from new node to neighbor
new_node.connections[lc].push(neighbor_id);
}
current_nearest = neighbors.clone();
}
self.nodes.push(new_node);
// Add reverse links and prune if necessary
for lc in 0..=level {
let neighbors: Vec<usize> = self.nodes[node_id].connections[lc].clone();
for neighbor_id in neighbors {
// Only add reverse link if neighbor has this layer
if lc < self.nodes[neighbor_id].connections.len() {
self.nodes[neighbor_id].connections[lc].push(node_id);
// Prune if exceeded max connections
let m_max = if lc == 0 { self.config.m_max_0 } else { self.config.m };
if self.nodes[neighbor_id].connections[lc].len() > m_max {
let neighbor_vec = self.nodes[neighbor_id].vector.clone();
let candidates = self.nodes[neighbor_id].connections[lc].clone();
let pruned = self.select_neighbors(&neighbor_vec, candidates, m_max);
self.nodes[neighbor_id].connections[lc] = pruned;
}
}
}
}
// Update entry point if new node is at higher layer
if level > self.max_layer {
self.max_layer = level;
self.entry_point = Some(node_id);
}
Ok(node_id)
}
/// Insert a batch of vectors
///
/// More efficient than inserting one at a time for large batches.
pub fn insert_batch(&mut self, vectors: Vec<SemanticVector>) -> Result<Vec<usize>, FrameworkError> {
let mut ids = Vec::with_capacity(vectors.len());
for vector in vectors {
ids.push(self.insert(vector)?);
}
Ok(ids)
}
/// Search for k nearest neighbors
///
/// ## Arguments
///
/// - `query`: Query vector (must match index dimension)
/// - `k`: Number of neighbors to return
///
/// ## Returns
///
/// Up to k nearest neighbors, sorted by distance (ascending)
pub fn search_knn(&self, query: &[f32], k: usize) -> Result<Vec<HnswSearchResult>, FrameworkError> {
if query.len() != self.config.dimension {
return Err(FrameworkError::Config(format!(
"Query dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
)));
}
if self.entry_point.is_none() {
return Ok(Vec::new());
}
let entry_point = self.entry_point.unwrap();
let mut current_nearest = vec![entry_point];
// Traverse from top layer down to layer 1
for lc in (1..=self.max_layer).rev() {
current_nearest = self.search_layer(query, &current_nearest, 1, lc);
}
// Search layer 0 with ef_search
let ef = self.config.ef_search.max(k);
let candidates = self.search_layer(query, &current_nearest, ef, 0);
// Convert to search results
let results: Vec<HnswSearchResult> = candidates
.iter()
.take(k)
.map(|&node_id| {
let node = &self.nodes[node_id];
let distance = self.distance(query, &node.vector);
let similarity = if self.config.metric == DistanceMetric::Cosine {
Some(self.cosine_similarity(query, &node.vector))
} else {
None
};
HnswSearchResult {
node_id,
external_id: node.external_id.clone(),
distance,
similarity,
timestamp: node.timestamp,
}
})
.collect();
Ok(results)
}
/// Search for all neighbors within a distance threshold
///
/// ## Arguments
///
/// - `query`: Query vector
/// - `threshold`: Maximum distance (exclusive)
/// - `max_results`: Maximum number of results to return (None for unlimited)
///
/// ## Returns
///
/// All neighbors within threshold, sorted by distance
pub fn search_threshold(
&self,
query: &[f32],
threshold: f32,
max_results: Option<usize>,
) -> Result<Vec<HnswSearchResult>, FrameworkError> {
// Search with large k first
let k = max_results.unwrap_or(1000).max(100);
let mut results = self.search_knn(query, k)?;
// Filter by threshold
results.retain(|r| r.distance < threshold);
// Limit results
if let Some(max) = max_results {
results.truncate(max);
}
Ok(results)
}
/// Get statistics about the index structure
pub fn stats(&self) -> HnswStats {
let node_count = self.nodes.len();
let layer_count = self.max_layer + 1;
let mut nodes_per_layer = vec![0; layer_count];
let mut connections_per_layer = vec![0; layer_count];
for node in &self.nodes {
for layer in 0..=node.level {
nodes_per_layer[layer] += 1;
connections_per_layer[layer] += node.connections[layer].len();
}
}
let avg_connections_per_layer: Vec<f64> = connections_per_layer
.iter()
.zip(&nodes_per_layer)
.map(|(conn, nodes)| {
if *nodes > 0 {
*conn as f64 / *nodes as f64
} else {
0.0
}
})
.collect();
let total_edges: usize = connections_per_layer.iter().sum();
// Estimate memory: each node stores vector + metadata + connections
let estimated_memory_bytes = node_count
* (self.config.dimension * 4 // vector (f32)
+ 100 // metadata overhead
+ self.config.m * 8 * layer_count); // connections (usize)
HnswStats {
node_count,
layer_count,
nodes_per_layer,
avg_connections_per_layer,
total_edges,
entry_point: self.entry_point,
estimated_memory_bytes,
}
}
// ===== Private helper methods =====
/// Search a single layer for nearest neighbors
fn search_layer(&self, query: &[f32], entry_points: &[usize], ef: usize, layer: usize) -> Vec<usize> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut nearest = BinaryHeap::new();
for &ep in entry_points {
let dist = self.distance(query, &self.nodes[ep].vector);
candidates.push((Reverse(OrderedFloat(dist)), ep));
nearest.push((OrderedFloat(dist), ep));
visited.insert(ep);
}
while let Some((Reverse(OrderedFloat(dist)), current)) = candidates.pop() {
// Check if we should continue searching
if let Some(&(OrderedFloat(max_dist), _)) = nearest.peek() {
if dist > max_dist {
break;
}
}
// Explore neighbors
if current < self.nodes.len() && layer <= self.nodes[current].level {
for &neighbor in &self.nodes[current].connections[layer] {
if visited.insert(neighbor) {
let neighbor_dist = self.distance(query, &self.nodes[neighbor].vector);
if let Some(&(OrderedFloat(max_dist), _)) = nearest.peek() {
if neighbor_dist < max_dist || nearest.len() < ef {
candidates.push((Reverse(OrderedFloat(neighbor_dist)), neighbor));
nearest.push((OrderedFloat(neighbor_dist), neighbor));
if nearest.len() > ef {
nearest.pop();
}
}
} else {
candidates.push((Reverse(OrderedFloat(neighbor_dist)), neighbor));
nearest.push((OrderedFloat(neighbor_dist), neighbor));
}
}
}
}
}
// Extract node IDs sorted by distance (ascending)
let mut sorted_nearest: Vec<_> = nearest.into_iter().collect();
sorted_nearest.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
sorted_nearest.into_iter().map(|(_, id)| id).collect()
}
/// Select M neighbors from candidates using heuristic
fn select_neighbors(&self, base: &[f32], candidates: Vec<usize>, m: usize) -> Vec<usize> {
if candidates.len() <= m {
return candidates;
}
// Simple heuristic: keep nearest M by distance
let mut with_distances: Vec<_> = candidates
.into_iter()
.map(|id| {
let dist = self.distance(base, &self.nodes[id].vector);
(OrderedFloat(dist), id)
})
.collect();
with_distances.sort_by_key(|(dist, _)| *dist);
with_distances.into_iter().take(m).map(|(_, id)| id).collect()
}
/// Compute distance between two vectors
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.metric {
DistanceMetric::Cosine => {
let similarity = self.cosine_similarity(a, b);
// Convert to angular distance: arccos(sim) / π ∈ [0, 1]
similarity.max(-1.0).min(1.0).acos() / std::f32::consts::PI
}
DistanceMetric::Euclidean => {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
DistanceMetric::Manhattan => {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.sum()
}
}
}
/// Compute cosine similarity between two vectors
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot / (norm_a * norm_b)).max(-1.0).min(1.0)
}
/// Randomly assign a layer to a new node
fn random_level(&self) -> usize {
let mut rng = self.rng.write().unwrap();
let uniform: f64 = rng.gen();
(-uniform.ln() * self.config.ml).floor() as usize
}
/// Get the underlying vector for a node
pub fn get_vector(&self, node_id: usize) -> Option<&Vec<f32>> {
self.nodes.get(node_id).map(|n| &n.vector)
}
/// Get the external ID for a node
pub fn get_external_id(&self, node_id: usize) -> Option<&str> {
self.nodes.get(node_id).map(|n| n.external_id.as_str())
}
/// Get total number of nodes in the index
pub fn len(&self) -> usize {
self.nodes.len()
}
/// Check if index is empty
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
}
impl Default for HnswIndex {
fn default() -> Self {
Self::new()
}
}
/// Wrapper for f32 that implements Ord for use in BinaryHeap
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
struct OrderedFloat(f32);
impl Eq for OrderedFloat {}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.partial_cmp(&other.0).unwrap_or(std::cmp::Ordering::Equal)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use crate::ruvector_native::Domain;
fn create_test_vector(id: &str, embedding: Vec<f32>) -> SemanticVector {
SemanticVector {
id: id.to_string(),
embedding,
domain: Domain::Climate,
timestamp: Utc::now(),
metadata: HashMap::new(),
}
}
#[test]
fn test_hnsw_basic_insert_search() {
let config = HnswConfig {
dimension: 3,
..Default::default()
};
let mut index = HnswIndex::with_config(config);
// Insert vectors
let v1 = create_test_vector("v1", vec![1.0, 0.0, 0.0]);
let v2 = create_test_vector("v2", vec![0.0, 1.0, 0.0]);
let v3 = create_test_vector("v3", vec![0.9, 0.1, 0.0]);
index.insert(v1).unwrap();
index.insert(v2).unwrap();
index.insert(v3).unwrap();
assert_eq!(index.len(), 3);
// Search for nearest to v1
let query = vec![1.0, 0.0, 0.0];
let results = index.search_knn(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].external_id, "v1"); // Exact match
assert_eq!(results[1].external_id, "v3"); // Close match
}
#[test]
fn test_hnsw_batch_insert() {
let config = HnswConfig {
dimension: 2,
..Default::default()
};
let mut index = HnswIndex::with_config(config);
let vectors = vec![
create_test_vector("v1", vec![1.0, 0.0]),
create_test_vector("v2", vec![0.0, 1.0]),
create_test_vector("v3", vec![1.0, 1.0]),
];
let ids = index.insert_batch(vectors).unwrap();
assert_eq!(ids.len(), 3);
assert_eq!(index.len(), 3);
}
#[test]
fn test_hnsw_threshold_search() {
let config = HnswConfig {
dimension: 2,
..Default::default()
};
let mut index = HnswIndex::with_config(config);
// Insert vectors at different distances
index.insert(create_test_vector("close", vec![1.0, 0.1])).unwrap();
index.insert(create_test_vector("medium", vec![0.7, 0.7])).unwrap();
index.insert(create_test_vector("far", vec![0.0, 1.0])).unwrap();
let query = vec![1.0, 0.0];
let results = index.search_threshold(&query, 0.3, None).unwrap();
// Should find only close vectors
assert!(results.len() >= 1);
assert!(results.iter().all(|r| r.distance < 0.3));
}
#[test]
fn test_hnsw_cosine_similarity() {
let config = HnswConfig {
dimension: 3,
metric: DistanceMetric::Cosine,
..Default::default()
};
let mut index = HnswIndex::with_config(config);
let v1 = create_test_vector("identical", vec![1.0, 0.0, 0.0]);
let v2 = create_test_vector("orthogonal", vec![0.0, 1.0, 0.0]);
let v3 = create_test_vector("opposite", vec![-1.0, 0.0, 0.0]);
index.insert(v1).unwrap();
index.insert(v2).unwrap();
index.insert(v3).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = index.search_knn(&query, 3).unwrap();
// Identical should be closest
assert_eq!(results[0].external_id, "identical");
assert!(results[0].distance < 0.01);
// Opposite should be farthest
assert_eq!(results[2].external_id, "opposite");
}
#[test]
fn test_hnsw_stats() {
let config = HnswConfig {
dimension: 2,
m: 4,
..Default::default()
};
let mut index = HnswIndex::with_config(config);
for i in 0..10 {
let vec = create_test_vector(&format!("v{}", i), vec![i as f32, i as f32]);
index.insert(vec).unwrap();
}
let stats = index.stats();
assert_eq!(stats.node_count, 10);
assert!(stats.layer_count > 0);
assert_eq!(stats.nodes_per_layer[0], 10); // All nodes in layer 0
assert!(stats.total_edges > 0);
}
#[test]
fn test_dimension_mismatch() {
let config = HnswConfig {
dimension: 3,
..Default::default()
};
let mut index = HnswIndex::with_config(config);
let bad_vector = create_test_vector("bad", vec![1.0, 2.0]); // Wrong dimension
let result = index.insert(bad_vector);
assert!(result.is_err());
}
#[test]
fn test_empty_index_search() {
let index = HnswIndex::new();
let query = vec![1.0; 128];
let results = index.search_knn(&query, 5).unwrap();
assert!(results.is_empty());
}
}

View File

@@ -0,0 +1,342 @@
//! Data ingestion pipeline for streaming data into RuVector
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::{DataRecord, DataSource, FrameworkError, Result};
/// Configuration for data ingestion
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IngestionConfig {
/// Batch size for fetching
pub batch_size: usize,
/// Maximum concurrent fetches
pub max_concurrent: usize,
/// Retry count on failure
pub retry_count: u32,
/// Delay between retries (ms)
pub retry_delay_ms: u64,
/// Enable deduplication
pub deduplicate: bool,
/// Rate limit (requests per second, 0 = unlimited)
pub rate_limit: u32,
}
impl Default for IngestionConfig {
fn default() -> Self {
Self {
batch_size: 1000,
max_concurrent: 4,
retry_count: 3,
retry_delay_ms: 1000,
deduplicate: true,
rate_limit: 10,
}
}
}
/// Configuration for a specific data source
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceConfig {
/// Source identifier
pub source_id: String,
/// API base URL
pub base_url: String,
/// API key (if required)
pub api_key: Option<String>,
/// Additional headers
pub headers: HashMap<String, String>,
/// Custom parameters
pub params: HashMap<String, String>,
}
/// Statistics for ingestion process
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IngestionStats {
/// Total records fetched
pub records_fetched: u64,
/// Batches processed
pub batches_processed: u64,
/// Retries performed
pub retries: u64,
/// Errors encountered
pub errors: u64,
/// Duplicates skipped
pub duplicates_skipped: u64,
/// Bytes downloaded
pub bytes_downloaded: u64,
/// Average batch fetch time (ms)
pub avg_batch_time_ms: f64,
}
/// Data ingestion pipeline
pub struct DataIngester {
config: IngestionConfig,
stats: Arc<std::sync::RwLock<IngestionStats>>,
seen_ids: Arc<std::sync::RwLock<std::collections::HashSet<String>>>,
}
impl DataIngester {
/// Create a new data ingester
pub fn new(config: IngestionConfig) -> Self {
Self {
config,
stats: Arc::new(std::sync::RwLock::new(IngestionStats::default())),
seen_ids: Arc::new(std::sync::RwLock::new(std::collections::HashSet::new())),
}
}
/// Ingest all data from a source
pub async fn ingest_all<S: DataSource>(&self, source: &S) -> Result<Vec<DataRecord>> {
let mut all_records = Vec::new();
let mut cursor: Option<String> = None;
loop {
let (batch, next_cursor) = self
.fetch_with_retry(source, cursor.clone(), self.config.batch_size)
.await?;
if batch.is_empty() {
break;
}
// Deduplicate if enabled
let records = if self.config.deduplicate {
self.deduplicate_batch(batch)
} else {
batch
};
all_records.extend(records);
{
let mut stats = self.stats.write().unwrap();
stats.batches_processed += 1;
}
cursor = next_cursor;
if cursor.is_none() {
break;
}
// Rate limiting
if self.config.rate_limit > 0 {
let delay = 1000 / self.config.rate_limit as u64;
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
}
}
Ok(all_records)
}
/// Stream records with backpressure
pub async fn stream_records<S: DataSource + 'static>(
&self,
source: Arc<S>,
buffer_size: usize,
) -> Result<mpsc::Receiver<DataRecord>> {
let (tx, rx) = mpsc::channel(buffer_size);
let config = self.config.clone();
let stats = self.stats.clone();
let seen_ids = self.seen_ids.clone();
tokio::spawn(async move {
let mut cursor: Option<String> = None;
loop {
match source
.fetch_batch(cursor.clone(), config.batch_size)
.await
{
Ok((batch, next_cursor)) => {
if batch.is_empty() {
break;
}
for record in batch {
// Deduplicate
if config.deduplicate {
let mut ids = seen_ids.write().unwrap();
if ids.contains(&record.id) {
continue;
}
ids.insert(record.id.clone());
}
if tx.send(record).await.is_err() {
return; // Receiver dropped
}
let mut s = stats.write().unwrap();
s.records_fetched += 1;
}
cursor = next_cursor;
if cursor.is_none() {
break;
}
}
Err(_) => {
let mut s = stats.write().unwrap();
s.errors += 1;
break;
}
}
}
});
Ok(rx)
}
/// Fetch a batch with retry logic
async fn fetch_with_retry<S: DataSource>(
&self,
source: &S,
cursor: Option<String>,
batch_size: usize,
) -> Result<(Vec<DataRecord>, Option<String>)> {
let mut last_error = None;
for attempt in 0..=self.config.retry_count {
if attempt > 0 {
let delay = self.config.retry_delay_ms * (1 << (attempt - 1));
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
let mut stats = self.stats.write().unwrap();
stats.retries += 1;
}
match source.fetch_batch(cursor.clone(), batch_size).await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
}
}
}
let mut stats = self.stats.write().unwrap();
stats.errors += 1;
Err(last_error.unwrap_or_else(|| FrameworkError::Ingestion("Unknown error".to_string())))
}
/// Deduplicate a batch of records
fn deduplicate_batch(&self, batch: Vec<DataRecord>) -> Vec<DataRecord> {
let mut unique = Vec::with_capacity(batch.len());
let mut seen = self.seen_ids.write().unwrap();
for record in batch {
if !seen.contains(&record.id) {
seen.insert(record.id.clone());
unique.push(record);
} else {
let mut stats = self.stats.write().unwrap();
stats.duplicates_skipped += 1;
}
}
unique
}
/// Get current ingestion statistics
pub fn stats(&self) -> IngestionStats {
self.stats.read().unwrap().clone()
}
/// Reset statistics
pub fn reset_stats(&self) {
*self.stats.write().unwrap() = IngestionStats::default();
}
}
/// Trait for transforming records during ingestion
#[async_trait]
pub trait RecordTransformer: Send + Sync {
/// Transform a record
async fn transform(&self, record: DataRecord) -> Result<DataRecord>;
/// Filter records (return false to skip)
fn filter(&self, record: &DataRecord) -> bool {
true
}
}
/// Identity transformer (no-op)
pub struct IdentityTransformer;
#[async_trait]
impl RecordTransformer for IdentityTransformer {
async fn transform(&self, record: DataRecord) -> Result<DataRecord> {
Ok(record)
}
}
/// Batched ingestion with transformations
pub struct BatchIngester<T: RecordTransformer> {
ingester: DataIngester,
transformer: T,
}
impl<T: RecordTransformer> BatchIngester<T> {
/// Create a new batch ingester with transformer
pub fn new(config: IngestionConfig, transformer: T) -> Self {
Self {
ingester: DataIngester::new(config),
transformer,
}
}
/// Ingest and transform all records
pub async fn ingest_all<S: DataSource>(&self, source: &S) -> Result<Vec<DataRecord>> {
let raw_records = self.ingester.ingest_all(source).await?;
let mut transformed = Vec::with_capacity(raw_records.len());
for record in raw_records {
if self.transformer.filter(&record) {
let t = self.transformer.transform(record).await?;
transformed.push(t);
}
}
Ok(transformed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = IngestionConfig::default();
assert_eq!(config.batch_size, 1000);
assert!(config.deduplicate);
}
#[test]
fn test_ingester_creation() {
let config = IngestionConfig::default();
let ingester = DataIngester::new(config);
let stats = ingester.stats();
assert_eq!(stats.records_fetched, 0);
}
}

View File

@@ -0,0 +1,470 @@
//! # RuVector Data Discovery Framework
//!
//! Core traits and types for building dataset integrations with RuVector's
//! vector memory, graph structures, and dynamic minimum cut algorithms.
//!
//! ## Architecture
//!
//! The framework provides three core abstractions:
//!
//! 1. **DataIngester**: Streaming data ingestion with batched graph/vector updates
//! 2. **CoherenceEngine**: Real-time coherence signal computation using min-cut
//! 3. **DiscoveryEngine**: Pattern detection for emerging structures and anomalies
//!
//! ## Quick Start
//!
//! ```rust,ignore
//! use ruvector_data_framework::{
//! DataIngester, CoherenceEngine, DiscoveryEngine,
//! IngestionConfig, CoherenceConfig, DiscoveryConfig,
//! };
//!
//! // Configure the discovery pipeline
//! let ingester = DataIngester::new(ingestion_config);
//! let coherence = CoherenceEngine::new(coherence_config);
//! let discovery = DiscoveryEngine::new(discovery_config);
//!
//! // Stream data and detect patterns
//! let stream = ingester.stream_from_source(source).await?;
//! let signals = coherence.compute_signals(stream).await?;
//! let patterns = discovery.detect_patterns(signals).await?;
//! ```
#![warn(missing_docs)]
#![warn(clippy::all)]
pub mod academic_clients;
pub mod api_clients;
pub mod arxiv_client;
pub mod biorxiv_client;
pub mod coherence;
pub mod crossref_client;
pub mod discovery;
pub mod dynamic_mincut;
pub mod economic_clients;
pub mod export;
pub mod finance_clients;
pub mod forecasting;
pub mod genomics_clients;
pub mod geospatial_clients;
pub mod government_clients;
pub mod hnsw;
pub mod cut_aware_hnsw;
pub mod ingester;
pub mod mcp_server;
pub mod medical_clients;
pub mod ml_clients;
pub mod news_clients;
pub mod optimized;
pub mod patent_clients;
pub mod persistence;
pub mod physics_clients;
pub mod realtime;
pub mod ruvector_native;
pub mod semantic_scholar;
pub mod space_clients;
pub mod streaming;
pub mod transportation_clients;
pub mod utils;
pub mod visualization;
pub mod wiki_clients;
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use thiserror::Error;
// Re-exports
pub use academic_clients::{CoreClient, EricClient, UnpaywallClient};
pub use api_clients::{EdgarClient, Embedder, NoaaClient, OpenAlexClient, SimpleEmbedder};
#[cfg(feature = "onnx-embeddings")]
pub use api_clients::OnnxEmbedder;
#[cfg(feature = "onnx-embeddings")]
pub use ruvector_onnx_embeddings::{PretrainedModel, EmbedderConfig, PoolingStrategy};
pub use arxiv_client::ArxivClient;
pub use biorxiv_client::{BiorxivClient, MedrxivClient};
pub use crossref_client::CrossRefClient;
pub use economic_clients::{AlphaVantageClient, FredClient, WorldBankClient};
pub use finance_clients::{BlsClient, CoinGeckoClient, EcbClient, FinnhubClient, TwelveDataClient};
pub use genomics_clients::{EnsemblClient, GwasClient, NcbiClient, UniProtClient};
pub use geospatial_clients::{GeonamesClient, NominatimClient, OpenElevationClient, OverpassClient};
pub use government_clients::{
CensusClient, DataGovClient, EuOpenDataClient, UkGovClient, UNDataClient,
WorldBankClient as WorldBankGovClient,
};
pub use medical_clients::{ClinicalTrialsClient, FdaClient, PubMedClient};
pub use ml_clients::{
HuggingFaceClient, HuggingFaceDataset, HuggingFaceModel, OllamaClient, OllamaModel,
PapersWithCodeClient, PaperWithCodeDataset, PaperWithCodePaper, ReplicateClient,
ReplicateModel, TogetherAiClient, TogetherModel,
};
pub use news_clients::{GuardianClient, HackerNewsClient, NewsDataClient, RedditClient};
pub use patent_clients::{EpoClient, UsptoPatentClient};
pub use physics_clients::{ArgoClient, CernOpenDataClient, GeoUtils, MaterialsProjectClient, UsgsEarthquakeClient};
pub use semantic_scholar::SemanticScholarClient;
pub use space_clients::{AstronomyClient, ExoplanetClient, NasaClient, SpaceXClient};
pub use transportation_clients::{GtfsClient, MobilityDatabaseClient, OpenChargeMapClient, OpenRouteServiceClient};
pub use wiki_clients::{WikidataClient, WikidataEntity, WikipediaClient};
pub use coherence::{
CoherenceBoundary, CoherenceConfig, CoherenceEngine, CoherenceEvent, CoherenceSignal,
};
pub use cut_aware_hnsw::{
CutAwareHNSW, CutAwareConfig, CutAwareMetrics, CoherenceZone,
SearchResult as CutAwareSearchResult, EdgeUpdate as CutAwareEdgeUpdate, UpdateKind, LayerCutStats,
};
pub use discovery::{
DiscoveryConfig, DiscoveryEngine, DiscoveryPattern, PatternCategory, PatternStrength,
};
pub use dynamic_mincut::{
CutGatedSearch, CutWatcherConfig, DynamicCutWatcher, DynamicMinCutError,
EdgeUpdate as DynamicEdgeUpdate, EdgeUpdateType, EulerTourTree, HNSWGraph,
LocalCut, LocalMinCutProcedure, WatcherStats,
};
pub use export::{
export_all, export_coherence_csv, export_dot, export_graphml, export_patterns_csv,
export_patterns_with_evidence_csv, ExportFilter,
};
pub use forecasting::{CoherenceForecaster, CrossDomainForecaster, Forecast, Trend};
pub use ingester::{DataIngester, IngestionConfig, IngestionStats, SourceConfig};
pub use realtime::{FeedItem, FeedSource, NewsAggregator, NewsSource, RealTimeEngine};
pub use ruvector_native::{
CoherenceHistoryEntry, CoherenceSnapshot, Domain, DiscoveredPattern,
GraphExport, NativeDiscoveryEngine, NativeEngineConfig, SemanticVector,
};
pub use streaming::{StreamingConfig, StreamingEngine, StreamingEngineBuilder, StreamingMetrics};
/// Framework error types
#[derive(Error, Debug)]
pub enum FrameworkError {
/// Data ingestion failed
#[error("Ingestion error: {0}")]
Ingestion(String),
/// Coherence computation failed
#[error("Coherence error: {0}")]
Coherence(String),
/// Discovery algorithm failed
#[error("Discovery error: {0}")]
Discovery(String),
/// Network/API error
#[error("Network error: {0}")]
Network(#[from] reqwest::Error),
/// Serialization error
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
/// Graph operation failed
#[error("Graph error: {0}")]
Graph(String),
/// Configuration error
#[error("Config error: {0}")]
Config(String),
}
/// Result type for framework operations
pub type Result<T> = std::result::Result<T, FrameworkError>;
/// A timestamped data record from any source
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataRecord {
/// Unique identifier
pub id: String,
/// Source dataset (e.g., "openalex", "noaa", "edgar")
pub source: String,
/// Record type within source (e.g., "work", "author", "filing")
pub record_type: String,
/// Timestamp when data was observed/published
pub timestamp: DateTime<Utc>,
/// Raw data payload
pub data: serde_json::Value,
/// Pre-computed embedding vector (optional)
pub embedding: Option<Vec<f32>>,
/// Relationships to other records
pub relationships: Vec<Relationship>,
}
/// A relationship between two records
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Relationship {
/// Target record ID
pub target_id: String,
/// Relationship type (e.g., "cites", "authored_by", "filed_by")
pub rel_type: String,
/// Relationship weight/strength
pub weight: f64,
/// Additional properties
pub properties: HashMap<String, serde_json::Value>,
}
/// Trait for data sources that can be ingested
#[async_trait]
pub trait DataSource: Send + Sync {
/// Source identifier
fn source_id(&self) -> &str;
/// Fetch a batch of records starting from cursor
async fn fetch_batch(
&self,
cursor: Option<String>,
batch_size: usize,
) -> Result<(Vec<DataRecord>, Option<String>)>;
/// Get total record count (if known)
async fn total_count(&self) -> Result<Option<u64>>;
/// Check if source is available
async fn health_check(&self) -> Result<bool>;
}
/// Trait for computing embeddings from records
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
/// Compute embedding for a single record
async fn embed_record(&self, record: &DataRecord) -> Result<Vec<f32>>;
/// Compute embeddings for a batch of records
async fn embed_batch(&self, records: &[DataRecord]) -> Result<Vec<Vec<f32>>>;
/// Embedding dimension
fn dimension(&self) -> usize;
}
/// Trait for graph building from records
pub trait GraphBuilder: Send + Sync {
/// Add a node from a data record
fn add_node(&mut self, record: &DataRecord) -> Result<u64>;
/// Add an edge between nodes
fn add_edge(&mut self, source: u64, target: u64, weight: f64, rel_type: &str) -> Result<()>;
/// Get node count
fn node_count(&self) -> usize;
/// Get edge count
fn edge_count(&self) -> usize;
}
/// Temporal window for time-series analysis
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct TemporalWindow {
/// Window start
pub start: DateTime<Utc>,
/// Window end
pub end: DateTime<Utc>,
/// Window identifier (for sliding windows)
pub window_id: u64,
}
impl TemporalWindow {
/// Create a new temporal window
pub fn new(start: DateTime<Utc>, end: DateTime<Utc>, window_id: u64) -> Self {
Self {
start,
end,
window_id,
}
}
/// Duration in seconds
pub fn duration_secs(&self) -> i64 {
(self.end - self.start).num_seconds()
}
/// Check if timestamp falls within window
pub fn contains(&self, timestamp: DateTime<Utc>) -> bool {
timestamp >= self.start && timestamp < self.end
}
}
/// Statistics for a discovery session
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DiscoveryStats {
/// Records processed
pub records_processed: u64,
/// Nodes in graph
pub nodes_created: u64,
/// Edges in graph
pub edges_created: u64,
/// Coherence signals computed
pub signals_computed: u64,
/// Patterns discovered
pub patterns_discovered: u64,
/// Processing duration in milliseconds
pub duration_ms: u64,
/// Peak memory usage in bytes
pub peak_memory_bytes: u64,
}
/// Configuration for the entire discovery pipeline
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineConfig {
/// Ingestion configuration
pub ingestion: IngestionConfig,
/// Coherence engine configuration
pub coherence: CoherenceConfig,
/// Discovery engine configuration
pub discovery: DiscoveryConfig,
/// Enable parallel processing
pub parallel: bool,
/// Checkpoint interval (records)
pub checkpoint_interval: u64,
/// Output directory for results
pub output_dir: String,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
ingestion: IngestionConfig::default(),
coherence: CoherenceConfig::default(),
discovery: DiscoveryConfig::default(),
parallel: true,
checkpoint_interval: 10_000,
output_dir: "./discovery_output".to_string(),
}
}
}
/// Main discovery pipeline orchestrator
pub struct DiscoveryPipeline {
config: PipelineConfig,
ingester: DataIngester,
coherence: CoherenceEngine,
discovery: DiscoveryEngine,
stats: Arc<std::sync::RwLock<DiscoveryStats>>,
}
impl DiscoveryPipeline {
/// Create a new discovery pipeline
pub fn new(config: PipelineConfig) -> Self {
let ingester = DataIngester::new(config.ingestion.clone());
let coherence = CoherenceEngine::new(config.coherence.clone());
let discovery = DiscoveryEngine::new(config.discovery.clone());
Self {
config,
ingester,
coherence,
discovery,
stats: Arc::new(std::sync::RwLock::new(DiscoveryStats::default())),
}
}
/// Run the discovery pipeline on a data source
pub async fn run<S: DataSource>(&mut self, source: S) -> Result<Vec<DiscoveryPattern>> {
let start_time = std::time::Instant::now();
// Phase 1: Ingest data
tracing::info!("Starting ingestion from source: {}", source.source_id());
let records = self.ingester.ingest_all(&source).await?;
{
let mut stats = self.stats.write().unwrap();
stats.records_processed = records.len() as u64;
}
// Phase 2: Build graph and compute coherence
tracing::info!("Computing coherence signals over {} records", records.len());
let signals = self.coherence.compute_from_records(&records)?;
{
let mut stats = self.stats.write().unwrap();
stats.signals_computed = signals.len() as u64;
stats.nodes_created = self.coherence.node_count() as u64;
stats.edges_created = self.coherence.edge_count() as u64;
}
// Phase 3: Detect patterns
tracing::info!("Detecting discovery patterns");
let patterns = self.discovery.detect(&signals)?;
{
let mut stats = self.stats.write().unwrap();
stats.patterns_discovered = patterns.len() as u64;
stats.duration_ms = start_time.elapsed().as_millis() as u64;
}
tracing::info!(
"Discovery complete: {} patterns found in {}ms",
patterns.len(),
start_time.elapsed().as_millis()
);
Ok(patterns)
}
/// Get current statistics
pub fn stats(&self) -> DiscoveryStats {
self.stats.read().unwrap().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_temporal_window() {
let start = Utc::now();
let end = start + chrono::Duration::hours(1);
let window = TemporalWindow::new(start, end, 1);
assert_eq!(window.duration_secs(), 3600);
assert!(window.contains(start + chrono::Duration::minutes(30)));
assert!(!window.contains(start - chrono::Duration::minutes(1)));
assert!(!window.contains(end + chrono::Duration::minutes(1)));
}
#[test]
fn test_default_pipeline_config() {
let config = PipelineConfig::default();
assert!(config.parallel);
assert_eq!(config.checkpoint_interval, 10_000);
}
#[test]
fn test_data_record_serialization() {
let record = DataRecord {
id: "test-1".to_string(),
source: "test".to_string(),
record_type: "document".to_string(),
timestamp: Utc::now(),
data: serde_json::json!({"title": "Test"}),
embedding: Some(vec![0.1, 0.2, 0.3]),
relationships: vec![],
};
let json = serde_json::to_string(&record).unwrap();
let parsed: DataRecord = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, record.id);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,915 @@
//! Medical data API integrations for PubMed, ClinicalTrials.gov, and FDA
//!
//! This module provides async clients for fetching medical literature, clinical trials,
//! and FDA data, converting responses to SemanticVector format for RuVector discovery.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use chrono::{NaiveDate, Utc};
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use tokio::time::sleep;
use crate::api_clients::SimpleEmbedder;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Custom deserializer that handles both string and integer values
fn deserialize_number_from_string<'de, D>(deserializer: D) -> std::result::Result<Option<i32>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
struct NumberOrStringVisitor;
impl<'de> Visitor<'de> for NumberOrStringVisitor {
type Value = Option<i32>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a number or numeric string")
}
fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(v as i32))
}
fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(v as i32))
}
fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
v.parse::<i32>().map(Some).map_err(de::Error::custom)
}
fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}
fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}
}
deserializer.deserialize_any(NumberOrStringVisitor)
}
/// Rate limiting configuration
const NCBI_RATE_LIMIT_MS: u64 = 334; // ~3 requests/second without API key
const NCBI_WITH_KEY_RATE_LIMIT_MS: u64 = 100; // 10 requests/second with key
const FDA_RATE_LIMIT_MS: u64 = 250; // Conservative 4 requests/second
const CLINICALTRIALS_RATE_LIMIT_MS: u64 = 100;
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 1000;
// ============================================================================
// PubMed E-utilities Client
// ============================================================================
/// PubMed ESearch API response
#[derive(Debug, Deserialize)]
struct PubMedSearchResponse {
esearchresult: ESearchResult,
}
#[derive(Debug, Deserialize)]
struct ESearchResult {
#[serde(default)]
idlist: Vec<String>,
#[serde(default)]
count: String,
}
/// PubMed EFetch API response (simplified)
#[derive(Debug, Deserialize)]
struct PubMedFetchResponse {
#[serde(rename = "PubmedArticleSet")]
pubmed_article_set: Option<PubmedArticleSet>,
}
#[derive(Debug, Deserialize)]
struct PubmedArticleSet {
#[serde(rename = "PubmedArticle", default)]
articles: Vec<PubmedArticle>,
}
#[derive(Debug, Deserialize)]
struct PubmedArticle {
#[serde(rename = "MedlineCitation")]
medline_citation: MedlineCitation,
}
#[derive(Debug, Deserialize)]
struct MedlineCitation {
#[serde(rename = "PMID")]
pmid: PmidObject,
#[serde(rename = "Article")]
article: Article,
#[serde(rename = "DateCompleted", default)]
date_completed: Option<DateCompleted>,
}
#[derive(Debug, Deserialize)]
struct PmidObject {
#[serde(rename = "$value", default)]
value: String,
}
#[derive(Debug, Deserialize)]
struct Article {
#[serde(rename = "ArticleTitle", default)]
article_title: Option<String>,
#[serde(rename = "Abstract", default)]
abstract_data: Option<AbstractData>,
#[serde(rename = "AuthorList", default)]
author_list: Option<AuthorList>,
}
#[derive(Debug, Deserialize)]
struct AbstractData {
#[serde(rename = "AbstractText", default)]
abstract_text: Vec<AbstractText>,
}
#[derive(Debug, Deserialize)]
struct AbstractText {
#[serde(rename = "$value", default)]
value: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AuthorList {
#[serde(rename = "Author", default)]
authors: Vec<Author>,
}
#[derive(Debug, Deserialize)]
struct Author {
#[serde(rename = "LastName", default)]
last_name: Option<String>,
#[serde(rename = "ForeName", default)]
fore_name: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DateCompleted {
#[serde(rename = "Year", default)]
year: Option<String>,
#[serde(rename = "Month", default)]
month: Option<String>,
#[serde(rename = "Day", default)]
day: Option<String>,
}
/// Client for PubMed medical literature database
pub struct PubMedClient {
client: Client,
base_url: String,
api_key: Option<String>,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl PubMedClient {
/// Create a new PubMed client
///
/// # Arguments
/// * `api_key` - Optional NCBI API key (get from https://www.ncbi.nlm.nih.gov/account/)
pub fn new(api_key: Option<String>) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(FrameworkError::Network)?;
let rate_limit_delay = if api_key.is_some() {
Duration::from_millis(NCBI_WITH_KEY_RATE_LIMIT_MS)
} else {
Duration::from_millis(NCBI_RATE_LIMIT_MS)
};
Ok(Self {
client,
base_url: "https://eutils.ncbi.nlm.nih.gov/entrez/eutils".to_string(),
api_key,
rate_limit_delay,
embedder: Arc::new(SimpleEmbedder::new(384)), // Higher dimension for medical text
})
}
/// Search PubMed articles by query
///
/// # Arguments
/// * `query` - Search query (e.g., "COVID-19 vaccine", "alzheimer's treatment")
/// * `max_results` - Maximum number of results to return
pub async fn search_articles(
&self,
query: &str,
max_results: usize,
) -> Result<Vec<SemanticVector>> {
// Step 1: Search for PMIDs
let pmids = self.search_pmids(query, max_results).await?;
if pmids.is_empty() {
return Ok(Vec::new());
}
// Step 2: Fetch full abstracts for PMIDs
self.fetch_abstracts(&pmids).await
}
/// Search for PMIDs matching query
async fn search_pmids(&self, query: &str, max_results: usize) -> Result<Vec<String>> {
let mut url = format!(
"{}/esearch.fcgi?db=pubmed&term={}&retmode=json&retmax={}",
self.base_url,
urlencoding::encode(query),
max_results
);
if let Some(key) = &self.api_key {
url.push_str(&format!("&api_key={}", key));
}
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let search_response: PubMedSearchResponse = response.json().await?;
Ok(search_response.esearchresult.idlist)
}
/// Fetch full article abstracts by PMIDs
///
/// # Arguments
/// * `pmids` - List of PubMed IDs to fetch
pub async fn fetch_abstracts(&self, pmids: &[String]) -> Result<Vec<SemanticVector>> {
if pmids.is_empty() {
return Ok(Vec::new());
}
// Batch PMIDs (max 200 per request)
let mut all_vectors = Vec::new();
for chunk in pmids.chunks(200) {
let pmid_list = chunk.join(",");
let mut url = format!(
"{}/efetch.fcgi?db=pubmed&id={}&retmode=xml",
self.base_url, pmid_list
);
if let Some(key) = &self.api_key {
url.push_str(&format!("&api_key={}", key));
}
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let xml_text = response.text().await?;
// Parse XML response
let vectors = self.parse_xml_to_vectors(&xml_text)?;
all_vectors.extend(vectors);
}
Ok(all_vectors)
}
/// Parse PubMed XML response to SemanticVectors
fn parse_xml_to_vectors(&self, xml: &str) -> Result<Vec<SemanticVector>> {
// Use quick-xml for parsing
let fetch_response: PubMedFetchResponse = quick_xml::de::from_str(xml)
.map_err(|e| FrameworkError::Config(format!("XML parse error: {}", e)))?;
let mut vectors = Vec::new();
if let Some(article_set) = fetch_response.pubmed_article_set {
for pubmed_article in article_set.articles {
let citation = pubmed_article.medline_citation;
let article = citation.article;
let pmid = citation.pmid.value;
let title = article.article_title.unwrap_or_else(|| "Untitled".to_string());
// Extract abstract text
let abstract_text = article
.abstract_data
.as_ref()
.map(|abs| {
abs.abstract_text
.iter()
.filter_map(|at| at.value.clone())
.collect::<Vec<_>>()
.join(" ")
})
.unwrap_or_default();
// Create combined text for embedding
let text = format!("{} {}", title, abstract_text);
let embedding = self.embedder.embed_text(&text);
// Parse publication date
let timestamp = citation
.date_completed
.as_ref()
.and_then(|date| {
let year = date.year.as_ref()?.parse::<i32>().ok()?;
let month = date.month.as_ref()?.parse::<u32>().ok()?;
let day = date.day.as_ref()?.parse::<u32>().ok()?;
NaiveDate::from_ymd_opt(year, month, day)
})
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
// Extract author names
let authors = article
.author_list
.as_ref()
.map(|al| {
al.authors
.iter()
.filter_map(|a| {
let last = a.last_name.as_deref().unwrap_or("");
let first = a.fore_name.as_deref().unwrap_or("");
if !last.is_empty() {
Some(format!("{} {}", first, last))
} else {
None
}
})
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_default();
let mut metadata = HashMap::new();
metadata.insert("pmid".to_string(), pmid.clone());
metadata.insert("title".to_string(), title);
metadata.insert("abstract".to_string(), abstract_text);
metadata.insert("authors".to_string(), authors);
metadata.insert("source".to_string(), "pubmed".to_string());
vectors.push(SemanticVector {
id: format!("PMID:{}", pmid),
embedding,
domain: Domain::Medical,
timestamp,
metadata,
});
}
}
Ok(vectors)
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
// ============================================================================
// ClinicalTrials.gov Client
// ============================================================================
/// ClinicalTrials.gov API response
#[derive(Debug, Deserialize)]
struct ClinicalTrialsResponse {
#[serde(default)]
studies: Vec<ClinicalStudy>,
}
#[derive(Debug, Deserialize)]
struct ClinicalStudy {
#[serde(rename = "protocolSection")]
protocol_section: ProtocolSection,
}
#[derive(Debug, Deserialize)]
struct ProtocolSection {
#[serde(rename = "identificationModule")]
identification: IdentificationModule,
#[serde(rename = "statusModule")]
status: StatusModule,
#[serde(rename = "descriptionModule", default)]
description: Option<DescriptionModule>,
#[serde(rename = "conditionsModule", default)]
conditions: Option<ConditionsModule>,
}
#[derive(Debug, Deserialize)]
struct IdentificationModule {
#[serde(rename = "nctId")]
nct_id: String,
#[serde(rename = "briefTitle", default)]
brief_title: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StatusModule {
#[serde(rename = "overallStatus", default)]
overall_status: Option<String>,
#[serde(rename = "startDateStruct", default)]
start_date: Option<DateStruct>,
}
#[derive(Debug, Deserialize)]
struct DateStruct {
#[serde(default)]
date: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DescriptionModule {
#[serde(rename = "briefSummary", default)]
brief_summary: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ConditionsModule {
#[serde(default)]
conditions: Vec<String>,
}
/// Client for ClinicalTrials.gov database
pub struct ClinicalTrialsClient {
client: Client,
base_url: String,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl ClinicalTrialsClient {
/// Create a new ClinicalTrials.gov client
pub fn new() -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(FrameworkError::Network)?;
Ok(Self {
client,
base_url: "https://clinicaltrials.gov/api/v2".to_string(),
rate_limit_delay: Duration::from_millis(CLINICALTRIALS_RATE_LIMIT_MS),
embedder: Arc::new(SimpleEmbedder::new(384)),
})
}
/// Search clinical trials by condition
///
/// # Arguments
/// * `condition` - Medical condition to search (e.g., "diabetes", "cancer")
/// * `status` - Optional recruitment status filter (e.g., "RECRUITING", "COMPLETED")
pub async fn search_trials(
&self,
condition: &str,
status: Option<&str>,
) -> Result<Vec<SemanticVector>> {
let mut url = format!(
"{}/studies?query.cond={}&pageSize=100",
self.base_url,
urlencoding::encode(condition)
);
if let Some(s) = status {
url.push_str(&format!("&filter.overallStatus={}", s));
}
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let trials_response: ClinicalTrialsResponse = response.json().await?;
let mut vectors = Vec::new();
for study in trials_response.studies {
let vector = self.study_to_vector(study)?;
vectors.push(vector);
}
Ok(vectors)
}
/// Convert clinical study to SemanticVector
fn study_to_vector(&self, study: ClinicalStudy) -> Result<SemanticVector> {
let protocol = study.protocol_section;
let nct_id = protocol.identification.nct_id;
let title = protocol
.identification
.brief_title
.unwrap_or_else(|| "Untitled Study".to_string());
let summary = protocol
.description
.as_ref()
.and_then(|d| d.brief_summary.clone())
.unwrap_or_default();
let conditions = protocol
.conditions
.as_ref()
.map(|c| c.conditions.join(", "))
.unwrap_or_default();
let status = protocol
.status
.overall_status
.unwrap_or_else(|| "UNKNOWN".to_string());
// Create text for embedding
let text = format!("{} {} {}", title, summary, conditions);
let embedding = self.embedder.embed_text(&text);
// Parse start date
let timestamp = protocol
.status
.start_date
.as_ref()
.and_then(|sd| sd.date.as_ref())
.and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok())
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
let mut metadata = HashMap::new();
metadata.insert("nct_id".to_string(), nct_id.clone());
metadata.insert("title".to_string(), title);
metadata.insert("summary".to_string(), summary);
metadata.insert("conditions".to_string(), conditions);
metadata.insert("status".to_string(), status);
metadata.insert("source".to_string(), "clinicaltrials".to_string());
Ok(SemanticVector {
id: format!("NCT:{}", nct_id),
embedding,
domain: Domain::Medical,
timestamp,
metadata,
})
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for ClinicalTrialsClient {
fn default() -> Self {
Self::new().expect("Failed to create ClinicalTrials client")
}
}
// ============================================================================
// FDA OpenFDA Client
// ============================================================================
/// OpenFDA drug adverse event response
#[derive(Debug, Deserialize)]
struct FdaDrugEventResponse {
results: Vec<FdaDrugEvent>,
}
#[derive(Debug, Deserialize)]
struct FdaDrugEvent {
#[serde(rename = "safetyreportid")]
safety_report_id: String,
#[serde(rename = "receivedate", default)]
receive_date: Option<String>,
#[serde(default)]
patient: Option<FdaPatient>,
#[serde(default, deserialize_with = "deserialize_number_from_string")]
serious: Option<i32>,
}
#[derive(Debug, Deserialize)]
struct FdaPatient {
#[serde(default)]
drug: Vec<FdaDrug>,
#[serde(default)]
reaction: Vec<FdaReaction>,
}
#[derive(Debug, Deserialize)]
struct FdaDrug {
#[serde(rename = "medicinalproduct", default)]
medicinal_product: Option<String>,
}
#[derive(Debug, Deserialize)]
struct FdaReaction {
#[serde(rename = "reactionmeddrapt", default)]
reaction_meddra_pt: Option<String>,
}
/// OpenFDA device recall response
#[derive(Debug, Deserialize)]
struct FdaRecallResponse {
results: Vec<FdaRecall>,
}
#[derive(Debug, Deserialize)]
struct FdaRecall {
#[serde(rename = "recall_number")]
recall_number: String,
#[serde(default)]
reason_for_recall: Option<String>,
#[serde(default)]
product_description: Option<String>,
#[serde(default)]
report_date: Option<String>,
#[serde(default)]
classification: Option<String>,
}
/// Client for FDA OpenFDA API
pub struct FdaClient {
client: Client,
base_url: String,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl FdaClient {
/// Create a new FDA OpenFDA client
pub fn new() -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(FrameworkError::Network)?;
Ok(Self {
client,
base_url: "https://api.fda.gov".to_string(),
rate_limit_delay: Duration::from_millis(FDA_RATE_LIMIT_MS),
embedder: Arc::new(SimpleEmbedder::new(384)),
})
}
/// Search drug adverse events
///
/// # Arguments
/// * `drug_name` - Name of drug to search (e.g., "aspirin", "ibuprofen")
pub async fn search_drug_events(&self, drug_name: &str) -> Result<Vec<SemanticVector>> {
let url = format!(
"{}/drug/event.json?search=patient.drug.medicinalproduct:\"{}\"&limit=100",
self.base_url,
urlencoding::encode(drug_name)
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
// FDA API may return 404 if no results - handle gracefully
if response.status() == StatusCode::NOT_FOUND {
return Ok(Vec::new());
}
let events_response: FdaDrugEventResponse = response.json().await?;
let mut vectors = Vec::new();
for event in events_response.results {
let vector = self.drug_event_to_vector(event)?;
vectors.push(vector);
}
Ok(vectors)
}
/// Search device recalls
///
/// # Arguments
/// * `reason` - Reason for recall to search
pub async fn search_recalls(&self, reason: &str) -> Result<Vec<SemanticVector>> {
let url = format!(
"{}/device/recall.json?search=reason_for_recall:\"{}\"&limit=100",
self.base_url,
urlencoding::encode(reason)
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(Vec::new());
}
let recalls_response: FdaRecallResponse = response.json().await?;
let mut vectors = Vec::new();
for recall in recalls_response.results {
let vector = self.recall_to_vector(recall)?;
vectors.push(vector);
}
Ok(vectors)
}
/// Convert drug event to SemanticVector
fn drug_event_to_vector(&self, event: FdaDrugEvent) -> Result<SemanticVector> {
let mut drug_names = Vec::new();
let mut reactions = Vec::new();
if let Some(patient) = &event.patient {
for drug in &patient.drug {
if let Some(name) = &drug.medicinal_product {
drug_names.push(name.clone());
}
}
for reaction in &patient.reaction {
if let Some(r) = &reaction.reaction_meddra_pt {
reactions.push(r.clone());
}
}
}
let drugs_text = drug_names.join(", ");
let reactions_text = reactions.join(", ");
let serious = if event.serious == Some(1) {
"serious"
} else {
"non-serious"
};
// Create text for embedding
let text = format!("Drug: {} Reactions: {} Severity: {}", drugs_text, reactions_text, serious);
let embedding = self.embedder.embed_text(&text);
// Parse receive date
let timestamp = event
.receive_date
.as_ref()
.and_then(|d| NaiveDate::parse_from_str(d, "%Y%m%d").ok())
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
let mut metadata = HashMap::new();
metadata.insert("report_id".to_string(), event.safety_report_id.clone());
metadata.insert("drugs".to_string(), drugs_text);
metadata.insert("reactions".to_string(), reactions_text);
metadata.insert("serious".to_string(), serious.to_string());
metadata.insert("source".to_string(), "fda_drug_events".to_string());
Ok(SemanticVector {
id: format!("FDA_EVENT:{}", event.safety_report_id),
embedding,
domain: Domain::Medical,
timestamp,
metadata,
})
}
/// Convert recall to SemanticVector
fn recall_to_vector(&self, recall: FdaRecall) -> Result<SemanticVector> {
let reason = recall.reason_for_recall.unwrap_or_else(|| "Unknown reason".to_string());
let product = recall.product_description.unwrap_or_else(|| "Unknown product".to_string());
let classification = recall.classification.unwrap_or_else(|| "Unknown".to_string());
// Create text for embedding
let text = format!("Product: {} Reason: {} Classification: {}", product, reason, classification);
let embedding = self.embedder.embed_text(&text);
// Parse report date
let timestamp = recall
.report_date
.as_ref()
.and_then(|d| NaiveDate::parse_from_str(d, "%Y%m%d").ok())
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
let mut metadata = HashMap::new();
metadata.insert("recall_number".to_string(), recall.recall_number.clone());
metadata.insert("reason".to_string(), reason);
metadata.insert("product".to_string(), product);
metadata.insert("classification".to_string(), classification);
metadata.insert("source".to_string(), "fda_recalls".to_string());
Ok(SemanticVector {
id: format!("FDA_RECALL:{}", recall.recall_number),
embedding,
domain: Domain::Medical,
timestamp,
metadata,
})
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for FdaClient {
fn default() -> Self {
Self::new().expect("Failed to create FDA client")
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pubmed_client_creation() {
let client = PubMedClient::new(None);
assert!(client.is_ok());
}
#[tokio::test]
async fn test_clinical_trials_client_creation() {
let client = ClinicalTrialsClient::new();
assert!(client.is_ok());
}
#[tokio::test]
async fn test_fda_client_creation() {
let client = FdaClient::new();
assert!(client.is_ok());
}
#[test]
fn test_rate_limiting() {
// Verify rate limits are set correctly
let pubmed_without_key = PubMedClient::new(None).unwrap();
assert_eq!(
pubmed_without_key.rate_limit_delay,
Duration::from_millis(NCBI_RATE_LIMIT_MS)
);
let pubmed_with_key = PubMedClient::new(Some("test_key".to_string())).unwrap();
assert_eq!(
pubmed_with_key.rate_limit_delay,
Duration::from_millis(NCBI_WITH_KEY_RATE_LIMIT_MS)
);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,667 @@
//! Patent database API integrations for USPTO PatentsView and EPO
//!
//! This module provides async clients for fetching patent data from:
//! - USPTO PatentsView API (Free, no authentication required)
//! - EPO Open Patent Services (Free tier available)
//!
//! Converts patent data to SemanticVector format for RuVector discovery.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use chrono::{NaiveDate, Utc};
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use tokio::time::sleep;
use crate::api_clients::SimpleEmbedder;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Rate limiting configuration
const USPTO_RATE_LIMIT_MS: u64 = 200; // ~5 requests/second
const EPO_RATE_LIMIT_MS: u64 = 1000; // Conservative 1 request/second
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 1000;
// ============================================================================
// USPTO PatentsView API Client
// ============================================================================
/// USPTO PatentsView API response
#[derive(Debug, Deserialize)]
struct UsptoPatentsResponse {
#[serde(default)]
patents: Vec<UsptoPatent>,
#[serde(default)]
count: i32,
#[serde(default)]
total_patent_count: Option<i32>,
}
/// USPTO Patent record
#[derive(Debug, Deserialize)]
struct UsptoPatent {
/// Patent number
patent_number: String,
/// Patent title
#[serde(default)]
patent_title: Option<String>,
/// Patent abstract
#[serde(default)]
patent_abstract: Option<String>,
/// Grant date (YYYY-MM-DD)
#[serde(default)]
patent_date: Option<String>,
/// Application filing date
#[serde(default)]
app_date: Option<String>,
/// Assignees (organizations/companies)
#[serde(default)]
assignees: Vec<UsptoAssignee>,
/// Inventors
#[serde(default)]
inventors: Vec<UsptoInventor>,
/// CPC classifications
#[serde(default)]
cpcs: Vec<UsptoCpc>,
/// Citation counts
#[serde(default)]
cited_patent_count: Option<i32>,
#[serde(default)]
citedby_patent_count: Option<i32>,
}
#[derive(Debug, Deserialize)]
struct UsptoAssignee {
#[serde(default)]
assignee_organization: Option<String>,
#[serde(default)]
assignee_individual_name_first: Option<String>,
#[serde(default)]
assignee_individual_name_last: Option<String>,
}
#[derive(Debug, Deserialize)]
struct UsptoInventor {
#[serde(default)]
inventor_name_first: Option<String>,
#[serde(default)]
inventor_name_last: Option<String>,
}
#[derive(Debug, Deserialize)]
struct UsptoCpc {
/// CPC section (e.g., "Y02")
#[serde(default)]
cpc_section_id: Option<String>,
/// CPC subclass (e.g., "Y02E")
#[serde(default)]
cpc_subclass_id: Option<String>,
/// CPC group (e.g., "Y02E10/50")
#[serde(default)]
cpc_group_id: Option<String>,
}
/// USPTO citation response
#[derive(Debug, Deserialize)]
struct UsptoCitationsResponse {
#[serde(default)]
patents: Vec<UsptoCitation>,
}
#[derive(Debug, Deserialize)]
struct UsptoCitation {
patent_number: String,
#[serde(default)]
patent_title: Option<String>,
}
/// Client for USPTO PatentsView API (PatentSearch API v2)
///
/// PatentsView provides free access to USPTO patent data with no authentication required.
/// Uses the new PatentSearch API (ElasticSearch-based) as of May 2025.
/// API documentation: https://search.patentsview.org/docs/
pub struct UsptoPatentClient {
client: Client,
base_url: String,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl UsptoPatentClient {
/// Create a new USPTO PatentsView client
///
/// No authentication required for the PatentsView API.
/// Uses the new PatentSearch API at search.patentsview.org
pub fn new() -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.user_agent("RuVector-Discovery/1.0")
.build()
.map_err(FrameworkError::Network)?;
Ok(Self {
client,
base_url: "https://search.patentsview.org/api/v1".to_string(),
rate_limit_delay: Duration::from_millis(USPTO_RATE_LIMIT_MS),
embedder: Arc::new(SimpleEmbedder::new(512)), // Higher dimension for technical text
})
}
/// Search patents by keyword query
///
/// # Arguments
/// * `query` - Search keywords (e.g., "artificial intelligence", "solar cell")
/// * `max_results` - Maximum number of results to return (max 1000 per page)
///
/// # Example
/// ```ignore
/// let client = UsptoPatentClient::new()?;
/// let patents = client.search_patents("quantum computing", 50).await?;
/// ```
pub async fn search_patents(
&self,
query: &str,
max_results: usize,
) -> Result<Vec<SemanticVector>> {
let per_page = max_results.min(100);
let encoded_query = urlencoding::encode(query);
// New PatentSearch API uses GET with query parameters
// Query format: q=patent_title:*query* OR patent_abstract:*query*
let url = format!(
"{}/patent/?q=patent_title:*{}*%20OR%20patent_abstract:*{}*&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":{},\"matched_subentities_only\":true}}",
self.base_url, encoded_query, encoded_query, per_page
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let uspto_response: UsptoPatentsResponse = response.json().await?;
self.convert_patents_to_vectors(uspto_response.patents)
}
/// Search patents by assignee (company/organization name)
///
/// # Arguments
/// * `company_name` - Company or organization name (e.g., "IBM", "Google")
/// * `max_results` - Maximum number of results to return
///
/// # Example
/// ```ignore
/// let patents = client.search_by_assignee("Tesla Inc", 100).await?;
/// ```
pub async fn search_by_assignee(
&self,
company_name: &str,
max_results: usize,
) -> Result<Vec<SemanticVector>> {
let per_page = max_results.min(100);
let encoded_name = urlencoding::encode(company_name);
// New PatentSearch API format
let url = format!(
"{}/patent/?q=assignees.assignee_organization:*{}*&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":{},\"matched_subentities_only\":true}}",
self.base_url, encoded_name, per_page
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let uspto_response: UsptoPatentsResponse = response.json().await?;
self.convert_patents_to_vectors(uspto_response.patents)
}
/// Search patents by CPC classification code
///
/// # Arguments
/// * `cpc_class` - CPC classification (e.g., "Y02E" for climate tech energy, "G06N" for AI)
/// * `max_results` - Maximum number of results to return
///
/// # Example - Climate Change Mitigation Technologies
/// ```ignore
/// let climate_patents = client.search_by_cpc("Y02", 200).await?;
/// ```
///
/// # Common CPC Classes
/// * `Y02` - Climate change mitigation technologies
/// * `Y02E` - Climate tech - Energy generation/transmission/distribution
/// * `G06N` - Computing arrangements based on AI/ML/neural networks
/// * `A61` - Medical or veterinary science
/// * `H01` - Electric elements (batteries, solar cells, etc.)
pub async fn search_by_cpc(
&self,
cpc_class: &str,
max_results: usize,
) -> Result<Vec<SemanticVector>> {
let per_page = max_results.min(100);
let encoded_cpc = urlencoding::encode(cpc_class);
// New PatentSearch API - query cpcs.cpc_group field
let url = format!(
"{}/patent/?q=cpcs.cpc_group:{}*&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":{},\"matched_subentities_only\":true}}",
self.base_url, encoded_cpc, per_page
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let uspto_response: UsptoPatentsResponse = response.json().await?;
self.convert_patents_to_vectors(uspto_response.patents)
}
/// Get detailed information for a specific patent
///
/// # Arguments
/// * `patent_number` - USPTO patent number (e.g., "10000000")
///
/// # Example
/// ```ignore
/// let patent = client.get_patent("10123456").await?;
/// ```
pub async fn get_patent(&self, patent_number: &str) -> Result<Option<SemanticVector>> {
// New PatentSearch API - direct patent lookup
let url = format!(
"{}/patent/?q=patent_id:{}&f=patent_id,patent_title,patent_abstract,patent_date,assignees,inventors,cpcs&o={{\"size\":1}}",
self.base_url, patent_number
);
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(&url).await?;
let uspto_response: UsptoPatentsResponse = response.json().await?;
let mut vectors = self.convert_patents_to_vectors(uspto_response.patents)?;
Ok(vectors.pop())
}
/// Get citations for a patent (both citing and cited patents)
///
/// # Arguments
/// * `patent_number` - USPTO patent number
///
/// # Returns
/// Tuple of (patents that cite this patent, patents cited by this patent)
pub async fn get_citations(
&self,
patent_number: &str,
) -> Result<(Vec<SemanticVector>, Vec<SemanticVector>)> {
// Get patents that cite this one (forward citations)
let citing = self.get_citing_patents(patent_number).await?;
// Get patents cited by this one (backward citations)
let cited = self.get_cited_patents(patent_number).await?;
Ok((citing, cited))
}
/// Get patents that cite the given patent (forward citations)
/// Note: Citation data requires separate API endpoints in PatentSearch API v2
async fn get_citing_patents(&self, _patent_number: &str) -> Result<Vec<SemanticVector>> {
// The new PatentSearch API handles citations differently
// Forward citations are available via /api/v1/us_patent_citation/ endpoint
// For now, return empty - full citation support requires additional implementation
Ok(Vec::new())
}
/// Get patents cited by the given patent (backward citations)
/// Note: Citation data requires separate API endpoints in PatentSearch API v2
async fn get_cited_patents(&self, _patent_number: &str) -> Result<Vec<SemanticVector>> {
// The new PatentSearch API handles citations differently
// Backward citations are available via /api/v1/us_patent_citation/ endpoint
// For now, return empty - full citation support requires additional implementation
Ok(Vec::new())
}
/// Convert USPTO patent records to SemanticVectors
fn convert_patents_to_vectors(&self, patents: Vec<UsptoPatent>) -> Result<Vec<SemanticVector>> {
let mut vectors = Vec::new();
for patent in patents {
let title = patent.patent_title.unwrap_or_else(|| "Untitled Patent".to_string());
let abstract_text = patent.patent_abstract.unwrap_or_default();
// Create combined text for embedding
let text = format!("{} {}", title, abstract_text);
let embedding = self.embedder.embed_text(&text);
// Parse grant date (prefer patent_date, fallback to app_date)
let timestamp = patent
.patent_date
.or(patent.app_date)
.as_ref()
.and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok())
.and_then(|d| d.and_hms_opt(0, 0, 0))
.map(|dt| dt.and_utc())
.unwrap_or_else(Utc::now);
// Extract assignee names
let assignees = patent
.assignees
.iter()
.map(|a| {
a.assignee_organization
.clone()
.or_else(|| {
let first = a.assignee_individual_name_first.as_deref().unwrap_or("");
let last = a.assignee_individual_name_last.as_deref().unwrap_or("");
if !first.is_empty() || !last.is_empty() {
Some(format!("{} {}", first, last).trim().to_string())
} else {
None
}
})
.unwrap_or_default()
})
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join(", ");
// Extract inventor names
let inventors = patent
.inventors
.iter()
.filter_map(|i| {
let first = i.inventor_name_first.as_deref().unwrap_or("");
let last = i.inventor_name_last.as_deref().unwrap_or("");
if !first.is_empty() || !last.is_empty() {
Some(format!("{} {}", first, last).trim().to_string())
} else {
None
}
})
.collect::<Vec<_>>()
.join(", ");
// Extract CPC codes
let cpc_codes = patent
.cpcs
.iter()
.filter_map(|cpc| {
cpc.cpc_group_id
.clone()
.or_else(|| cpc.cpc_subclass_id.clone())
.or_else(|| cpc.cpc_section_id.clone())
})
.collect::<Vec<_>>()
.join(", ");
// Build metadata
let mut metadata = HashMap::new();
metadata.insert("patent_number".to_string(), patent.patent_number.clone());
metadata.insert("title".to_string(), title);
metadata.insert("abstract".to_string(), abstract_text);
metadata.insert("assignee".to_string(), assignees);
metadata.insert("inventors".to_string(), inventors);
metadata.insert("cpc_codes".to_string(), cpc_codes);
metadata.insert(
"citations_count".to_string(),
patent.citedby_patent_count.unwrap_or(0).to_string(),
);
metadata.insert(
"cited_count".to_string(),
patent.cited_patent_count.unwrap_or(0).to_string(),
);
metadata.insert("source".to_string(), "uspto".to_string());
vectors.push(SemanticVector {
id: format!("US{}", patent.patent_number),
embedding,
domain: Domain::Research, // Could be Domain::Innovation if that variant exists
timestamp,
metadata,
});
}
Ok(vectors)
}
/// GET request with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
{
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
if !response.status().is_success() {
return Err(FrameworkError::Network(
reqwest::Error::from(response.error_for_status().unwrap_err()),
));
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
/// POST request with retry logic (kept for backwards compatibility)
#[allow(dead_code)]
async fn post_with_retry(
&self,
url: &str,
json: &serde_json::Value,
) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.post(url).json(json).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
{
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
if !response.status().is_success() {
return Err(FrameworkError::Network(
reqwest::Error::from(response.error_for_status().unwrap_err()),
));
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for UsptoPatentClient {
fn default() -> Self {
Self::new().expect("Failed to create USPTO client")
}
}
// ============================================================================
// EPO Open Patent Services Client (Placeholder)
// ============================================================================
/// Client for European Patent Office (EPO) Open Patent Services
///
/// Note: This is a placeholder for future EPO integration.
/// The EPO OPS API requires registration and OAuth authentication.
/// See: https://developers.epo.org/
pub struct EpoClient {
client: Client,
base_url: String,
consumer_key: Option<String>,
consumer_secret: Option<String>,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl EpoClient {
/// Create a new EPO client
///
/// # Arguments
/// * `consumer_key` - EPO API consumer key (from developer registration)
/// * `consumer_secret` - EPO API consumer secret
///
/// Registration required at: https://developers.epo.org/
pub fn new(consumer_key: Option<String>, consumer_secret: Option<String>) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(FrameworkError::Network)?;
Ok(Self {
client,
base_url: "https://ops.epo.org/3.2/rest-services".to_string(),
consumer_key,
consumer_secret,
rate_limit_delay: Duration::from_millis(EPO_RATE_LIMIT_MS),
embedder: Arc::new(SimpleEmbedder::new(512)),
})
}
/// Search European patents
///
/// Note: Implementation requires OAuth authentication flow.
/// This is a placeholder for future development.
pub async fn search_patents(
&self,
_query: &str,
_max_results: usize,
) -> Result<Vec<SemanticVector>> {
Err(FrameworkError::Config(
"EPO client not yet implemented. Requires OAuth authentication.".to_string(),
))
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_uspto_client_creation() {
let client = UsptoPatentClient::new();
assert!(client.is_ok());
}
#[tokio::test]
async fn test_epo_client_creation() {
let client = EpoClient::new(None, None);
assert!(client.is_ok());
}
#[test]
fn test_default_client() {
let client = UsptoPatentClient::default();
assert_eq!(
client.rate_limit_delay,
Duration::from_millis(USPTO_RATE_LIMIT_MS)
);
}
#[test]
fn test_rate_limiting() {
let client = UsptoPatentClient::new().unwrap();
assert_eq!(
client.rate_limit_delay,
Duration::from_millis(USPTO_RATE_LIMIT_MS)
);
}
#[test]
fn test_cpc_classification_mapping() {
// Verify we handle different CPC code lengths correctly
let test_cases = vec![
("Y02", "cpc_section_id"),
("G06N", "cpc_subclass_id"),
("Y02E10/50", "cpc_group_id"),
];
for (code, expected_field) in test_cases {
let field = if code.len() <= 3 {
"cpc_section_id"
} else if code.len() <= 4 {
"cpc_subclass_id"
} else {
"cpc_group_id"
};
assert_eq!(field, expected_field, "Failed for CPC code: {}", code);
}
}
#[tokio::test]
#[ignore] // Requires network access
async fn test_search_patents_integration() {
let client = UsptoPatentClient::new().unwrap();
// Test basic search
let result = client.search_patents("quantum computing", 5).await;
// Should either succeed or fail with network error, not panic
match result {
Ok(patents) => {
assert!(patents.len() <= 5);
for patent in patents {
assert!(patent.id.starts_with("US"));
assert_eq!(patent.domain, Domain::Research);
assert!(!patent.metadata.is_empty());
}
}
Err(e) => {
// Network errors are acceptable in tests
println!("Network test skipped: {}", e);
}
}
}
#[tokio::test]
#[ignore] // Requires network access
async fn test_search_by_cpc_integration() {
let client = UsptoPatentClient::new().unwrap();
// Test CPC search for AI/ML patents
let result = client.search_by_cpc("G06N", 5).await;
match result {
Ok(patents) => {
assert!(patents.len() <= 5);
for patent in patents {
let cpc_codes = patent.metadata.get("cpc_codes").map(|s| s.as_str()).unwrap_or("");
// Should contain G06N classification
assert!(
cpc_codes.contains("G06N") || cpc_codes.is_empty(),
"Expected G06N in CPC codes, got: {}",
cpc_codes
);
}
}
Err(e) => {
println!("Network test skipped: {}", e);
}
}
}
#[test]
fn test_embedding_dimension() {
let client = UsptoPatentClient::new().unwrap();
// Verify embedding dimension is set correctly for technical text
let embedding = client.embedder.embed_text("test patent description");
assert_eq!(embedding.len(), 512);
}
}

View File

@@ -0,0 +1,638 @@
//! Persistence Layer for RuVector Discovery Framework
//!
//! This module provides serialization/deserialization for the OptimizedDiscoveryEngine
//! and discovered patterns. Supports:
//! - Full engine state save/load
//! - Pattern-only save/load/append
//! - Optional gzip compression for large datasets
//! - Incremental pattern appends without rewriting entire files
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use chrono::{DateTime, Utc};
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use serde::{Deserialize, Serialize};
use crate::optimized::{OptimizedConfig, OptimizedDiscoveryEngine, SignificantPattern};
use crate::ruvector_native::{
CoherenceSnapshot, Domain, GraphEdge, GraphNode, SemanticVector,
};
use crate::{FrameworkError, Result};
/// Serializable state of the OptimizedDiscoveryEngine
///
/// This struct excludes non-serializable fields like AtomicU64 metrics
/// and caches, focusing on the core graph and history state.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EngineState {
/// Engine configuration
pub config: OptimizedConfig,
/// All semantic vectors
pub vectors: Vec<SemanticVector>,
/// Graph nodes
pub nodes: HashMap<u32, GraphNode>,
/// Graph edges
pub edges: Vec<GraphEdge>,
/// Coherence history (timestamp, mincut value, snapshot)
pub coherence_history: Vec<(DateTime<Utc>, f64, CoherenceSnapshot)>,
/// Next node ID counter
pub next_node_id: u32,
/// Domain-specific node indices
pub domain_nodes: HashMap<Domain, Vec<u32>>,
/// Temporal analysis state
pub domain_timeseries: HashMap<Domain, Vec<(DateTime<Utc>, f64)>>,
/// Metadata about when this state was saved
pub saved_at: DateTime<Utc>,
/// Version for compatibility checking
pub version: String,
}
impl EngineState {
/// Create a new empty engine state
pub fn new(config: OptimizedConfig) -> Self {
Self {
config,
vectors: Vec::new(),
nodes: HashMap::new(),
edges: Vec::new(),
coherence_history: Vec::new(),
next_node_id: 0,
domain_nodes: HashMap::new(),
domain_timeseries: HashMap::new(),
saved_at: Utc::now(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
}
/// Options for saving/loading with compression
#[derive(Debug, Clone, Copy)]
pub struct PersistenceOptions {
/// Enable gzip compression
pub compress: bool,
/// Compression level (0-9, higher = better compression but slower)
pub compression_level: u32,
/// Pretty-print JSON (larger files, more readable)
pub pretty: bool,
}
impl Default for PersistenceOptions {
fn default() -> Self {
Self {
compress: false,
compression_level: 6,
pretty: false,
}
}
}
impl PersistenceOptions {
/// Create options with compression enabled
pub fn compressed() -> Self {
Self {
compress: true,
..Default::default()
}
}
/// Create options with pretty-printed JSON
pub fn pretty() -> Self {
Self {
pretty: true,
..Default::default()
}
}
}
/// Save the OptimizedDiscoveryEngine state to a file
///
/// # Arguments
/// * `engine` - The engine to save
/// * `path` - Path to save to (will be created/overwritten)
/// * `options` - Persistence options (compression, formatting)
///
/// # Example
/// ```no_run
/// # use ruvector_data_framework::optimized::{OptimizedConfig, OptimizedDiscoveryEngine};
/// # use ruvector_data_framework::persistence::{save_engine, PersistenceOptions};
/// # use std::path::Path;
/// let engine = OptimizedDiscoveryEngine::new(OptimizedConfig::default());
/// save_engine(&engine, Path::new("engine_state.json"), &PersistenceOptions::default())?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
pub fn save_engine(
engine: &OptimizedDiscoveryEngine,
path: &Path,
options: &PersistenceOptions,
) -> Result<()> {
// Extract serializable state
let state = extract_state(engine);
// Save to file
save_state(&state, path, options)?;
tracing::info!(
"Saved engine state to {} ({} nodes, {} edges)",
path.display(),
state.nodes.len(),
state.edges.len()
);
Ok(())
}
/// Load an OptimizedDiscoveryEngine from a saved state file
///
/// # Arguments
/// * `path` - Path to the saved state file
///
/// # Returns
/// A new OptimizedDiscoveryEngine with the loaded state
///
/// # Example
/// ```no_run
/// # use ruvector_data_framework::persistence::load_engine;
/// # use std::path::Path;
/// let engine = load_engine(Path::new("engine_state.json"))?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
pub fn load_engine(path: &Path) -> Result<OptimizedDiscoveryEngine> {
let state = load_state(path)?;
tracing::info!(
"Loaded engine state from {} ({} nodes, {} edges)",
path.display(),
state.nodes.len(),
state.edges.len()
);
// Reconstruct engine from state
Ok(reconstruct_engine(state))
}
/// Save discovered patterns to a JSON file
///
/// # Arguments
/// * `patterns` - Patterns to save
/// * `path` - Path to save to
/// * `options` - Persistence options
///
/// # Example
/// ```no_run
/// # use ruvector_data_framework::optimized::SignificantPattern;
/// # use ruvector_data_framework::persistence::{save_patterns, PersistenceOptions};
/// # use std::path::Path;
/// let patterns: Vec<SignificantPattern> = vec![];
/// save_patterns(&patterns, Path::new("patterns.json"), &PersistenceOptions::default())?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
pub fn save_patterns(
patterns: &[SignificantPattern],
path: &Path,
options: &PersistenceOptions,
) -> Result<()> {
let file = File::create(path).map_err(|e| {
FrameworkError::Discovery(format!("Failed to create file {}: {}", path.display(), e))
})?;
let writer = BufWriter::new(file);
if options.compress {
let mut encoder = GzEncoder::new(writer, Compression::new(options.compression_level));
let json = if options.pretty {
serde_json::to_string_pretty(patterns)?
} else {
serde_json::to_string(patterns)?
};
encoder.write_all(json.as_bytes()).map_err(|e| {
FrameworkError::Discovery(format!("Failed to write compressed patterns: {}", e))
})?;
encoder.finish().map_err(|e| {
FrameworkError::Discovery(format!("Failed to finish compression: {}", e))
})?;
} else {
if options.pretty {
serde_json::to_writer_pretty(writer, patterns)?;
} else {
serde_json::to_writer(writer, patterns)?;
}
}
tracing::info!("Saved {} patterns to {}", patterns.len(), path.display());
Ok(())
}
/// Load patterns from a JSON file
///
/// # Arguments
/// * `path` - Path to the patterns file
///
/// # Returns
/// Vector of loaded patterns
///
/// # Example
/// ```no_run
/// # use ruvector_data_framework::persistence::load_patterns;
/// # use std::path::Path;
/// let patterns = load_patterns(Path::new("patterns.json"))?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
pub fn load_patterns(path: &Path) -> Result<Vec<SignificantPattern>> {
let file = File::open(path).map_err(|e| {
FrameworkError::Discovery(format!("Failed to open file {}: {}", path.display(), e))
})?;
let reader = BufReader::new(file);
// Try to detect if file is gzip-compressed by reading magic bytes
let mut peeker = BufReader::new(File::open(path).unwrap());
let mut magic = [0u8; 2];
let is_gzip = peeker.read_exact(&mut magic).is_ok() && magic == [0x1f, 0x8b];
let patterns: Vec<SignificantPattern> = if is_gzip {
let file = File::open(path).unwrap();
let decoder = GzDecoder::new(BufReader::new(file));
serde_json::from_reader(decoder)?
} else {
serde_json::from_reader(reader)?
};
tracing::info!("Loaded {} patterns from {}", patterns.len(), path.display());
Ok(patterns)
}
/// Append new patterns to an existing patterns file
///
/// This is more efficient than loading all patterns, adding new ones,
/// and saving the entire list. However, it only works with uncompressed
/// JSON arrays.
///
/// # Arguments
/// * `patterns` - New patterns to append
/// * `path` - Path to the existing patterns file
///
/// # Note
/// If the file doesn't exist, it will be created with the given patterns.
/// For compressed files, this will decompress, append, and recompress.
///
/// # Example
/// ```no_run
/// # use ruvector_data_framework::optimized::SignificantPattern;
/// # use ruvector_data_framework::persistence::append_patterns;
/// # use std::path::Path;
/// let new_patterns: Vec<SignificantPattern> = vec![];
/// append_patterns(&new_patterns, Path::new("patterns.json"))?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
pub fn append_patterns(patterns: &[SignificantPattern], path: &Path) -> Result<()> {
if patterns.is_empty() {
return Ok(());
}
// Check if file exists
if !path.exists() {
// Create new file
return save_patterns(patterns, path, &PersistenceOptions::default());
}
// Load existing patterns
let mut existing = load_patterns(path)?;
// Append new patterns
existing.extend_from_slice(patterns);
// Save combined patterns
// Preserve compression if original was compressed
let options = if is_compressed(path)? {
PersistenceOptions::compressed()
} else {
PersistenceOptions::default()
};
save_patterns(&existing, path, &options)?;
tracing::info!(
"Appended {} patterns to {} (total: {})",
patterns.len(),
path.display(),
existing.len()
);
Ok(())
}
// ============================================================================
// Internal Helper Functions
// ============================================================================
/// Extract serializable state from an OptimizedDiscoveryEngine
///
/// This uses reflection-like access to the engine's internal state.
/// In practice, you'd need to add getter methods to OptimizedDiscoveryEngine.
fn extract_state(_engine: &OptimizedDiscoveryEngine) -> EngineState {
// Note: This requires the OptimizedDiscoveryEngine to expose its internal state
// via getter methods. For now, we'll use a placeholder that you'll need to implement.
// Since we can't directly access private fields, we need the engine to provide
// a method like `pub fn extract_state(&self) -> EngineState`
// For now, return a minimal state with what we can access
// TODO: Uncomment when OptimizedDiscoveryEngine provides getter methods
// let _stats = engine.stats();
EngineState {
config: OptimizedConfig::default(), // Would need engine.config() method
vectors: Vec::new(), // Would need engine.vectors() method
nodes: HashMap::new(), // Would need engine.nodes() method
edges: Vec::new(), // Would need engine.edges() method
coherence_history: Vec::new(), // Would need engine.coherence_history() method
next_node_id: 0, // Would need engine.next_node_id() method
domain_nodes: HashMap::new(), // Would need engine.domain_nodes() method
domain_timeseries: HashMap::new(), // Would need engine.domain_timeseries() method
saved_at: Utc::now(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
// TODO: Implement proper state extraction once OptimizedDiscoveryEngine
// exposes the necessary getter methods
}
/// Reconstruct an OptimizedDiscoveryEngine from saved state
fn reconstruct_engine(state: EngineState) -> OptimizedDiscoveryEngine {
// Similarly, this would require OptimizedDiscoveryEngine to have
// a constructor like `pub fn from_state(state: EngineState) -> Self`
// For now, create a new engine and note that full reconstruction
// would require additional methods
OptimizedDiscoveryEngine::new(state.config)
// TODO: Implement proper engine reconstruction once OptimizedDiscoveryEngine
// provides the necessary methods to restore state
}
/// Save engine state to a file with optional compression
fn save_state(state: &EngineState, path: &Path, options: &PersistenceOptions) -> Result<()> {
let file = File::create(path).map_err(|e| {
FrameworkError::Discovery(format!("Failed to create file {}: {}", path.display(), e))
})?;
let writer = BufWriter::new(file);
if options.compress {
let mut encoder = GzEncoder::new(writer, Compression::new(options.compression_level));
let json = if options.pretty {
serde_json::to_string_pretty(state)?
} else {
serde_json::to_string(state)?
};
encoder.write_all(json.as_bytes()).map_err(|e| {
FrameworkError::Discovery(format!("Failed to write compressed state: {}", e))
})?;
encoder.finish().map_err(|e| {
FrameworkError::Discovery(format!("Failed to finish compression: {}", e))
})?;
} else {
if options.pretty {
serde_json::to_writer_pretty(writer, state)?;
} else {
serde_json::to_writer(writer, state)?;
}
}
Ok(())
}
/// Load engine state from a file with automatic compression detection
fn load_state(path: &Path) -> Result<EngineState> {
let file = File::open(path).map_err(|e| {
FrameworkError::Discovery(format!("Failed to open file {}: {}", path.display(), e))
})?;
// Detect compression by reading magic bytes
let is_gzip = is_compressed(path)?;
let state = if is_gzip {
let file = File::open(path).unwrap();
let decoder = GzDecoder::new(BufReader::new(file));
serde_json::from_reader(decoder)?
} else {
let reader = BufReader::new(file);
serde_json::from_reader(reader)?
};
Ok(state)
}
/// Check if a file is gzip-compressed by reading magic bytes
fn is_compressed(path: &Path) -> Result<bool> {
let mut file = File::open(path).map_err(|e| {
FrameworkError::Discovery(format!("Failed to open file {}: {}", path.display(), e))
})?;
let mut magic = [0u8; 2];
match file.read_exact(&mut magic) {
Ok(_) => Ok(magic == [0x1f, 0x8b]),
Err(_) => Ok(false), // File too small or empty
}
}
/// Get file size in bytes
pub fn get_file_size(path: &Path) -> Result<u64> {
let metadata = std::fs::metadata(path).map_err(|e| {
FrameworkError::Discovery(format!("Failed to get file metadata: {}", e))
})?;
Ok(metadata.len())
}
/// Calculate compression ratio for a file
///
/// Returns (compressed_size, uncompressed_size, ratio)
pub fn compression_info(path: &Path) -> Result<(u64, u64, f64)> {
let compressed_size = get_file_size(path)?;
if is_compressed(path)? {
// Decompress and count bytes
let file = File::open(path).unwrap();
let mut decoder = GzDecoder::new(BufReader::new(file));
let mut buffer = Vec::new();
let uncompressed_size = decoder.read_to_end(&mut buffer).map_err(|e| {
FrameworkError::Discovery(format!("Failed to decompress: {}", e))
})? as u64;
let ratio = compressed_size as f64 / uncompressed_size as f64;
Ok((compressed_size, uncompressed_size, ratio))
} else {
Ok((compressed_size, compressed_size, 1.0))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimized::OptimizedConfig;
use crate::ruvector_native::{DiscoveredPattern, PatternType, Evidence};
use tempfile::NamedTempFile;
#[test]
fn test_engine_state_creation() {
let config = OptimizedConfig::default();
let state = EngineState::new(config.clone());
assert_eq!(state.next_node_id, 0);
assert_eq!(state.nodes.len(), 0);
assert_eq!(state.config.similarity_threshold, config.similarity_threshold);
}
#[test]
fn test_persistence_options() {
let default = PersistenceOptions::default();
assert!(!default.compress);
assert!(!default.pretty);
let compressed = PersistenceOptions::compressed();
assert!(compressed.compress);
let pretty = PersistenceOptions::pretty();
assert!(pretty.pretty);
}
#[test]
fn test_save_load_patterns() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path();
let patterns = vec![
SignificantPattern {
pattern: DiscoveredPattern {
id: "test-1".to_string(),
pattern_type: PatternType::CoherenceBreak,
confidence: 0.85,
affected_nodes: vec![1, 2, 3],
detected_at: Utc::now(),
description: "Test pattern".to_string(),
evidence: vec![
Evidence {
evidence_type: "test".to_string(),
value: 1.0,
description: "Test evidence".to_string(),
}
],
cross_domain_links: vec![],
},
p_value: 0.03,
effect_size: 1.2,
confidence_interval: (0.5, 1.5),
is_significant: true,
}
];
// Save patterns
save_patterns(&patterns, path, &PersistenceOptions::default()).unwrap();
// Load patterns
let loaded = load_patterns(path).unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].pattern.id, "test-1");
assert_eq!(loaded[0].p_value, 0.03);
}
#[test]
fn test_save_load_patterns_compressed() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path();
let patterns = vec![
SignificantPattern {
pattern: DiscoveredPattern {
id: "test-compressed".to_string(),
pattern_type: PatternType::Consolidation,
confidence: 0.90,
affected_nodes: vec![4, 5, 6],
detected_at: Utc::now(),
description: "Compressed test pattern".to_string(),
evidence: vec![],
cross_domain_links: vec![],
},
p_value: 0.01,
effect_size: 2.0,
confidence_interval: (1.0, 3.0),
is_significant: true,
}
];
// Save with compression
save_patterns(&patterns, path, &PersistenceOptions::compressed()).unwrap();
// Verify compression
assert!(is_compressed(path).unwrap());
// Load and verify
let loaded = load_patterns(path).unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].pattern.id, "test-compressed");
}
#[test]
fn test_append_patterns() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path();
let pattern1 = vec![
SignificantPattern {
pattern: DiscoveredPattern {
id: "pattern-1".to_string(),
pattern_type: PatternType::EmergingCluster,
confidence: 0.75,
affected_nodes: vec![1],
detected_at: Utc::now(),
description: "First pattern".to_string(),
evidence: vec![],
cross_domain_links: vec![],
},
p_value: 0.05,
effect_size: 1.0,
confidence_interval: (0.0, 2.0),
is_significant: false,
}
];
let pattern2 = vec![
SignificantPattern {
pattern: DiscoveredPattern {
id: "pattern-2".to_string(),
pattern_type: PatternType::Cascade,
confidence: 0.95,
affected_nodes: vec![2],
detected_at: Utc::now(),
description: "Second pattern".to_string(),
evidence: vec![],
cross_domain_links: vec![],
},
p_value: 0.001,
effect_size: 3.0,
confidence_interval: (2.0, 4.0),
is_significant: true,
}
];
// Save first pattern
save_patterns(&pattern1, path, &PersistenceOptions::default()).unwrap();
// Append second pattern
append_patterns(&pattern2, path).unwrap();
// Load and verify both are present
let loaded = load_patterns(path).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0].pattern.id, "pattern-1");
assert_eq!(loaded[1].pattern.id, "pattern-2");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,558 @@
//! Real-Time Data Feed Integration
//!
//! RSS/Atom feed parsing, WebSocket streaming, and REST API polling
//! for continuous data ingestion into RuVector discovery framework.
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tokio::time::interval;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Real-time engine for streaming data feeds
pub struct RealTimeEngine {
feeds: Vec<FeedSource>,
update_interval: Duration,
on_new_data: Option<Arc<dyn Fn(Vec<SemanticVector>) + Send + Sync>>,
dedup_cache: Arc<RwLock<HashSet<String>>>,
running: Arc<RwLock<bool>>,
}
/// Types of feed sources
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FeedSource {
/// RSS or Atom feed
Rss { url: String, category: String },
/// REST API with polling
RestPolling { url: String, interval: Duration },
/// WebSocket streaming endpoint
WebSocket { url: String },
}
/// News aggregator for multiple RSS feeds
pub struct NewsAggregator {
sources: Vec<NewsSource>,
client: reqwest::Client,
}
/// Individual news source configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NewsSource {
pub name: String,
pub feed_url: String,
pub domain: Domain,
}
/// Parsed feed item
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedItem {
pub id: String,
pub title: String,
pub description: String,
pub link: String,
pub published: Option<chrono::DateTime<Utc>>,
pub author: Option<String>,
pub categories: Vec<String>,
}
impl RealTimeEngine {
/// Create a new real-time engine
pub fn new(update_interval: Duration) -> Self {
Self {
feeds: Vec::new(),
update_interval,
on_new_data: None,
dedup_cache: Arc::new(RwLock::new(HashSet::new())),
running: Arc::new(RwLock::new(false)),
}
}
/// Add a feed source to monitor
pub fn add_feed(&mut self, source: FeedSource) {
self.feeds.push(source);
}
/// Set callback for new data
pub fn set_callback<F>(&mut self, callback: F)
where
F: Fn(Vec<SemanticVector>) + Send + Sync + 'static,
{
self.on_new_data = Some(Arc::new(callback));
}
/// Start the real-time engine
pub async fn start(&mut self) -> Result<()> {
{
let mut running = self.running.write().await;
if *running {
return Err(FrameworkError::Config(
"Engine already running".to_string(),
));
}
*running = true;
}
let feeds = self.feeds.clone();
let callback = self.on_new_data.clone();
let dedup_cache = self.dedup_cache.clone();
let update_interval = self.update_interval;
let running = self.running.clone();
tokio::spawn(async move {
let mut ticker = interval(update_interval);
loop {
ticker.tick().await;
// Check if we should stop
{
let is_running = running.read().await;
if !*is_running {
break;
}
}
// Process all feeds
for feed in &feeds {
match Self::process_feed(feed, &dedup_cache).await {
Ok(vectors) => {
if !vectors.is_empty() {
if let Some(ref cb) = callback {
cb(vectors);
}
}
}
Err(e) => {
tracing::error!("Feed processing error: {}", e);
}
}
}
}
});
Ok(())
}
/// Stop the real-time engine
pub async fn stop(&mut self) {
let mut running = self.running.write().await;
*running = false;
}
/// Process a single feed source
async fn process_feed(
feed: &FeedSource,
dedup_cache: &Arc<RwLock<HashSet<String>>>,
) -> Result<Vec<SemanticVector>> {
match feed {
FeedSource::Rss { url, category } => {
Self::process_rss_feed(url, category, dedup_cache).await
}
FeedSource::RestPolling { url, .. } => {
Self::process_rest_feed(url, dedup_cache).await
}
FeedSource::WebSocket { url } => Self::process_websocket_feed(url, dedup_cache).await,
}
}
/// Process RSS/Atom feed
async fn process_rss_feed(
url: &str,
category: &str,
dedup_cache: &Arc<RwLock<HashSet<String>>>,
) -> Result<Vec<SemanticVector>> {
let client = reqwest::Client::new();
let response = client.get(url).send().await?;
let content = response.text().await?;
// Parse RSS/Atom feed
let items = Self::parse_rss(&content)?;
let mut vectors = Vec::new();
let mut cache = dedup_cache.write().await;
for item in items {
// Check for duplicates
if cache.contains(&item.id) {
continue;
}
// Add to dedup cache
cache.insert(item.id.clone());
// Convert to SemanticVector
let domain = Self::category_to_domain(category);
let vector = Self::item_to_vector(item, domain);
vectors.push(vector);
}
Ok(vectors)
}
/// Process REST API polling
async fn process_rest_feed(
url: &str,
dedup_cache: &Arc<RwLock<HashSet<String>>>,
) -> Result<Vec<SemanticVector>> {
let client = reqwest::Client::new();
let response = client.get(url).send().await?;
let items: Vec<FeedItem> = response.json().await?;
let mut vectors = Vec::new();
let mut cache = dedup_cache.write().await;
for item in items {
if cache.contains(&item.id) {
continue;
}
cache.insert(item.id.clone());
let vector = Self::item_to_vector(item, Domain::Research);
vectors.push(vector);
}
Ok(vectors)
}
/// Process WebSocket stream (simplified implementation)
async fn process_websocket_feed(
_url: &str,
_dedup_cache: &Arc<RwLock<HashSet<String>>>,
) -> Result<Vec<SemanticVector>> {
// WebSocket implementation would require tokio-tungstenite
// For now, return empty - can be extended with actual WebSocket client
tracing::warn!("WebSocket feeds not yet implemented");
Ok(Vec::new())
}
/// Parse RSS/Atom XML into feed items
fn parse_rss(content: &str) -> Result<Vec<FeedItem>> {
// Simple XML parsing for RSS 2.0
// In production, use feed-rs or rss crate
let mut items = Vec::new();
// Basic RSS parsing (simplified)
for item_block in content.split("<item>").skip(1) {
if let Some(end) = item_block.find("</item>") {
let item_xml = &item_block[..end];
if let Some(item) = Self::parse_rss_item(item_xml) {
items.push(item);
}
}
}
Ok(items)
}
/// Parse a single RSS item from XML
fn parse_rss_item(xml: &str) -> Option<FeedItem> {
let title = Self::extract_tag(xml, "title")?;
let description = Self::extract_tag(xml, "description").unwrap_or_default();
let link = Self::extract_tag(xml, "link").unwrap_or_default();
let guid = Self::extract_tag(xml, "guid").unwrap_or_else(|| link.clone());
let published = Self::extract_tag(xml, "pubDate")
.and_then(|date_str| chrono::DateTime::parse_from_rfc2822(&date_str).ok())
.map(|dt| dt.with_timezone(&Utc));
let author = Self::extract_tag(xml, "author");
Some(FeedItem {
id: guid,
title,
description,
link,
published,
author,
categories: Vec::new(),
})
}
/// Extract content between XML tags
fn extract_tag(xml: &str, tag: &str) -> Option<String> {
let start_tag = format!("<{}>", tag);
let end_tag = format!("</{}>", tag);
let start = xml.find(&start_tag)? + start_tag.len();
let end = xml.find(&end_tag)?;
if start < end {
let content = &xml[start..end];
// Basic HTML entity decoding
let decoded = content
.replace("&lt;", "<")
.replace("&gt;", ">")
.replace("&amp;", "&")
.replace("&quot;", "\"")
.replace("&#39;", "'");
Some(decoded.trim().to_string())
} else {
None
}
}
/// Convert category string to Domain enum
fn category_to_domain(category: &str) -> Domain {
match category.to_lowercase().as_str() {
"climate" | "weather" | "environment" => Domain::Climate,
"finance" | "economy" | "market" | "stock" => Domain::Finance,
"research" | "science" | "academic" | "medical" => Domain::Research,
_ => Domain::CrossDomain,
}
}
/// Convert FeedItem to SemanticVector
fn item_to_vector(item: FeedItem, domain: Domain) -> SemanticVector {
use std::collections::HashMap;
// Create a simple embedding from title + description
// In production, use actual embedding model
let text = format!("{} {}", item.title, item.description);
let embedding = Self::simple_embedding(&text);
let mut metadata = HashMap::new();
metadata.insert("title".to_string(), item.title.clone());
metadata.insert("link".to_string(), item.link.clone());
if let Some(author) = item.author {
metadata.insert("author".to_string(), author);
}
SemanticVector {
id: item.id,
embedding,
domain,
timestamp: item.published.unwrap_or_else(Utc::now),
metadata,
}
}
/// Simple embedding generation (hash-based for demo)
fn simple_embedding(text: &str) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// Create 384-dimensional embedding from text hash
let mut embedding = vec![0.0f32; 384];
for (i, word) in text.split_whitespace().take(384).enumerate() {
let mut hasher = DefaultHasher::new();
word.hash(&mut hasher);
let hash = hasher.finish();
embedding[i] = (hash as f32 / u64::MAX as f32) * 2.0 - 1.0;
}
// Normalize
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for val in &mut embedding {
*val /= magnitude;
}
}
embedding
}
}
impl NewsAggregator {
/// Create a new news aggregator
pub fn new() -> Self {
Self {
sources: Vec::new(),
client: reqwest::Client::builder()
.user_agent("RuVector/1.0")
.timeout(Duration::from_secs(30))
.build()
.unwrap(),
}
}
/// Add a news source
pub fn add_source(&mut self, source: NewsSource) {
self.sources.push(source);
}
/// Add default free news sources
pub fn add_default_sources(&mut self) {
// Climate sources
self.add_source(NewsSource {
name: "NASA Earth Observatory".to_string(),
feed_url: "https://earthobservatory.nasa.gov/feeds/image-of-the-day.rss".to_string(),
domain: Domain::Climate,
});
// Financial sources
self.add_source(NewsSource {
name: "Yahoo Finance - Top Stories".to_string(),
feed_url: "https://finance.yahoo.com/news/rssindex".to_string(),
domain: Domain::Finance,
});
// Medical/Research sources
self.add_source(NewsSource {
name: "PubMed Recent".to_string(),
feed_url: "https://pubmed.ncbi.nlm.nih.gov/rss/search/1nKx2zx8g-9UCGpQD5qVmN6jTvSRRxYqjD3T_nA-pSMjDlXr4u/?limit=100&utm_campaign=pubmed-2&fc=20210421200858".to_string(),
domain: Domain::Research,
});
// General news sources
self.add_source(NewsSource {
name: "Reuters Top News".to_string(),
feed_url: "https://www.reutersagency.com/feed/?taxonomy=best-topics&post_type=best".to_string(),
domain: Domain::CrossDomain,
});
self.add_source(NewsSource {
name: "AP News Top Stories".to_string(),
feed_url: "https://apnews.com/index.rss".to_string(),
domain: Domain::CrossDomain,
});
}
/// Fetch latest items from all sources
pub async fn fetch_latest(&self, limit: usize) -> Result<Vec<SemanticVector>> {
let mut all_vectors = Vec::new();
let mut seen = HashSet::new();
for source in &self.sources {
match self.fetch_source(source, limit).await {
Ok(vectors) => {
for vector in vectors {
if !seen.contains(&vector.id) {
seen.insert(vector.id.clone());
all_vectors.push(vector);
}
}
}
Err(e) => {
tracing::warn!("Failed to fetch {}: {}", source.name, e);
}
}
}
// Sort by timestamp, most recent first
all_vectors.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
// Limit results
all_vectors.truncate(limit);
Ok(all_vectors)
}
/// Fetch from a single source
async fn fetch_source(&self, source: &NewsSource, limit: usize) -> Result<Vec<SemanticVector>> {
let response = self.client.get(&source.feed_url).send().await?;
let content = response.text().await?;
let items = RealTimeEngine::parse_rss(&content)?;
let mut vectors = Vec::new();
for item in items.into_iter().take(limit) {
let vector = RealTimeEngine::item_to_vector(item, source.domain);
vectors.push(vector);
}
Ok(vectors)
}
}
impl Default for NewsAggregator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_tag() {
let xml = "<title>Test Title</title><description>Test Description</description>";
assert_eq!(
RealTimeEngine::extract_tag(xml, "title"),
Some("Test Title".to_string())
);
assert_eq!(
RealTimeEngine::extract_tag(xml, "description"),
Some("Test Description".to_string())
);
assert_eq!(RealTimeEngine::extract_tag(xml, "missing"), None);
}
#[test]
fn test_category_to_domain() {
assert_eq!(
RealTimeEngine::category_to_domain("climate"),
Domain::Climate
);
assert_eq!(
RealTimeEngine::category_to_domain("Finance"),
Domain::Finance
);
assert_eq!(
RealTimeEngine::category_to_domain("research"),
Domain::Research
);
assert_eq!(
RealTimeEngine::category_to_domain("other"),
Domain::CrossDomain
);
}
#[test]
fn test_simple_embedding() {
let embedding = RealTimeEngine::simple_embedding("climate change impacts");
assert_eq!(embedding.len(), 384);
// Check normalization
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.01);
}
#[tokio::test]
async fn test_realtime_engine_lifecycle() {
let mut engine = RealTimeEngine::new(Duration::from_secs(1));
engine.add_feed(FeedSource::Rss {
url: "https://example.com/feed.rss".to_string(),
category: "climate".to_string(),
});
// Start and stop
assert!(engine.start().await.is_ok());
engine.stop().await;
}
#[test]
fn test_news_aggregator() {
let mut aggregator = NewsAggregator::new();
aggregator.add_default_sources();
assert!(aggregator.sources.len() >= 5);
}
#[test]
fn test_parse_rss_item() {
let xml = r#"
<title>Test Article</title>
<description>This is a test article</description>
<link>https://example.com/article</link>
<guid>article-123</guid>
<pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate>
"#;
let item = RealTimeEngine::parse_rss_item(xml);
assert!(item.is_some());
let item = item.unwrap();
assert_eq!(item.title, "Test Article");
assert_eq!(item.description, "This is a test article");
assert_eq!(item.link, "https://example.com/article");
assert_eq!(item.id, "article-123");
}
}

View File

@@ -0,0 +1,854 @@
//! RuVector-Native Discovery Engine
//!
//! Deep integration with ruvector-core, ruvector-graph, and ruvector-mincut
//! for production-grade coherence analysis and pattern discovery.
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::utils::cosine_similarity;
/// Vector embedding for semantic similarity
/// Uses RuVector's native vector storage format
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticVector {
/// Vector ID
pub id: String,
/// Dense embedding (typically 384-1536 dimensions)
pub embedding: Vec<f32>,
/// Source domain
pub domain: Domain,
/// Timestamp
pub timestamp: DateTime<Utc>,
/// Metadata for filtering
pub metadata: HashMap<String, String>,
}
/// Discovery domains
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum Domain {
Climate,
Finance,
Research,
Medical,
Economic,
Genomics,
Physics,
Seismic,
Ocean,
Space,
Transportation,
Geospatial,
Government,
CrossDomain,
}
/// RuVector-native graph node
/// Designed to work with ruvector-graph's adjacency structures
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphNode {
/// Node ID (u32 for ruvector compatibility)
pub id: u32,
/// String identifier for external reference
pub external_id: String,
/// Domain
pub domain: Domain,
/// Associated vector embedding index
pub vector_idx: Option<usize>,
/// Node weight (for weighted min-cut)
pub weight: f64,
/// Attributes
pub attributes: HashMap<String, f64>,
}
/// RuVector-native graph edge
/// Compatible with ruvector-mincut's edge format
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphEdge {
/// Source node ID
pub source: u32,
/// Target node ID
pub target: u32,
/// Edge weight (capacity for min-cut)
pub weight: f64,
/// Edge type
pub edge_type: EdgeType,
/// Timestamp when edge was created/updated
pub timestamp: DateTime<Utc>,
}
/// Types of edges in the discovery graph
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum EdgeType {
/// Correlation-based (e.g., temperature correlation)
Correlation,
/// Similarity-based (e.g., vector cosine similarity)
Similarity,
/// Citation/reference link
Citation,
/// Causal relationship
Causal,
/// Cross-domain bridge
CrossDomain,
}
/// Configuration for the native discovery engine
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NativeEngineConfig {
/// Minimum edge weight to include
pub min_edge_weight: f64,
/// Vector similarity threshold
pub similarity_threshold: f64,
/// Min-cut sensitivity (lower = more sensitive to breaks)
pub mincut_sensitivity: f64,
/// Enable cross-domain discovery
pub cross_domain: bool,
/// Window size for temporal analysis (seconds)
pub window_seconds: i64,
/// HNSW parameters
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub hnsw_ef_search: usize,
/// Vector dimension
pub dimension: usize,
/// Batch size for processing
pub batch_size: usize,
/// Checkpoint interval (records)
pub checkpoint_interval: u64,
/// Number of parallel workers
pub parallel_workers: usize,
}
impl Default for NativeEngineConfig {
fn default() -> Self {
Self {
min_edge_weight: 0.3,
similarity_threshold: 0.7,
mincut_sensitivity: 0.15,
cross_domain: true,
window_seconds: 86400 * 30, // 30 days
hnsw_m: 16,
hnsw_ef_construction: 200,
hnsw_ef_search: 50,
dimension: 384,
batch_size: 1000,
checkpoint_interval: 10_000,
parallel_workers: 4,
}
}
}
/// The main RuVector-native discovery engine
///
/// This engine uses RuVector's core algorithms:
/// - Vector similarity via HNSW index
/// - Graph coherence via Stoer-Wagner min-cut
/// - Temporal windowing for streaming analysis
pub struct NativeDiscoveryEngine {
config: NativeEngineConfig,
/// Vector storage (would use ruvector-core in production)
vectors: Vec<SemanticVector>,
/// Graph nodes
nodes: HashMap<u32, GraphNode>,
/// Graph edges (adjacency list format for ruvector-mincut)
edges: Vec<GraphEdge>,
/// Historical coherence values for change detection
coherence_history: Vec<(DateTime<Utc>, f64, CoherenceSnapshot)>,
/// Next node ID
next_node_id: u32,
/// Domain-specific subgraph indices
domain_nodes: HashMap<Domain, Vec<u32>>,
}
/// Snapshot of coherence state for historical comparison
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceSnapshot {
/// Min-cut value
pub mincut_value: f64,
/// Number of nodes
pub node_count: usize,
/// Number of edges
pub edge_count: usize,
/// Partition sizes after min-cut
pub partition_sizes: (usize, usize),
/// Boundary nodes (nodes on the cut)
pub boundary_nodes: Vec<u32>,
/// Average edge weight
pub avg_edge_weight: f64,
}
/// A detected pattern or anomaly
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveredPattern {
/// Pattern ID
pub id: String,
/// Pattern type
pub pattern_type: PatternType,
/// Confidence score (0-1)
pub confidence: f64,
/// Affected nodes
pub affected_nodes: Vec<u32>,
/// Timestamp of detection
pub detected_at: DateTime<Utc>,
/// Description
pub description: String,
/// Evidence
pub evidence: Vec<Evidence>,
/// Cross-domain connections if applicable
pub cross_domain_links: Vec<CrossDomainLink>,
}
/// Types of discoverable patterns
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum PatternType {
/// Network coherence break (min-cut dropped)
CoherenceBreak,
/// Network consolidation (min-cut increased)
Consolidation,
/// Emerging cluster (new dense subgraph)
EmergingCluster,
/// Dissolving cluster
DissolvingCluster,
/// Bridge formation (cross-domain connection)
BridgeFormation,
/// Anomalous node (outlier in vector space)
AnomalousNode,
/// Temporal shift (pattern change over time)
TemporalShift,
/// Cascade (change propagating through network)
Cascade,
}
/// Evidence supporting a pattern detection
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Evidence {
pub evidence_type: String,
pub value: f64,
pub description: String,
}
/// Cross-domain link discovered
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossDomainLink {
pub source_domain: Domain,
pub target_domain: Domain,
pub source_nodes: Vec<u32>,
pub target_nodes: Vec<u32>,
pub link_strength: f64,
pub link_type: String,
}
impl NativeDiscoveryEngine {
/// Create a new engine with the given configuration
pub fn new(config: NativeEngineConfig) -> Self {
Self {
config,
vectors: Vec::new(),
nodes: HashMap::new(),
edges: Vec::new(),
coherence_history: Vec::new(),
next_node_id: 0,
domain_nodes: HashMap::new(),
}
}
/// Add a vector to the engine
/// In production, this would use ruvector-core's vector storage
pub fn add_vector(&mut self, vector: SemanticVector) -> u32 {
let node_id = self.next_node_id;
self.next_node_id += 1;
let vector_idx = self.vectors.len();
self.vectors.push(vector.clone());
let node = GraphNode {
id: node_id,
external_id: vector.id.clone(),
domain: vector.domain,
vector_idx: Some(vector_idx),
weight: 1.0,
attributes: HashMap::new(),
};
self.nodes.insert(node_id, node);
self.domain_nodes.entry(vector.domain).or_default().push(node_id);
// Auto-connect to similar vectors
self.connect_similar_vectors(node_id);
node_id
}
/// Connect a node to similar vectors using cosine similarity
/// In production, this would use ruvector-hnsw for O(log n) search
fn connect_similar_vectors(&mut self, node_id: u32) {
let node = match self.nodes.get(&node_id) {
Some(n) => n.clone(),
None => return,
};
let vector_idx = match node.vector_idx {
Some(idx) => idx,
None => return,
};
let source_vec = &self.vectors[vector_idx].embedding;
// Find similar vectors (brute force - would use HNSW in production)
for (other_id, other_node) in &self.nodes {
if *other_id == node_id {
continue;
}
if let Some(other_idx) = other_node.vector_idx {
let other_vec = &self.vectors[other_idx].embedding;
let similarity = cosine_similarity(source_vec, other_vec);
if similarity >= self.config.similarity_threshold as f32 {
// Determine edge type
let edge_type = if node.domain != other_node.domain {
EdgeType::CrossDomain
} else {
EdgeType::Similarity
};
self.edges.push(GraphEdge {
source: node_id,
target: *other_id,
weight: similarity as f64,
edge_type,
timestamp: Utc::now(),
});
}
}
}
}
/// Add a correlation-based edge
pub fn add_correlation_edge(&mut self, source: u32, target: u32, correlation: f64) {
if correlation.abs() >= self.config.min_edge_weight {
self.edges.push(GraphEdge {
source,
target,
weight: correlation.abs(),
edge_type: EdgeType::Correlation,
timestamp: Utc::now(),
});
}
}
/// Compute current coherence using Stoer-Wagner min-cut
///
/// The min-cut value represents the "weakest link" in the network.
/// A drop in min-cut indicates the network is becoming fragmented.
pub fn compute_coherence(&self) -> CoherenceSnapshot {
if self.nodes.is_empty() || self.edges.is_empty() {
return CoherenceSnapshot {
mincut_value: 0.0,
node_count: self.nodes.len(),
edge_count: self.edges.len(),
partition_sizes: (0, 0),
boundary_nodes: vec![],
avg_edge_weight: 0.0,
};
}
// Build adjacency matrix for min-cut
// In production, this would call ruvector-mincut directly
let mincut_result = self.stoer_wagner_mincut();
let avg_edge_weight = if self.edges.is_empty() {
0.0
} else {
self.edges.iter().map(|e| e.weight).sum::<f64>() / self.edges.len() as f64
};
CoherenceSnapshot {
mincut_value: mincut_result.0,
node_count: self.nodes.len(),
edge_count: self.edges.len(),
partition_sizes: mincut_result.1,
boundary_nodes: mincut_result.2,
avg_edge_weight,
}
}
/// Stoer-Wagner minimum cut algorithm
/// Returns (min_cut_value, partition_sizes, boundary_nodes)
fn stoer_wagner_mincut(&self) -> (f64, (usize, usize), Vec<u32>) {
let n = self.nodes.len();
if n < 2 {
return (0.0, (n, 0), vec![]);
}
// Build adjacency matrix
let node_ids: Vec<u32> = self.nodes.keys().copied().collect();
let id_to_idx: HashMap<u32, usize> = node_ids.iter()
.enumerate()
.map(|(i, &id)| (id, i))
.collect();
let mut adj = vec![vec![0.0; n]; n];
for edge in &self.edges {
if let (Some(&i), Some(&j)) = (id_to_idx.get(&edge.source), id_to_idx.get(&edge.target)) {
adj[i][j] += edge.weight;
adj[j][i] += edge.weight;
}
}
// Stoer-Wagner algorithm
let mut best_cut = f64::INFINITY;
let mut best_partition = (0, 0);
let mut best_boundary = vec![];
let mut active: Vec<bool> = vec![true; n];
let mut merged: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
for phase in 0..(n - 1) {
// Maximum adjacency search
let mut in_a = vec![false; n];
let mut key = vec![0.0; n];
// Find first active node
let start = (0..n).find(|&i| active[i]).unwrap();
in_a[start] = true;
// Update keys
for j in 0..n {
if active[j] && !in_a[j] {
key[j] = adj[start][j];
}
}
let mut s = start;
let mut t = start;
for _ in 1..=(n - 1 - phase) {
// Find max key not in A
let mut max_key = f64::NEG_INFINITY;
let mut max_node = 0;
for j in 0..n {
if active[j] && !in_a[j] && key[j] > max_key {
max_key = key[j];
max_node = j;
}
}
s = t;
t = max_node;
in_a[t] = true;
// Update keys
for j in 0..n {
if active[j] && !in_a[j] {
key[j] += adj[t][j];
}
}
}
// Cut of the phase
let cut_weight = key[t];
if cut_weight < best_cut {
best_cut = cut_weight;
// Partition is: merged[t] vs everything else
let partition_a: Vec<usize> = merged[t].clone();
let partition_b: Vec<usize> = (0..n)
.filter(|&i| active[i] && i != t)
.flat_map(|i| merged[i].iter().copied())
.collect();
best_partition = (partition_a.len(), partition_b.len());
// Boundary nodes are those in the smaller partition with edges to other
best_boundary = partition_a.iter()
.map(|&i| node_ids[i])
.collect();
}
// Merge s and t
active[t] = false;
let to_merge: Vec<usize> = merged[t].clone();
merged[s].extend(to_merge);
for i in 0..n {
if active[i] && i != s {
adj[s][i] += adj[t][i];
adj[i][s] += adj[i][t];
}
}
}
(best_cut, best_partition, best_boundary)
}
/// Detect patterns by comparing current state to history
pub fn detect_patterns(&mut self) -> Vec<DiscoveredPattern> {
let mut patterns = Vec::new();
let current = self.compute_coherence();
let now = Utc::now();
// Compare to previous state
if let Some((prev_time, prev_mincut, prev_snapshot)) = self.coherence_history.last() {
let mincut_delta = current.mincut_value - prev_mincut;
let relative_change = if *prev_mincut > 0.0 {
mincut_delta.abs() / prev_mincut
} else {
mincut_delta.abs()
};
// Detect coherence break
if mincut_delta < -self.config.mincut_sensitivity {
patterns.push(DiscoveredPattern {
id: format!("coherence_break_{}", now.timestamp()),
pattern_type: PatternType::CoherenceBreak,
confidence: (relative_change.min(1.0) * 0.5 + 0.5),
affected_nodes: current.boundary_nodes.clone(),
detected_at: now,
description: format!(
"Network coherence dropped from {:.3} to {:.3} ({:.1}% decrease)",
prev_mincut, current.mincut_value, relative_change * 100.0
),
evidence: vec![
Evidence {
evidence_type: "mincut_delta".to_string(),
value: mincut_delta,
description: "Change in min-cut value".to_string(),
},
Evidence {
evidence_type: "boundary_size".to_string(),
value: current.boundary_nodes.len() as f64,
description: "Number of nodes on the cut".to_string(),
},
],
cross_domain_links: self.find_cross_domain_at_boundary(&current.boundary_nodes),
});
}
// Detect consolidation
if mincut_delta > self.config.mincut_sensitivity {
patterns.push(DiscoveredPattern {
id: format!("consolidation_{}", now.timestamp()),
pattern_type: PatternType::Consolidation,
confidence: (relative_change.min(1.0) * 0.5 + 0.5),
affected_nodes: current.boundary_nodes.clone(),
detected_at: now,
description: format!(
"Network coherence increased from {:.3} to {:.3} ({:.1}% increase)",
prev_mincut, current.mincut_value, relative_change * 100.0
),
evidence: vec![
Evidence {
evidence_type: "mincut_delta".to_string(),
value: mincut_delta,
description: "Change in min-cut value".to_string(),
},
],
cross_domain_links: vec![],
});
}
// Detect partition imbalance (emerging cluster)
let (part_a, part_b) = current.partition_sizes;
let imbalance = (part_a as f64 - part_b as f64).abs() / (part_a + part_b) as f64;
let (prev_a, prev_b) = prev_snapshot.partition_sizes;
let prev_imbalance = if prev_a + prev_b > 0 {
(prev_a as f64 - prev_b as f64).abs() / (prev_a + prev_b) as f64
} else {
0.0
};
if imbalance > prev_imbalance + 0.2 {
patterns.push(DiscoveredPattern {
id: format!("emerging_cluster_{}", now.timestamp()),
pattern_type: PatternType::EmergingCluster,
confidence: 0.7,
affected_nodes: current.boundary_nodes.clone(),
detected_at: now,
description: format!(
"Partition imbalance increased: {} vs {} nodes (was {} vs {})",
part_a, part_b, prev_a, prev_b
),
evidence: vec![],
cross_domain_links: vec![],
});
}
}
// Cross-domain pattern detection
if self.config.cross_domain {
patterns.extend(self.detect_cross_domain_patterns());
}
// Store current state in history
self.coherence_history.push((now, current.mincut_value, current));
patterns
}
/// Find cross-domain links at boundary nodes
fn find_cross_domain_at_boundary(&self, boundary: &[u32]) -> Vec<CrossDomainLink> {
let mut links = Vec::new();
// Find cross-domain edges involving boundary nodes
for edge in &self.edges {
if edge.edge_type == EdgeType::CrossDomain {
if boundary.contains(&edge.source) || boundary.contains(&edge.target) {
if let (Some(src_node), Some(tgt_node)) =
(self.nodes.get(&edge.source), self.nodes.get(&edge.target))
{
links.push(CrossDomainLink {
source_domain: src_node.domain,
target_domain: tgt_node.domain,
source_nodes: vec![edge.source],
target_nodes: vec![edge.target],
link_strength: edge.weight,
link_type: "boundary_crossing".to_string(),
});
}
}
}
}
links
}
/// Detect patterns that span multiple domains
fn detect_cross_domain_patterns(&self) -> Vec<DiscoveredPattern> {
let mut patterns = Vec::new();
// Count cross-domain edges by domain pair
let mut cross_counts: HashMap<(Domain, Domain), Vec<&GraphEdge>> = HashMap::new();
for edge in &self.edges {
if edge.edge_type == EdgeType::CrossDomain {
if let (Some(src), Some(tgt)) =
(self.nodes.get(&edge.source), self.nodes.get(&edge.target))
{
let key = if src.domain < tgt.domain {
(src.domain, tgt.domain)
} else {
(tgt.domain, src.domain)
};
cross_counts.entry(key).or_default().push(edge);
}
}
}
// Report significant cross-domain bridges
for ((domain_a, domain_b), edges) in cross_counts {
if edges.len() >= 3 {
let avg_strength = edges.iter().map(|e| e.weight).sum::<f64>() / edges.len() as f64;
if avg_strength > self.config.similarity_threshold as f64 {
patterns.push(DiscoveredPattern {
id: format!("bridge_{:?}_{:?}_{}", domain_a, domain_b, Utc::now().timestamp()),
pattern_type: PatternType::BridgeFormation,
confidence: avg_strength,
affected_nodes: edges.iter()
.flat_map(|e| vec![e.source, e.target])
.collect(),
detected_at: Utc::now(),
description: format!(
"Cross-domain bridge detected: {:?} ↔ {:?} ({} connections, avg strength {:.3})",
domain_a, domain_b, edges.len(), avg_strength
),
evidence: vec![
Evidence {
evidence_type: "edge_count".to_string(),
value: edges.len() as f64,
description: "Number of cross-domain connections".to_string(),
},
],
cross_domain_links: vec![CrossDomainLink {
source_domain: domain_a,
target_domain: domain_b,
source_nodes: edges.iter().map(|e| e.source).collect(),
target_nodes: edges.iter().map(|e| e.target).collect(),
link_strength: avg_strength,
link_type: "semantic_bridge".to_string(),
}],
});
}
}
}
patterns
}
/// Get domain-specific coherence
pub fn domain_coherence(&self, domain: Domain) -> Option<f64> {
let domain_node_ids = self.domain_nodes.get(&domain)?;
if domain_node_ids.len() < 2 {
return None;
}
// Count edges within domain
let mut internal_weight = 0.0;
let mut edge_count = 0;
for edge in &self.edges {
if domain_node_ids.contains(&edge.source) && domain_node_ids.contains(&edge.target) {
internal_weight += edge.weight;
edge_count += 1;
}
}
if edge_count == 0 {
return Some(0.0);
}
Some(internal_weight / edge_count as f64)
}
/// Get statistics about the current state
pub fn stats(&self) -> EngineStats {
let mut domain_counts = HashMap::new();
for domain in self.domain_nodes.keys() {
domain_counts.insert(*domain, self.domain_nodes[domain].len());
}
let mut cross_domain_edges = 0;
for edge in &self.edges {
if edge.edge_type == EdgeType::CrossDomain {
cross_domain_edges += 1;
}
}
EngineStats {
total_nodes: self.nodes.len(),
total_edges: self.edges.len(),
total_vectors: self.vectors.len(),
domain_counts,
cross_domain_edges,
history_length: self.coherence_history.len(),
}
}
/// Get all detected patterns from the latest detection run
pub fn get_patterns(&self) -> Vec<DiscoveredPattern> {
// For now, return an empty vec. In production, this would store
// patterns from the last detect_patterns() call
vec![]
}
/// Export the current graph structure
pub fn export_graph(&self) -> GraphExport {
GraphExport {
nodes: self.nodes.values().cloned().collect(),
edges: self.edges.clone(),
domains: self.domain_nodes.clone(),
}
}
/// Get the coherence history
pub fn get_coherence_history(&self) -> Vec<CoherenceHistoryEntry> {
self.coherence_history.iter()
.map(|(timestamp, mincut, snapshot)| {
CoherenceHistoryEntry {
timestamp: *timestamp,
mincut_value: *mincut,
snapshot: snapshot.clone(),
}
})
.collect()
}
}
/// Engine statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EngineStats {
pub total_nodes: usize,
pub total_edges: usize,
pub total_vectors: usize,
pub domain_counts: HashMap<Domain, usize>,
pub cross_domain_edges: usize,
pub history_length: usize,
}
/// Exported graph structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphExport {
pub nodes: Vec<GraphNode>,
pub edges: Vec<GraphEdge>,
pub domains: HashMap<Domain, Vec<u32>>,
}
/// Coherence history entry
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceHistoryEntry {
pub timestamp: DateTime<Utc>,
pub mincut_value: f64,
pub snapshot: CoherenceSnapshot,
}
// Note: cosine_similarity is imported from crate::utils
// Implement ordering for Domain to use in HashMap keys
impl PartialOrd for Domain {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Domain {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(*self as u8).cmp(&(*other as u8))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c)).abs() < 0.001);
}
#[test]
fn test_engine_basic() {
let config = NativeEngineConfig::default();
let mut engine = NativeDiscoveryEngine::new(config);
// Add some vectors
let v1 = SemanticVector {
id: "climate_1".to_string(),
embedding: vec![1.0, 0.5, 0.2],
domain: Domain::Climate,
timestamp: Utc::now(),
metadata: HashMap::new(),
};
let v2 = SemanticVector {
id: "climate_2".to_string(),
embedding: vec![0.9, 0.6, 0.3],
domain: Domain::Climate,
timestamp: Utc::now(),
metadata: HashMap::new(),
};
engine.add_vector(v1);
engine.add_vector(v2);
let stats = engine.stats();
assert_eq!(stats.total_nodes, 2);
assert_eq!(stats.total_vectors, 2);
}
}

View File

@@ -0,0 +1,841 @@
//! Semantic Scholar API Integration
//!
//! This module provides an async client for fetching academic papers from Semantic Scholar,
//! converting responses to SemanticVector format for RuVector discovery.
//!
//! # Semantic Scholar API Details
//! - Base URL: https://api.semanticscholar.org/graph/v1
//! - Free tier: 100 requests per 5 minutes without API key
//! - With API key: Higher limits (contact Semantic Scholar)
//! - Returns JSON responses
//!
//! # Example
//! ```rust,ignore
//! use ruvector_data_framework::semantic_scholar::SemanticScholarClient;
//!
//! let client = SemanticScholarClient::new(None); // No API key
//!
//! // Search papers by keywords
//! let vectors = client.search_papers("machine learning", 10).await?;
//!
//! // Get paper details
//! let paper = client.get_paper("649def34f8be52c8b66281af98ae884c09aef38b").await?;
//!
//! // Get citations
//! let citations = client.get_citations("649def34f8be52c8b66281af98ae884c09aef38b", 20).await?;
//!
//! // Search by field of study
//! let cs_papers = client.search_by_field("Computer Science", 50).await?;
//! ```
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, NaiveDate, Utc};
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
use crate::api_clients::SimpleEmbedder;
use crate::ruvector_native::{Domain, SemanticVector};
use crate::{FrameworkError, Result};
/// Rate limiting configuration for Semantic Scholar API
const S2_RATE_LIMIT_MS: u64 = 3000; // 3 seconds between requests (100 req / 5 min = ~20 req/min = 3s/req)
const S2_WITH_KEY_RATE_LIMIT_MS: u64 = 200; // More aggressive with API key
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 2000;
const DEFAULT_EMBEDDING_DIM: usize = 384;
// ============================================================================
// Semantic Scholar API Response Structures
// ============================================================================
/// Search response from Semantic Scholar
#[derive(Debug, Deserialize)]
struct SearchResponse {
#[serde(default)]
total: Option<i32>,
#[serde(default)]
offset: Option<i32>,
#[serde(default)]
next: Option<i32>,
#[serde(default)]
data: Vec<PaperData>,
}
/// Paper data structure
#[derive(Debug, Clone, Deserialize, Serialize)]
struct PaperData {
#[serde(rename = "paperId")]
paper_id: String,
#[serde(default)]
title: Option<String>,
#[serde(rename = "abstract", default)]
abstract_text: Option<String>,
#[serde(default)]
year: Option<i32>,
#[serde(rename = "citationCount", default)]
citation_count: Option<i32>,
#[serde(rename = "referenceCount", default)]
reference_count: Option<i32>,
#[serde(rename = "influentialCitationCount", default)]
influential_citation_count: Option<i32>,
#[serde(default)]
authors: Vec<AuthorData>,
#[serde(rename = "fieldsOfStudy", default)]
fields_of_study: Vec<String>,
#[serde(default)]
venue: Option<String>,
#[serde(rename = "publicationVenue", default)]
publication_venue: Option<PublicationVenue>,
#[serde(default)]
url: Option<String>,
#[serde(rename = "openAccessPdf", default)]
open_access_pdf: Option<OpenAccessPdf>,
}
/// Author information
#[derive(Debug, Clone, Deserialize, Serialize)]
struct AuthorData {
#[serde(rename = "authorId", default)]
author_id: Option<String>,
#[serde(default)]
name: Option<String>,
}
/// Publication venue details
#[derive(Debug, Clone, Deserialize, Serialize)]
struct PublicationVenue {
#[serde(default)]
name: Option<String>,
#[serde(rename = "type", default)]
venue_type: Option<String>,
}
/// Open access PDF information
#[derive(Debug, Clone, Deserialize, Serialize)]
struct OpenAccessPdf {
#[serde(default)]
url: Option<String>,
#[serde(default)]
status: Option<String>,
}
/// Citation/reference response
#[derive(Debug, Deserialize)]
struct CitationResponse {
#[serde(default)]
offset: Option<i32>,
#[serde(default)]
next: Option<i32>,
#[serde(default)]
data: Vec<CitationData>,
}
/// Citation data wrapper
#[derive(Debug, Deserialize)]
struct CitationData {
#[serde(rename = "citingPaper", default)]
citing_paper: Option<PaperData>,
#[serde(rename = "citedPaper", default)]
cited_paper: Option<PaperData>,
}
/// Author details response
#[derive(Debug, Deserialize)]
struct AuthorResponse {
#[serde(rename = "authorId")]
author_id: String,
#[serde(default)]
name: Option<String>,
#[serde(rename = "paperCount", default)]
paper_count: Option<i32>,
#[serde(rename = "citationCount", default)]
citation_count: Option<i32>,
#[serde(rename = "hIndex", default)]
h_index: Option<i32>,
#[serde(default)]
papers: Vec<PaperData>,
}
// ============================================================================
// Semantic Scholar Client
// ============================================================================
/// Client for Semantic Scholar API
///
/// Provides methods to search for academic papers, retrieve citations and references,
/// filter by fields of study, and convert results to SemanticVector format for RuVector analysis.
///
/// # Rate Limiting
/// The client automatically enforces rate limits:
/// - Without API key: 100 requests per 5 minutes (3 seconds between requests)
/// - With API key: Higher limits (200ms between requests)
///
/// # API Key
/// Set the `SEMANTIC_SCHOLAR_API_KEY` environment variable to use authenticated requests.
pub struct SemanticScholarClient {
client: Client,
embedder: Arc<SimpleEmbedder>,
base_url: String,
api_key: Option<String>,
rate_limit_delay: Duration,
}
impl SemanticScholarClient {
/// Create a new Semantic Scholar API client
///
/// # Arguments
/// * `api_key` - Optional API key. If None, checks SEMANTIC_SCHOLAR_API_KEY env var
///
/// # Example
/// ```rust,ignore
/// // Without API key
/// let client = SemanticScholarClient::new(None);
///
/// // With API key
/// let client = SemanticScholarClient::new(Some("your-api-key".to_string()));
/// ```
pub fn new(api_key: Option<String>) -> Self {
Self::with_embedding_dim(api_key, DEFAULT_EMBEDDING_DIM)
}
/// Create a new client with custom embedding dimension
///
/// # Arguments
/// * `api_key` - Optional API key
/// * `embedding_dim` - Dimension for text embeddings (default: 384)
pub fn with_embedding_dim(api_key: Option<String>, embedding_dim: usize) -> Self {
// Try API key from parameter, then environment variable
let api_key = api_key.or_else(|| env::var("SEMANTIC_SCHOLAR_API_KEY").ok());
let rate_limit_delay = if api_key.is_some() {
Duration::from_millis(S2_WITH_KEY_RATE_LIMIT_MS)
} else {
Duration::from_millis(S2_RATE_LIMIT_MS)
};
Self {
client: Client::builder()
.user_agent("RuVector-Discovery/1.0")
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client"),
embedder: Arc::new(SimpleEmbedder::new(embedding_dim)),
base_url: "https://api.semanticscholar.org/graph/v1".to_string(),
api_key,
rate_limit_delay,
}
}
/// Search papers by keywords
///
/// # Arguments
/// * `query` - Search query (keywords, title, etc.)
/// * `limit` - Maximum number of results to return (max 100 per request)
///
/// # Example
/// ```rust,ignore
/// let vectors = client.search_papers("deep learning transformers", 50).await?;
/// ```
pub async fn search_papers(&self, query: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let limit = limit.min(100); // API limit
let encoded_query = urlencoding::encode(query);
let url = format!(
"{}/paper/search?query={}&limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,influentialCitationCount,authors,fieldsOfStudy,venue,publicationVenue,url,openAccessPdf",
self.base_url, encoded_query, limit
);
let response: SearchResponse = self.fetch_json(&url).await?;
let mut vectors = Vec::new();
for paper in response.data {
if let Some(vector) = self.paper_to_vector(paper) {
vectors.push(vector);
}
}
Ok(vectors)
}
/// Get a single paper by Semantic Scholar paper ID
///
/// # Arguments
/// * `paper_id` - Semantic Scholar paper ID (e.g., "649def34f8be52c8b66281af98ae884c09aef38b")
///
/// # Example
/// ```rust,ignore
/// let paper = client.get_paper("649def34f8be52c8b66281af98ae884c09aef38b").await?;
/// ```
pub async fn get_paper(&self, paper_id: &str) -> Result<Option<SemanticVector>> {
let url = format!(
"{}/paper/{}?fields=paperId,title,abstract,year,citationCount,referenceCount,influentialCitationCount,authors,fieldsOfStudy,venue,publicationVenue,url,openAccessPdf",
self.base_url, paper_id
);
let paper: PaperData = self.fetch_json(&url).await?;
Ok(self.paper_to_vector(paper))
}
/// Get papers that cite this paper
///
/// # Arguments
/// * `paper_id` - Semantic Scholar paper ID
/// * `limit` - Maximum number of citations to return (max 1000)
///
/// # Example
/// ```rust,ignore
/// let citations = client.get_citations("649def34f8be52c8b66281af98ae884c09aef38b", 50).await?;
/// ```
pub async fn get_citations(&self, paper_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let limit = limit.min(1000); // API limit
let url = format!(
"{}/paper/{}/citations?limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,authors,fieldsOfStudy,venue,url",
self.base_url, paper_id, limit
);
let response: CitationResponse = self.fetch_json(&url).await?;
let mut vectors = Vec::new();
for citation in response.data {
if let Some(citing_paper) = citation.citing_paper {
if let Some(vector) = self.paper_to_vector(citing_paper) {
vectors.push(vector);
}
}
}
Ok(vectors)
}
/// Get papers this paper references
///
/// # Arguments
/// * `paper_id` - Semantic Scholar paper ID
/// * `limit` - Maximum number of references to return (max 1000)
///
/// # Example
/// ```rust,ignore
/// let references = client.get_references("649def34f8be52c8b66281af98ae884c09aef38b", 50).await?;
/// ```
pub async fn get_references(&self, paper_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
let limit = limit.min(1000); // API limit
let url = format!(
"{}/paper/{}/references?limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,authors,fieldsOfStudy,venue,url",
self.base_url, paper_id, limit
);
let response: CitationResponse = self.fetch_json(&url).await?;
let mut vectors = Vec::new();
for reference in response.data {
if let Some(cited_paper) = reference.cited_paper {
if let Some(vector) = self.paper_to_vector(cited_paper) {
vectors.push(vector);
}
}
}
Ok(vectors)
}
/// Search papers by field of study
///
/// # Arguments
/// * `field_of_study` - Field name (e.g., "Computer Science", "Medicine", "Biology", "Physics", "Economics")
/// * `limit` - Maximum number of results to return
///
/// # Example
/// ```rust,ignore
/// let cs_papers = client.search_by_field("Computer Science", 100).await?;
/// let medical_papers = client.search_by_field("Medicine", 50).await?;
/// ```
pub async fn search_by_field(&self, field_of_study: &str, limit: usize) -> Result<Vec<SemanticVector>> {
// Search for papers in this field, sorted by citation count
let query = format!("fieldsOfStudy:{}", field_of_study);
self.search_papers(&query, limit).await
}
/// Get author details and their papers
///
/// # Arguments
/// * `author_id` - Semantic Scholar author ID
///
/// # Example
/// ```rust,ignore
/// let author_papers = client.get_author("1741101").await?;
/// ```
pub async fn get_author(&self, author_id: &str) -> Result<Vec<SemanticVector>> {
let url = format!(
"{}/author/{}?fields=authorId,name,paperCount,citationCount,hIndex,papers.paperId,papers.title,papers.abstract,papers.year,papers.citationCount,papers.fieldsOfStudy",
self.base_url, author_id
);
let author: AuthorResponse = self.fetch_json(&url).await?;
let mut vectors = Vec::new();
for paper in author.papers {
if let Some(vector) = self.paper_to_vector(paper) {
vectors.push(vector);
}
}
Ok(vectors)
}
/// Search recent papers published after a minimum year
///
/// # Arguments
/// * `query` - Search query
/// * `year_min` - Minimum publication year (e.g., 2020)
///
/// # Example
/// ```rust,ignore
/// // Get papers about "climate change" published since 2020
/// let recent = client.search_recent("climate change", 2020).await?;
/// ```
pub async fn search_recent(&self, query: &str, year_min: i32) -> Result<Vec<SemanticVector>> {
let all_results = self.search_papers(query, 100).await?;
// Filter by year
Ok(all_results
.into_iter()
.filter(|v| {
v.metadata
.get("year")
.and_then(|y| y.parse::<i32>().ok())
.map(|year| year >= year_min)
.unwrap_or(false)
})
.collect())
}
/// Build citation graph for a paper
///
/// Returns a tuple of (paper, citations, references) as SemanticVectors
///
/// # Arguments
/// * `paper_id` - Semantic Scholar paper ID
/// * `max_citations` - Maximum citations to retrieve
/// * `max_references` - Maximum references to retrieve
///
/// # Example
/// ```rust,ignore
/// let (paper, citations, references) = client.build_citation_graph(
/// "649def34f8be52c8b66281af98ae884c09aef38b",
/// 50,
/// 50
/// ).await?;
/// ```
pub async fn build_citation_graph(
&self,
paper_id: &str,
max_citations: usize,
max_references: usize,
) -> Result<(Option<SemanticVector>, Vec<SemanticVector>, Vec<SemanticVector>)> {
// Fetch paper, citations, and references in parallel
let paper_result = self.get_paper(paper_id);
let citations_result = self.get_citations(paper_id, max_citations);
let references_result = self.get_references(paper_id, max_references);
// Wait for all with proper spacing for rate limiting
let paper = paper_result.await?;
sleep(self.rate_limit_delay).await;
let citations = citations_result.await?;
sleep(self.rate_limit_delay).await;
let references = references_result.await?;
Ok((paper, citations, references))
}
/// Convert PaperData to SemanticVector
fn paper_to_vector(&self, paper: PaperData) -> Option<SemanticVector> {
let title = paper.title.clone().unwrap_or_default();
let abstract_text = paper.abstract_text.clone().unwrap_or_default();
// Skip papers without title
if title.is_empty() {
return None;
}
// Generate embedding from title + abstract
let combined_text = format!("{} {}", title, abstract_text);
let embedding = self.embedder.embed_text(&combined_text);
// Convert year to timestamp
let timestamp = paper.year
.and_then(|y| NaiveDate::from_ymd_opt(y, 1, 1))
.map(|d| DateTime::from_naive_utc_and_offset(d.and_hms_opt(0, 0, 0).unwrap(), Utc))
.unwrap_or_else(Utc::now);
// Build metadata
let mut metadata = HashMap::new();
metadata.insert("paper_id".to_string(), paper.paper_id.clone());
metadata.insert("title".to_string(), title);
if !abstract_text.is_empty() {
metadata.insert("abstract".to_string(), abstract_text);
}
if let Some(year) = paper.year {
metadata.insert("year".to_string(), year.to_string());
}
if let Some(count) = paper.citation_count {
metadata.insert("citationCount".to_string(), count.to_string());
}
if let Some(count) = paper.reference_count {
metadata.insert("referenceCount".to_string(), count.to_string());
}
if let Some(count) = paper.influential_citation_count {
metadata.insert("influentialCitationCount".to_string(), count.to_string());
}
// Authors
let authors = paper
.authors
.iter()
.filter_map(|a| a.name.as_ref())
.cloned()
.collect::<Vec<_>>()
.join(", ");
if !authors.is_empty() {
metadata.insert("authors".to_string(), authors);
}
// Fields of study
if !paper.fields_of_study.is_empty() {
metadata.insert("fieldsOfStudy".to_string(), paper.fields_of_study.join(", "));
}
// Venue
if let Some(venue) = paper.venue.or_else(|| paper.publication_venue.and_then(|pv| pv.name)) {
metadata.insert("venue".to_string(), venue);
}
// URL
if let Some(url) = paper.url {
metadata.insert("url".to_string(), url);
} else {
metadata.insert(
"url".to_string(),
format!("https://www.semanticscholar.org/paper/{}", paper.paper_id),
);
}
// Open access PDF
if let Some(pdf) = paper.open_access_pdf.and_then(|p| p.url) {
metadata.insert("pdf_url".to_string(), pdf);
}
metadata.insert("source".to_string(), "semantic_scholar".to_string());
Some(SemanticVector {
id: format!("s2:{}", paper.paper_id),
embedding,
domain: Domain::Research,
timestamp,
metadata,
})
}
/// Fetch JSON from URL with rate limiting and retry logic
async fn fetch_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
// Rate limiting
sleep(self.rate_limit_delay).await;
let response = self.fetch_with_retry(url).await?;
let json = response.json::<T>().await?;
Ok(json)
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
let mut request = self.client.get(url);
// Add API key header if available
if let Some(ref api_key) = self.api_key {
request = request.header("x-api-key", api_key);
}
match request.send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
retries += 1;
let delay = RETRY_DELAY_MS * (2_u64.pow(retries - 1)); // Exponential backoff
tracing::warn!(
"Rate limited by Semantic Scholar, retrying in {}ms",
delay
);
sleep(Duration::from_millis(delay)).await;
continue;
}
if !response.status().is_success() {
return Err(FrameworkError::Network(
reqwest::Error::from(response.error_for_status().unwrap_err()),
));
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
let delay = RETRY_DELAY_MS * (2_u64.pow(retries - 1)); // Exponential backoff
tracing::warn!("Request failed, retrying ({}/{}) in {}ms", retries, MAX_RETRIES, delay);
sleep(Duration::from_millis(delay)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for SemanticScholarClient {
fn default() -> Self {
Self::new(None)
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = SemanticScholarClient::new(None);
assert_eq!(client.base_url, "https://api.semanticscholar.org/graph/v1");
assert_eq!(client.rate_limit_delay, Duration::from_millis(S2_RATE_LIMIT_MS));
}
#[test]
fn test_client_with_api_key() {
let client = SemanticScholarClient::new(Some("test-key".to_string()));
assert_eq!(client.api_key, Some("test-key".to_string()));
assert_eq!(client.rate_limit_delay, Duration::from_millis(S2_WITH_KEY_RATE_LIMIT_MS));
}
#[test]
fn test_custom_embedding_dim() {
let client = SemanticScholarClient::with_embedding_dim(None, 512);
let embedding = client.embedder.embed_text("test");
assert_eq!(embedding.len(), 512);
}
#[test]
fn test_paper_to_vector() {
let client = SemanticScholarClient::new(None);
let paper = PaperData {
paper_id: "649def34f8be52c8b66281af98ae884c09aef38b".to_string(),
title: Some("Attention Is All You Need".to_string()),
abstract_text: Some("The dominant sequence transduction models...".to_string()),
year: Some(2017),
citation_count: Some(50000),
reference_count: Some(35),
influential_citation_count: Some(5000),
authors: vec![
AuthorData {
author_id: Some("1741101".to_string()),
name: Some("Ashish Vaswani".to_string()),
},
AuthorData {
author_id: Some("1699545".to_string()),
name: Some("Noam Shazeer".to_string()),
},
],
fields_of_study: vec!["Computer Science".to_string(), "Mathematics".to_string()],
venue: Some("NeurIPS".to_string()),
publication_venue: None,
url: Some("https://arxiv.org/abs/1706.03762".to_string()),
open_access_pdf: Some(OpenAccessPdf {
url: Some("https://arxiv.org/pdf/1706.03762.pdf".to_string()),
status: Some("GREEN".to_string()),
}),
};
let vector = client.paper_to_vector(paper);
assert!(vector.is_some());
let v = vector.unwrap();
assert_eq!(v.id, "s2:649def34f8be52c8b66281af98ae884c09aef38b");
assert_eq!(v.domain, Domain::Research);
assert_eq!(v.metadata.get("paper_id").unwrap(), "649def34f8be52c8b66281af98ae884c09aef38b");
assert_eq!(v.metadata.get("title").unwrap(), "Attention Is All You Need");
assert_eq!(v.metadata.get("year").unwrap(), "2017");
assert_eq!(v.metadata.get("citationCount").unwrap(), "50000");
assert_eq!(v.metadata.get("referenceCount").unwrap(), "35");
assert_eq!(v.metadata.get("authors").unwrap(), "Ashish Vaswani, Noam Shazeer");
assert_eq!(v.metadata.get("fieldsOfStudy").unwrap(), "Computer Science, Mathematics");
assert_eq!(v.metadata.get("venue").unwrap(), "NeurIPS");
assert!(v.metadata.contains_key("pdf_url"));
}
#[test]
fn test_paper_to_vector_minimal() {
let client = SemanticScholarClient::new(None);
let paper = PaperData {
paper_id: "test123".to_string(),
title: Some("Minimal Paper".to_string()),
abstract_text: None,
year: None,
citation_count: None,
reference_count: None,
influential_citation_count: None,
authors: vec![],
fields_of_study: vec![],
venue: None,
publication_venue: None,
url: None,
open_access_pdf: None,
};
let vector = client.paper_to_vector(paper);
assert!(vector.is_some());
let v = vector.unwrap();
assert_eq!(v.id, "s2:test123");
assert_eq!(v.metadata.get("title").unwrap(), "Minimal Paper");
assert!(v.metadata.get("url").unwrap().contains("semanticscholar.org"));
}
#[test]
fn test_paper_without_title() {
let client = SemanticScholarClient::new(None);
let paper = PaperData {
paper_id: "test456".to_string(),
title: None,
abstract_text: Some("Has abstract but no title".to_string()),
year: Some(2020),
citation_count: None,
reference_count: None,
influential_citation_count: None,
authors: vec![],
fields_of_study: vec![],
venue: None,
publication_venue: None,
url: None,
open_access_pdf: None,
};
// Papers without titles should be skipped
let vector = client.paper_to_vector(paper);
assert!(vector.is_none());
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API in tests
async fn test_search_papers_integration() {
let client = SemanticScholarClient::new(None);
let results = client.search_papers("machine learning", 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
if !vectors.is_empty() {
let first = &vectors[0];
assert!(first.id.starts_with("s2:"));
assert_eq!(first.domain, Domain::Research);
assert!(first.metadata.contains_key("title"));
assert!(first.metadata.contains_key("paper_id"));
}
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
async fn test_get_paper_integration() {
let client = SemanticScholarClient::new(None);
// Well-known paper: "Attention Is All You Need"
let result = client.get_paper("649def34f8be52c8b66281af98ae884c09aef38b").await;
assert!(result.is_ok());
let paper = result.unwrap();
assert!(paper.is_some());
let p = paper.unwrap();
assert_eq!(p.id, "s2:649def34f8be52c8b66281af98ae884c09aef38b");
assert!(p.metadata.get("title").unwrap().contains("Attention"));
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
async fn test_get_citations_integration() {
let client = SemanticScholarClient::new(None);
// Get citations for "Attention Is All You Need"
let result = client.get_citations("649def34f8be52c8b66281af98ae884c09aef38b", 10).await;
assert!(result.is_ok());
let citations = result.unwrap();
assert!(citations.len() <= 10);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
async fn test_search_by_field_integration() {
let client = SemanticScholarClient::new(None);
let results = client.search_by_field("Computer Science", 5).await;
assert!(results.is_ok());
let vectors = results.unwrap();
assert!(vectors.len() <= 5);
}
#[tokio::test]
#[ignore] // Ignore by default to avoid hitting Semantic Scholar API
async fn test_build_citation_graph_integration() {
let client = SemanticScholarClient::new(None);
let result = client.build_citation_graph(
"649def34f8be52c8b66281af98ae884c09aef38b",
5,
5
).await;
assert!(result.is_ok());
let (paper, citations, references) = result.unwrap();
assert!(paper.is_some());
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,703 @@
//! Real-time Streaming Data Ingestion
//!
//! Provides async stream processing with windowed analysis, real-time pattern
//! detection, backpressure handling, and comprehensive metrics collection.
//!
//! ## Features
//! - Async stream processing for continuous data ingestion
//! - Tumbling and sliding window analysis
//! - Real-time pattern detection with callbacks
//! - Automatic backpressure handling
//! - Throughput and latency metrics
//!
//! ## Example
//! ```rust,ignore
//! use futures::stream;
//! use std::time::Duration;
//!
//! let config = StreamingConfig {
//! window_size: Duration::from_secs(60),
//! slide_interval: Duration::from_secs(30),
//! max_buffer_size: 10000,
//! ..Default::default()
//! };
//!
//! let mut engine = StreamingEngine::new(config);
//!
//! // Set pattern callback
//! engine.set_pattern_callback(|pattern| {
//! println!("Pattern detected: {:?}", pattern);
//! });
//!
//! // Ingest stream
//! let stream = stream::iter(vectors);
//! engine.ingest_stream(stream).await?;
//!
//! // Get metrics
//! let metrics = engine.metrics();
//! println!("Processed: {} vectors, {} patterns",
//! metrics.vectors_processed, metrics.patterns_detected);
//! ```
use std::sync::Arc;
use std::time::{Duration as StdDuration, Instant};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::{RwLock, Semaphore};
use crate::optimized::{OptimizedConfig, OptimizedDiscoveryEngine, SignificantPattern};
use crate::ruvector_native::SemanticVector;
use crate::Result;
/// Configuration for the streaming engine
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamingConfig {
/// Discovery engine configuration
pub discovery_config: OptimizedConfig,
/// Window size for temporal analysis
pub window_size: StdDuration,
/// Slide interval for sliding windows (if None, use tumbling windows)
pub slide_interval: Option<StdDuration>,
/// Maximum buffer size before applying backpressure
pub max_buffer_size: usize,
/// Timeout for processing a single vector (None = no timeout)
pub processing_timeout: Option<StdDuration>,
/// Batch size for parallel processing
pub batch_size: usize,
/// Enable automatic pattern detection
pub auto_detect_patterns: bool,
/// Pattern detection interval (check every N vectors)
pub detection_interval: usize,
/// Maximum concurrent processing tasks
pub max_concurrency: usize,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
discovery_config: OptimizedConfig::default(),
window_size: StdDuration::from_secs(60),
slide_interval: Some(StdDuration::from_secs(30)),
max_buffer_size: 10000,
processing_timeout: Some(StdDuration::from_secs(5)),
batch_size: 100,
auto_detect_patterns: true,
detection_interval: 100,
max_concurrency: 4,
}
}
}
/// Streaming metrics for monitoring performance
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StreamingMetrics {
/// Total vectors processed
pub vectors_processed: u64,
/// Total patterns detected
pub patterns_detected: u64,
/// Average latency in milliseconds
pub avg_latency_ms: f64,
/// Throughput (vectors per second)
pub throughput_per_sec: f64,
/// Current window count
pub windows_processed: u64,
/// Total bytes processed (if available)
pub bytes_processed: u64,
/// Backpressure events (times buffer was full)
pub backpressure_events: u64,
/// Processing errors
pub errors: u64,
/// Peak vectors in buffer
pub peak_buffer_size: usize,
/// Start time
pub start_time: Option<DateTime<Utc>>,
/// Last update time
pub last_update: Option<DateTime<Utc>>,
}
impl StreamingMetrics {
/// Calculate uptime in seconds
pub fn uptime_secs(&self) -> f64 {
if let (Some(start), Some(last)) = (self.start_time, self.last_update) {
(last - start).num_milliseconds() as f64 / 1000.0
} else {
0.0
}
}
/// Calculate average throughput
pub fn calculate_throughput(&mut self) {
let uptime = self.uptime_secs();
if uptime > 0.0 {
self.throughput_per_sec = self.vectors_processed as f64 / uptime;
}
}
}
/// Time window for analysis
#[derive(Debug, Clone)]
struct TimeWindow {
start: DateTime<Utc>,
end: DateTime<Utc>,
vectors: Vec<SemanticVector>,
}
impl TimeWindow {
fn new(start: DateTime<Utc>, duration: ChronoDuration) -> Self {
Self {
start,
end: start + duration,
vectors: Vec::new(),
}
}
fn contains(&self, timestamp: DateTime<Utc>) -> bool {
timestamp >= self.start && timestamp < self.end
}
fn add_vector(&mut self, vector: SemanticVector) {
self.vectors.push(vector);
}
fn is_complete(&self, now: DateTime<Utc>) -> bool {
now >= self.end
}
}
/// Streaming engine for real-time data ingestion and pattern detection
pub struct StreamingEngine {
/// Configuration
config: StreamingConfig,
/// Underlying discovery engine (wrapped in Arc<RwLock> for async access)
engine: Arc<RwLock<OptimizedDiscoveryEngine>>,
/// Pattern callback
on_pattern: Arc<RwLock<Option<Box<dyn Fn(SignificantPattern) + Send + Sync>>>>,
/// Metrics
metrics: Arc<RwLock<StreamingMetrics>>,
/// Current windows (for sliding window analysis)
windows: Arc<RwLock<Vec<TimeWindow>>>,
/// Backpressure semaphore
semaphore: Arc<Semaphore>,
/// Latency tracking
latencies: Arc<RwLock<Vec<f64>>>,
}
impl StreamingEngine {
/// Create a new streaming engine
pub fn new(config: StreamingConfig) -> Self {
let discovery_config = config.discovery_config.clone();
let max_buffer = config.max_buffer_size;
let mut metrics = StreamingMetrics::default();
metrics.start_time = Some(Utc::now());
Self {
config,
engine: Arc::new(RwLock::new(OptimizedDiscoveryEngine::new(discovery_config))),
on_pattern: Arc::new(RwLock::new(None)),
metrics: Arc::new(RwLock::new(metrics)),
windows: Arc::new(RwLock::new(Vec::new())),
semaphore: Arc::new(Semaphore::new(max_buffer)),
latencies: Arc::new(RwLock::new(Vec::with_capacity(1000))),
}
}
/// Set the pattern detection callback
pub async fn set_pattern_callback<F>(&mut self, callback: F)
where
F: Fn(SignificantPattern) + Send + Sync + 'static,
{
let mut on_pattern = self.on_pattern.write().await;
*on_pattern = Some(Box::new(callback));
}
/// Ingest a stream of vectors with windowed analysis
pub async fn ingest_stream<S>(&mut self, stream: S) -> Result<()>
where
S: Stream<Item = SemanticVector> + Send,
{
let mut stream = Box::pin(stream);
let mut vector_count = 0_u64;
let mut current_batch = Vec::with_capacity(self.config.batch_size);
// Initialize first window
let window_duration = ChronoDuration::from_std(self.config.window_size)
.map_err(|e| crate::FrameworkError::Config(format!("Invalid window size: {}", e)))?;
let mut last_window_start = Utc::now();
self.create_window(last_window_start, window_duration).await;
while let Some(vector) = stream.next().await {
// Backpressure handling
let _permit = self.semaphore.acquire().await.map_err(|e| {
crate::FrameworkError::Ingestion(format!("Backpressure semaphore error: {}", e))
})?;
let start = Instant::now();
// Check if we need to create a new window (sliding)
if let Some(slide_interval) = self.config.slide_interval {
let slide_duration = ChronoDuration::from_std(slide_interval)
.map_err(|e| crate::FrameworkError::Config(format!("Invalid slide interval: {}", e)))?;
let now = Utc::now();
if (now - last_window_start) >= slide_duration {
self.create_window(now, window_duration).await;
last_window_start = now;
}
}
// Add vector to appropriate windows
self.add_to_windows(vector.clone()).await;
current_batch.push(vector);
vector_count += 1;
// Process batch
if current_batch.len() >= self.config.batch_size {
self.process_batch(&current_batch).await?;
current_batch.clear();
}
// Pattern detection
if self.config.auto_detect_patterns && vector_count % self.config.detection_interval as u64 == 0 {
self.detect_patterns().await?;
}
// Close completed windows
self.close_completed_windows().await?;
// Record latency
let latency_ms = start.elapsed().as_micros() as f64 / 1000.0;
self.record_latency(latency_ms).await;
// Update metrics
let mut metrics = self.metrics.write().await;
metrics.vectors_processed = vector_count;
metrics.last_update = Some(Utc::now());
}
// Process remaining batch
if !current_batch.is_empty() {
self.process_batch(&current_batch).await?;
}
// Final pattern detection
if self.config.auto_detect_patterns {
self.detect_patterns().await?;
}
// Close all remaining windows
self.close_all_windows().await?;
// Calculate final metrics
let mut metrics = self.metrics.write().await;
metrics.calculate_throughput();
Ok(())
}
/// Process a batch of vectors in parallel
async fn process_batch(&self, vectors: &[SemanticVector]) -> Result<()> {
let batch_size = self.config.batch_size;
let chunks: Vec<_> = vectors.chunks(batch_size).collect();
// Process chunks with controlled concurrency
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
let mut tasks = Vec::new();
for chunk in chunks {
let chunk_vec = chunk.to_vec();
let engine = self.engine.clone();
let sem = semaphore.clone();
let task = tokio::spawn(async move {
let _permit = sem.acquire().await.ok()?;
let mut engine_guard = engine.write().await;
#[cfg(feature = "parallel")]
{
engine_guard.add_vectors_batch(chunk_vec);
}
#[cfg(not(feature = "parallel"))]
{
for vector in chunk_vec {
engine_guard.add_vector(vector);
}
}
Some(())
});
tasks.push(task);
}
// Wait for all tasks to complete
for task in tasks {
if let Err(e) = task.await {
tracing::warn!("Batch processing task failed: {}", e);
let mut metrics = self.metrics.write().await;
metrics.errors += 1;
}
}
Ok(())
}
/// Create a new time window
async fn create_window(&self, start: DateTime<Utc>, duration: ChronoDuration) {
let window = TimeWindow::new(start, duration);
let mut windows = self.windows.write().await;
windows.push(window);
}
/// Add vector to all active windows
async fn add_to_windows(&self, vector: SemanticVector) {
let timestamp = vector.timestamp;
let mut windows = self.windows.write().await;
for window in windows.iter_mut() {
if window.contains(timestamp) {
window.add_vector(vector.clone());
}
}
}
/// Close completed windows and analyze them
async fn close_completed_windows(&self) -> Result<()> {
let now = Utc::now();
let mut windows = self.windows.write().await;
// Find completed windows
let (completed, active): (Vec<_>, Vec<_>) = windows
.drain(..)
.partition(|w| w.is_complete(now));
*windows = active;
drop(windows); // Release lock before processing
// Process completed windows
for window in completed {
self.process_window(window).await?;
let mut metrics = self.metrics.write().await;
metrics.windows_processed += 1;
}
Ok(())
}
/// Close all remaining windows
async fn close_all_windows(&self) -> Result<()> {
let mut windows = self.windows.write().await;
let all_windows: Vec<_> = windows.drain(..).collect();
drop(windows);
for window in all_windows {
self.process_window(window).await?;
}
Ok(())
}
/// Process a completed window
async fn process_window(&self, window: TimeWindow) -> Result<()> {
if window.vectors.is_empty() {
return Ok(());
}
tracing::debug!(
"Processing window: {} vectors from {} to {}",
window.vectors.len(),
window.start,
window.end
);
// Add vectors to engine
self.process_batch(&window.vectors).await?;
// Detect patterns for this window
if self.config.auto_detect_patterns {
self.detect_patterns().await?;
}
Ok(())
}
/// Detect patterns and trigger callbacks
async fn detect_patterns(&self) -> Result<()> {
let patterns = {
let mut engine = self.engine.write().await;
engine.detect_patterns_with_significance()
};
let pattern_count = patterns.len();
// Trigger callback for each significant pattern
let on_pattern = self.on_pattern.read().await;
if let Some(callback) = on_pattern.as_ref() {
for pattern in patterns {
if pattern.is_significant {
callback(pattern);
}
}
}
// Update metrics
let mut metrics = self.metrics.write().await;
metrics.patterns_detected += pattern_count as u64;
Ok(())
}
/// Record latency measurement
async fn record_latency(&self, latency_ms: f64) {
let mut latencies = self.latencies.write().await;
latencies.push(latency_ms);
// Keep only last 1000 measurements
let len = latencies.len();
if len > 1000 {
latencies.drain(0..len - 1000);
}
// Update average latency
let avg = latencies.iter().sum::<f64>() / latencies.len() as f64;
let mut metrics = self.metrics.write().await;
metrics.avg_latency_ms = avg;
}
/// Get current metrics
pub async fn metrics(&self) -> StreamingMetrics {
let mut metrics = self.metrics.read().await.clone();
metrics.calculate_throughput();
metrics
}
/// Get engine statistics
pub async fn engine_stats(&self) -> crate::optimized::OptimizedStats {
let engine = self.engine.read().await;
engine.stats()
}
/// Reset metrics
pub async fn reset_metrics(&self) {
let mut metrics = self.metrics.write().await;
*metrics = StreamingMetrics::default();
metrics.start_time = Some(Utc::now());
let mut latencies = self.latencies.write().await;
latencies.clear();
}
}
/// Builder for StreamingEngine with fluent API
pub struct StreamingEngineBuilder {
config: StreamingConfig,
}
impl StreamingEngineBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
config: StreamingConfig::default(),
}
}
/// Set window size
pub fn window_size(mut self, duration: StdDuration) -> Self {
self.config.window_size = duration;
self
}
/// Set slide interval (for sliding windows)
pub fn slide_interval(mut self, duration: StdDuration) -> Self {
self.config.slide_interval = Some(duration);
self
}
/// Use tumbling windows (no overlap)
pub fn tumbling_windows(mut self) -> Self {
self.config.slide_interval = None;
self
}
/// Set max buffer size
pub fn max_buffer_size(mut self, size: usize) -> Self {
self.config.max_buffer_size = size;
self
}
/// Set batch size
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
/// Set max concurrency
pub fn max_concurrency(mut self, concurrency: usize) -> Self {
self.config.max_concurrency = concurrency;
self
}
/// Set detection interval
pub fn detection_interval(mut self, interval: usize) -> Self {
self.config.detection_interval = interval;
self
}
/// Set discovery config
pub fn discovery_config(mut self, config: OptimizedConfig) -> Self {
self.config.discovery_config = config;
self
}
/// Build the streaming engine
pub fn build(self) -> StreamingEngine {
StreamingEngine::new(self.config)
}
}
impl Default for StreamingEngineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
use crate::ruvector_native::Domain;
use std::collections::HashMap;
fn create_test_vector(id: &str, domain: Domain) -> SemanticVector {
SemanticVector {
id: id.to_string(),
embedding: vec![0.1, 0.2, 0.3, 0.4],
domain,
timestamp: Utc::now(),
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn test_streaming_engine_creation() {
let config = StreamingConfig::default();
let engine = StreamingEngine::new(config);
let metrics = engine.metrics().await;
assert_eq!(metrics.vectors_processed, 0);
assert_eq!(metrics.patterns_detected, 0);
}
#[tokio::test]
async fn test_pattern_callback() {
let config = StreamingConfig {
auto_detect_patterns: true,
detection_interval: 2,
..Default::default()
};
let mut engine = StreamingEngine::new(config);
let pattern_count = Arc::new(RwLock::new(0_u64));
let pc = pattern_count.clone();
engine.set_pattern_callback(move |_pattern| {
let pc = pc.clone();
tokio::spawn(async move {
let mut count = pc.write().await;
*count += 1;
});
}).await;
// Create a stream of vectors
let vectors = vec![
create_test_vector("v1", Domain::Climate),
create_test_vector("v2", Domain::Climate),
create_test_vector("v3", Domain::Finance),
];
let vector_stream = stream::iter(vectors);
engine.ingest_stream(vector_stream).await.unwrap();
let metrics = engine.metrics().await;
assert!(metrics.vectors_processed >= 3);
}
#[tokio::test]
async fn test_windowed_processing() {
let config = StreamingConfig {
window_size: StdDuration::from_millis(100),
slide_interval: Some(StdDuration::from_millis(50)),
auto_detect_patterns: false,
..Default::default()
};
let mut engine = StreamingEngine::new(config);
let vectors = vec![
create_test_vector("v1", Domain::Climate),
create_test_vector("v2", Domain::Climate),
];
let vector_stream = stream::iter(vectors);
engine.ingest_stream(vector_stream).await.unwrap();
let metrics = engine.metrics().await;
assert_eq!(metrics.vectors_processed, 2);
}
#[tokio::test]
async fn test_builder() {
let engine = StreamingEngineBuilder::new()
.window_size(StdDuration::from_secs(30))
.slide_interval(StdDuration::from_secs(15))
.max_buffer_size(5000)
.batch_size(50)
.build();
let metrics = engine.metrics().await;
assert_eq!(metrics.vectors_processed, 0);
}
#[tokio::test]
async fn test_metrics_calculation() {
let mut metrics = StreamingMetrics {
vectors_processed: 1000,
start_time: Some(Utc::now() - ChronoDuration::seconds(10)),
last_update: Some(Utc::now()),
..Default::default()
};
metrics.calculate_throughput();
assert!(metrics.throughput_per_sec > 0.0);
assert!(metrics.uptime_secs() >= 9.0 && metrics.uptime_secs() <= 11.0);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,171 @@
//! Shared utility functions for the RuVector Data Framework
//!
//! This module contains common utilities used across multiple modules,
//! including vector operations and mathematical functions.
/// Compute cosine similarity between two vectors
///
/// Returns a value in [-1, 1] where:
/// - 1 = identical direction
/// - 0 = orthogonal
/// - -1 = opposite direction
///
/// # Arguments
///
/// * `a` - First vector
/// * `b` - Second vector (must be same length as `a`)
///
/// # Returns
///
/// Cosine similarity score, or 0.0 if vectors are empty or different lengths
///
/// # Example
///
/// ```
/// use ruvector_data_framework::utils::cosine_similarity;
///
/// let a = vec![1.0, 0.0, 0.0];
/// let b = vec![1.0, 0.0, 0.0];
/// assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
///
/// let c = vec![0.0, 1.0, 0.0];
/// assert!(cosine_similarity(&a, &c).abs() < 1e-6);
/// ```
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
// Process in chunks for better cache locality
const CHUNK_SIZE: usize = 8;
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
// Process aligned chunks
let chunks = a.len() / CHUNK_SIZE;
for chunk in 0..chunks {
let base = chunk * CHUNK_SIZE;
for i in 0..CHUNK_SIZE {
let ai = a[base + i];
let bi = b[base + i];
dot += ai * bi;
norm_a += ai * ai;
norm_b += bi * bi;
}
}
// Process remainder
for i in (chunks * CHUNK_SIZE)..a.len() {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
norm_a += ai * ai;
norm_b += bi * bi;
}
let denom = (norm_a * norm_b).sqrt();
if denom > 1e-10 {
dot / denom
} else {
0.0
}
}
/// Compute Euclidean (L2) distance between two vectors
///
/// # Arguments
///
/// * `a` - First vector
/// * `b` - Second vector (must be same length as `a`)
///
/// # Returns
///
/// Euclidean distance, or 0.0 if vectors are empty or different lengths
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let sum_sq: f32 = a.iter()
.zip(b.iter())
.map(|(ai, bi)| {
let diff = ai - bi;
diff * diff
})
.sum();
sum_sq.sqrt()
}
/// Normalize a vector to unit length (L2 normalization)
///
/// # Arguments
///
/// * `v` - Vector to normalize (modified in place)
#[inline]
pub fn normalize_vector(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_empty() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
}
#[test]
fn test_normalize_vector() {
let mut v = vec![3.0, 4.0];
normalize_vector(&mut v);
assert!((v[0] - 0.6).abs() < 1e-6);
assert!((v[1] - 0.8).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,576 @@
//! ASCII Art Visualization for Discovery Framework
//!
//! Provides terminal-based graph visualization with ANSI colors, domain clustering,
//! coherence heatmaps, and pattern timeline displays.
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use crate::optimized::{OptimizedDiscoveryEngine, SignificantPattern};
use crate::ruvector_native::{Domain, PatternType};
/// ANSI color codes for domains
const COLOR_CLIMATE: &str = "\x1b[34m"; // Blue
const COLOR_FINANCE: &str = "\x1b[32m"; // Green
const COLOR_RESEARCH: &str = "\x1b[33m"; // Yellow
const COLOR_MEDICAL: &str = "\x1b[36m"; // Cyan
const COLOR_CROSS: &str = "\x1b[35m"; // Magenta
const COLOR_RESET: &str = "\x1b[0m";
const COLOR_BRIGHT: &str = "\x1b[1m";
const COLOR_DIM: &str = "\x1b[2m";
/// Box-drawing characters
const BOX_H: char = '─';
const BOX_V: char = '│';
const BOX_TL: char = '┌';
const BOX_TR: char = '┐';
const BOX_BL: char = '└';
const BOX_BR: char = '┘';
const BOX_CROSS: char = '┼';
const BOX_T_DOWN: char = '┬';
const BOX_T_UP: char = '┴';
const BOX_T_RIGHT: char = '├';
const BOX_T_LEFT: char = '┤';
/// Get ANSI color for a domain
fn domain_color(domain: Domain) -> &'static str {
match domain {
Domain::Climate => COLOR_CLIMATE,
Domain::Finance => COLOR_FINANCE,
Domain::Research => COLOR_RESEARCH,
Domain::Medical => COLOR_MEDICAL,
Domain::Economic => "\x1b[38;5;214m", // Orange color for Economic
Domain::Genomics => "\x1b[38;5;46m", // Green color for Genomics
Domain::Physics => "\x1b[38;5;33m", // Blue color for Physics
Domain::Seismic => "\x1b[38;5;130m", // Brown color for Seismic
Domain::Ocean => "\x1b[38;5;39m", // Cyan color for Ocean
Domain::Space => "\x1b[38;5;141m", // Purple color for Space
Domain::Transportation => "\x1b[38;5;208m", // Orange color for Transportation
Domain::Geospatial => "\x1b[38;5;118m", // Light green for Geospatial
Domain::Government => "\x1b[38;5;243m", // Gray color for Government
Domain::CrossDomain => COLOR_CROSS,
}
}
/// Get a character representation for a domain
fn domain_char(domain: Domain) -> char {
match domain {
Domain::Climate => 'C',
Domain::Finance => 'F',
Domain::Research => 'R',
Domain::Medical => 'M',
Domain::Economic => 'E',
Domain::Genomics => 'G',
Domain::Physics => 'P',
Domain::Seismic => 'S',
Domain::Ocean => 'O',
Domain::Space => 'A', // A for Astronomy/Aerospace
Domain::Transportation => 'T',
Domain::Geospatial => 'L', // L for Location
Domain::Government => 'V', // V for goVernment
Domain::CrossDomain => 'X',
}
}
/// Render the graph as ASCII art with colored domain nodes
///
/// # Arguments
/// * `engine` - The discovery engine containing the graph
/// * `width` - Canvas width in characters
/// * `height` - Canvas height in characters
///
/// # Returns
/// A string containing the ASCII art representation
pub fn render_graph_ascii(engine: &OptimizedDiscoveryEngine, width: usize, height: usize) -> String {
let stats = engine.stats();
let mut output = String::new();
// Draw title box
output.push_str(&format!("{}{}", COLOR_BRIGHT, BOX_TL));
output.push_str(&BOX_H.to_string().repeat(width - 2));
output.push_str(&format!("{}{}\n", BOX_TR, COLOR_RESET));
let title = format!(" Discovery Graph ({} nodes, {} edges) ", stats.total_nodes, stats.total_edges);
output.push_str(&format!("{}{}", COLOR_BRIGHT, BOX_V));
output.push_str(&format!("{:^width$}", title, width = width - 2));
output.push_str(&format!("{}{}\n", BOX_V, COLOR_RESET));
output.push_str(&format!("{}{}", COLOR_BRIGHT, BOX_BL));
output.push_str(&BOX_H.to_string().repeat(width - 2));
output.push_str(&format!("{}{}\n\n", BOX_BR, COLOR_RESET));
// If no nodes, show empty message
if stats.total_nodes == 0 {
output.push_str(&format!("{} (empty graph){}\n", COLOR_DIM, COLOR_RESET));
return output;
}
// Create a simple layout by domain
let mut domain_positions: HashMap<Domain, Vec<(usize, usize)>> = HashMap::new();
// Layout domains in quadrants
let mid_x = width / 2;
let mid_y = height / 2;
// Assign domain regions
let domain_regions = [
(Domain::Climate, 10, 2), // Top-left
(Domain::Finance, mid_x + 10, 2), // Top-right
(Domain::Research, 10, mid_y + 2), // Bottom-left
];
for (domain, count) in &stats.domain_counts {
let (_, base_x, base_y) = domain_regions.iter()
.find(|(d, _, _)| d == domain)
.unwrap_or(&(Domain::Research, 10, 2));
let mut positions = Vec::new();
// Arrange nodes in a cluster
let nodes_per_row = ((*count as f64).sqrt().ceil() as usize).max(1);
for i in 0..*count {
let row = i / nodes_per_row;
let col = i % nodes_per_row;
let x = base_x + col * 3;
let y = base_y + row * 2;
if x < width - 5 && y < height - 2 {
positions.push((x, y));
}
}
domain_positions.insert(*domain, positions);
}
// Create canvas
let mut canvas: Vec<Vec<String>> = vec![vec![" ".to_string(); width]; height];
// Draw nodes
for (domain, positions) in &domain_positions {
let color = domain_color(*domain);
let ch = domain_char(*domain);
for (x, y) in positions {
if *x < width && *y < height {
canvas[*y][*x] = format!("{}{}{}", color, ch, COLOR_RESET);
}
}
}
// Draw edges (simplified - show connections between domains)
if stats.cross_domain_edges > 0 {
// Draw some connecting lines
for (domain_a, positions_a) in &domain_positions {
for (domain_b, positions_b) in &domain_positions {
if domain_a == domain_b {
continue;
}
// Draw one connection line
if let (Some(pos_a), Some(pos_b)) = (positions_a.first(), positions_b.first()) {
let (x1, y1) = pos_a;
let (x2, y2) = pos_b;
// Simple line drawing (horizontal then vertical)
let color = COLOR_DIM;
// Horizontal part
let (min_x, max_x) = if x1 < x2 { (*x1, *x2) } else { (*x2, *x1) };
for x in min_x..=max_x {
if x < width && *y1 < height && canvas[*y1][x] == " " {
canvas[*y1][x] = format!("{}{}{}", color, BOX_H, COLOR_RESET);
}
}
// Vertical part
let (min_y, max_y) = if y1 < y2 { (*y1, *y2) } else { (*y2, *y1) };
for y in min_y..=max_y {
if *x2 < width && y < height && canvas[y][*x2] == " " {
canvas[y][*x2] = format!("{}{}{}", color, BOX_V, COLOR_RESET);
}
}
}
}
}
}
// Render canvas to string
for row in canvas {
for cell in row {
output.push_str(&cell);
}
output.push('\n');
}
output.push('\n');
// Legend
output.push_str(&format!("{}Legend:{}\n", COLOR_BRIGHT, COLOR_RESET));
output.push_str(&format!(" {}C{} = Climate ", COLOR_CLIMATE, COLOR_RESET));
output.push_str(&format!("{}F{} = Finance ", COLOR_FINANCE, COLOR_RESET));
output.push_str(&format!("{}R{} = Research\n", COLOR_RESEARCH, COLOR_RESET));
output.push_str(&format!(" Cross-domain bridges: {}\n", stats.cross_domain_edges));
output
}
/// Render a domain connectivity matrix
///
/// Shows the strength of connections between different domains
pub fn render_domain_matrix(engine: &OptimizedDiscoveryEngine) -> String {
let stats = engine.stats();
let mut output = String::new();
output.push_str(&format!("\n{}{}Domain Connectivity Matrix{}{}\n",
COLOR_BRIGHT, BOX_TL, BOX_TR, COLOR_RESET));
output.push_str(&format!("{}\n", BOX_H.to_string().repeat(50)));
// Calculate connections between domains
let domains = [Domain::Climate, Domain::Finance, Domain::Research];
let mut matrix: HashMap<(Domain, Domain), usize> = HashMap::new();
// Initialize matrix
for &d1 in &domains {
for &d2 in &domains {
matrix.insert((d1, d2), 0);
}
}
// This is a placeholder - in real implementation, we'd iterate through edges
// and count connections between domains
output.push_str(&format!(" {}Climate{} {}Finance{} {}Research{}\n",
COLOR_CLIMATE, COLOR_RESET,
COLOR_FINANCE, COLOR_RESET,
COLOR_RESEARCH, COLOR_RESET));
for &domain_a in &domains {
let color_a = domain_color(domain_a);
output.push_str(&format!("{}{:9}{} ", color_a, format!("{:?}", domain_a), COLOR_RESET));
for &domain_b in &domains {
let count = matrix.get(&(domain_a, domain_b)).unwrap_or(&0);
let display = if domain_a == domain_b {
format!("{}[{:3}]{}", COLOR_BRIGHT, stats.domain_counts.get(&domain_a).unwrap_or(&0), COLOR_RESET)
} else {
format!(" {:3} ", count)
};
output.push_str(&display);
}
output.push('\n');
}
output.push_str(&format!("\n{}Note:{} Diagonal = node count, Off-diagonal = cross-domain edges\n",
COLOR_DIM, COLOR_RESET));
output.push_str(&format!("Total cross-domain edges: {}\n", stats.cross_domain_edges));
output
}
/// Render coherence timeline as ASCII sparkline/chart
///
/// # Arguments
/// * `history` - Time series of (timestamp, coherence_value) pairs
pub fn render_coherence_timeline(history: &[(DateTime<Utc>, f64)]) -> String {
let mut output = String::new();
output.push_str(&format!("\n{}{}Coherence Timeline{}{}\n",
COLOR_BRIGHT, BOX_TL, BOX_TR, COLOR_RESET));
output.push_str(&format!("{}\n", BOX_H.to_string().repeat(70)));
if history.is_empty() {
output.push_str(&format!("{} (no coherence history){}\n", COLOR_DIM, COLOR_RESET));
return output;
}
let values: Vec<f64> = history.iter().map(|(_, v)| *v).collect();
let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
output.push_str(&format!(" Coherence range: {:.4} - {:.4}\n", min_val, max_val));
output.push_str(&format!(" Data points: {}\n\n", history.len()));
// ASCII sparkline
let chart_height = 10;
let chart_width = 60.min(history.len());
// Sample data if too many points
let step = if history.len() > chart_width {
history.len() / chart_width
} else {
1
};
let sampled: Vec<f64> = history.iter()
.step_by(step)
.take(chart_width)
.map(|(_, v)| *v)
.collect();
// Normalize values to chart height
let range = max_val - min_val;
let normalized: Vec<usize> = if range > 1e-10 {
sampled.iter()
.map(|v| {
let normalized = ((v - min_val) / range * (chart_height - 1) as f64) as usize;
normalized.min(chart_height - 1)
})
.collect()
} else {
vec![chart_height / 2; sampled.len()]
};
// Draw chart
for row in (0..chart_height).rev() {
let value = min_val + (row as f64 / (chart_height - 1) as f64) * range;
output.push_str(&format!("{:6.3} {} ", value, BOX_V));
for &height in &normalized {
let ch = if height >= row {
format!("{}{}", COLOR_CLIMATE, COLOR_RESET)
} else if height + 1 == row {
format!("{}{}", COLOR_DIM, COLOR_RESET)
} else {
" ".to_string()
};
output.push_str(&ch);
}
output.push('\n');
}
// X-axis
output.push_str(" ");
output.push_str(&BOX_BL.to_string());
output.push_str(&BOX_H.to_string().repeat(chart_width));
output.push('\n');
// Time labels
if let (Some(first), Some(last)) = (history.first(), history.last()) {
let duration = last.0.signed_duration_since(first.0);
let width_val = if chart_width > 12 { chart_width - 12 } else { 0 };
output.push_str(&format!(" {} {:>width$}\n",
first.0.format("%Y-%m-%d"),
last.0.format("%Y-%m-%d"),
width = width_val));
output.push_str(&format!(" {}Duration: {}{}\n",
COLOR_DIM,
if duration.num_days() > 0 {
format!("{} days", duration.num_days())
} else if duration.num_hours() > 0 {
format!("{} hours", duration.num_hours())
} else {
format!("{} minutes", duration.num_minutes())
},
COLOR_RESET));
}
output
}
/// Render a summary of discovered patterns
///
/// # Arguments
/// * `patterns` - List of significant patterns to summarize
pub fn render_pattern_summary(patterns: &[SignificantPattern]) -> String {
let mut output = String::new();
output.push_str(&format!("\n{}{}Pattern Discovery Summary{}{}\n",
COLOR_BRIGHT, BOX_TL, BOX_TR, COLOR_RESET));
output.push_str(&format!("{}\n", BOX_H.to_string().repeat(80)));
if patterns.is_empty() {
output.push_str(&format!("{} No patterns discovered yet{}\n", COLOR_DIM, COLOR_RESET));
return output;
}
output.push_str(&format!(" Total patterns detected: {}\n", patterns.len()));
// Count by type
let mut type_counts: HashMap<PatternType, usize> = HashMap::new();
let mut significant_count = 0;
for pattern in patterns {
*type_counts.entry(pattern.pattern.pattern_type).or_default() += 1;
if pattern.is_significant {
significant_count += 1;
}
}
output.push_str(&format!(" Statistically significant: {} ({:.1}%)\n\n",
significant_count,
(significant_count as f64 / patterns.len() as f64) * 100.0));
// Pattern type breakdown
output.push_str(&format!("{}Pattern Types:{}\n", COLOR_BRIGHT, COLOR_RESET));
for (pattern_type, count) in type_counts.iter() {
let icon = match pattern_type {
PatternType::CoherenceBreak => "⚠️ ",
PatternType::Consolidation => "📈",
PatternType::EmergingCluster => "🌟",
PatternType::DissolvingCluster => "💫",
PatternType::BridgeFormation => "🌉",
PatternType::AnomalousNode => "🔴",
PatternType::TemporalShift => "",
PatternType::Cascade => "🌊",
};
let bar_length = ((*count as f64 / patterns.len() as f64) * 30.0) as usize;
let bar = "".repeat(bar_length);
output.push_str(&format!(" {} {:20} {:3} {}{}{}\n",
icon,
format!("{:?}", pattern_type),
count,
COLOR_CLIMATE,
bar,
COLOR_RESET));
}
output.push('\n');
// Top patterns by confidence
output.push_str(&format!("{}Top Patterns (by confidence):{}\n", COLOR_BRIGHT, COLOR_RESET));
let mut sorted_patterns: Vec<_> = patterns.iter().collect();
sorted_patterns.sort_by(|a, b| b.pattern.confidence.partial_cmp(&a.pattern.confidence).unwrap());
for (i, pattern) in sorted_patterns.iter().take(5).enumerate() {
let significance_marker = if pattern.is_significant {
format!("{}*{}", COLOR_BRIGHT, COLOR_RESET)
} else {
" ".to_string()
};
let color = if pattern.pattern.confidence > 0.8 {
COLOR_CLIMATE
} else if pattern.pattern.confidence > 0.5 {
COLOR_FINANCE
} else {
COLOR_DIM
};
output.push_str(&format!(" {}{}.{} {}{:?}{} (p={:.4}, effect={:.3}, conf={:.2})\n",
significance_marker,
i + 1,
COLOR_RESET,
color,
pattern.pattern.pattern_type,
COLOR_RESET,
pattern.p_value,
pattern.effect_size,
pattern.pattern.confidence));
output.push_str(&format!(" {}{}{}\n",
COLOR_DIM,
pattern.pattern.description,
COLOR_RESET));
}
output.push_str(&format!("\n{}Note:{} * = statistically significant (p < 0.05)\n",
COLOR_DIM, COLOR_RESET));
output
}
/// Render a complete dashboard combining all visualizations
pub fn render_dashboard(
engine: &OptimizedDiscoveryEngine,
patterns: &[SignificantPattern],
coherence_history: &[(DateTime<Utc>, f64)],
) -> String {
let mut output = String::new();
// Title
output.push_str(&format!("\n{}{}═══════════════════════════════════════════════════════════════════════════════{}\n",
COLOR_BRIGHT, BOX_TL, COLOR_RESET));
output.push_str(&format!("{}{} RuVector Discovery Framework - Live Dashboard {}\n",
COLOR_BRIGHT, BOX_V, COLOR_RESET));
output.push_str(&format!("{}{}═══════════════════════════════════════════════════════════════════════════════{}\n\n",
COLOR_BRIGHT, BOX_BL, COLOR_RESET));
// Stats overview
let stats = engine.stats();
output.push_str(&format!("{}Quick Stats:{}\n", COLOR_BRIGHT, COLOR_RESET));
output.push_str(&format!(" Nodes: {} │ Edges: {} │ Vectors: {} │ Cross-domain: {}\n",
stats.total_nodes,
stats.total_edges,
stats.total_vectors,
stats.cross_domain_edges));
output.push_str(&format!(" Patterns: {} │ Coherence samples: {} │ Cache hit rate: {:.1}%\n\n",
patterns.len(),
coherence_history.len(),
stats.cache_hit_rate * 100.0));
// Graph visualization
output.push_str(&render_graph_ascii(engine, 80, 20));
output.push('\n');
// Domain matrix
output.push_str(&render_domain_matrix(engine));
output.push('\n');
// Coherence timeline
output.push_str(&render_coherence_timeline(coherence_history));
output.push('\n');
// Pattern summary
output.push_str(&render_pattern_summary(patterns));
output.push_str(&format!("\n{}{}═══════════════════════════════════════════════════════════════════════════════{}\n",
COLOR_DIM, BOX_BL, COLOR_RESET));
output
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimized::{OptimizedConfig, OptimizedDiscoveryEngine};
use crate::ruvector_native::SemanticVector;
use chrono::Utc;
#[test]
fn test_domain_color() {
assert_eq!(domain_color(Domain::Climate), COLOR_CLIMATE);
assert_eq!(domain_color(Domain::Finance), COLOR_FINANCE);
}
#[test]
fn test_domain_char() {
assert_eq!(domain_char(Domain::Climate), 'C');
assert_eq!(domain_char(Domain::Finance), 'F');
assert_eq!(domain_char(Domain::Research), 'R');
}
#[test]
fn test_render_empty_graph() {
let config = OptimizedConfig::default();
let engine = OptimizedDiscoveryEngine::new(config);
let output = render_graph_ascii(&engine, 80, 20);
assert!(output.contains("empty graph"));
}
#[test]
fn test_render_pattern_summary_empty() {
let output = render_pattern_summary(&[]);
assert!(output.contains("No patterns"));
}
#[test]
fn test_render_coherence_timeline_empty() {
let output = render_coherence_timeline(&[]);
assert!(output.contains("no coherence history"));
}
#[test]
fn test_render_coherence_timeline_with_data() {
let now = Utc::now();
let history = vec![
(now, 0.5),
(now + chrono::Duration::hours(1), 0.6),
(now + chrono::Duration::hours(2), 0.7),
];
let output = render_coherence_timeline(&history);
assert!(output.contains("Coherence Timeline"));
assert!(output.contains("Data points: 3"));
}
}

View File

@@ -0,0 +1,906 @@
//! Wikipedia and Wikidata API clients for knowledge graph building
//!
//! This module provides async clients for:
//! - Wikipedia: Article content, categories, links, and search
//! - Wikidata: Entity lookup, SPARQL queries, and structured knowledge
//!
//! Both clients convert responses into RuVector's DataRecord format with
//! semantic embeddings for vector search and graph analysis.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use chrono::Utc;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
use crate::{DataRecord, DataSource, FrameworkError, Relationship, Result};
use crate::api_clients::SimpleEmbedder;
/// Rate limiting configuration
const DEFAULT_RATE_LIMIT_DELAY_MS: u64 = 100;
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 1000;
// ============================================================================
// Wikipedia API Client
// ============================================================================
/// Wikipedia API search response
#[derive(Debug, Deserialize)]
struct WikiSearchResponse {
query: WikiSearchQuery,
}
#[derive(Debug, Deserialize)]
struct WikiSearchQuery {
search: Vec<WikiSearchResult>,
}
#[derive(Debug, Deserialize)]
struct WikiSearchResult {
title: String,
pageid: u64,
snippet: String,
}
/// Wikipedia API page response
#[derive(Debug, Deserialize)]
struct WikiPageResponse {
query: WikiPageQuery,
}
#[derive(Debug, Deserialize)]
struct WikiPageQuery {
pages: HashMap<String, WikiPage>,
}
#[derive(Debug, Deserialize)]
struct WikiPage {
pageid: u64,
title: String,
#[serde(default)]
extract: String,
#[serde(default)]
categories: Vec<WikiCategory>,
#[serde(default)]
links: Vec<WikiLink>,
}
#[derive(Debug, Deserialize)]
struct WikiCategory {
title: String,
}
#[derive(Debug, Deserialize)]
struct WikiLink {
title: String,
}
/// Client for Wikipedia API
pub struct WikipediaClient {
client: Client,
base_url: String,
language: String,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl WikipediaClient {
/// Create a new Wikipedia client
///
/// # Arguments
/// * `language` - Wikipedia language code (e.g., "en", "de", "fr")
pub fn new(language: String) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.user_agent("RuVector/1.0 (https://github.com/ruvnet/ruvector)")
.build()
.map_err(|e| FrameworkError::Network(e))?;
let base_url = format!("https://{}.wikipedia.org/w/api.php", language);
Ok(Self {
client,
base_url,
language,
rate_limit_delay: Duration::from_millis(DEFAULT_RATE_LIMIT_DELAY_MS),
embedder: Arc::new(SimpleEmbedder::new(256)), // Larger dimension for richer content
})
}
/// Search Wikipedia articles
///
/// # Arguments
/// * `query` - Search query
/// * `limit` - Maximum number of results (max 500)
pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<DataRecord>> {
let url = format!(
"{}?action=query&list=search&srsearch={}&srlimit={}&format=json",
self.base_url,
urlencoding::encode(query),
limit.min(500)
);
let response = self.fetch_with_retry(&url).await?;
let search_response: WikiSearchResponse = response.json().await?;
let mut records = Vec::new();
for result in search_response.query.search {
// Get full article for each search result
if let Ok(article) = self.get_article(&result.title).await {
records.push(article);
sleep(self.rate_limit_delay).await;
}
}
Ok(records)
}
/// Get a Wikipedia article by title
///
/// # Arguments
/// * `title` - Article title
pub async fn get_article(&self, title: &str) -> Result<DataRecord> {
let url = format!(
"{}?action=query&prop=extracts|categories|links&titles={}&exintro=1&explaintext=1&format=json&cllimit=50&pllimit=50",
self.base_url,
urlencoding::encode(title)
);
let response = self.fetch_with_retry(&url).await?;
let page_response: WikiPageResponse = response.json().await?;
// Extract the page (should be only one)
let page = page_response
.query
.pages
.values()
.next()
.ok_or_else(|| FrameworkError::Discovery("No page found".to_string()))?;
self.page_to_record(page)
}
/// Get categories for an article
///
/// # Arguments
/// * `title` - Article title
pub async fn get_categories(&self, title: &str) -> Result<Vec<String>> {
let url = format!(
"{}?action=query&prop=categories&titles={}&cllimit=500&format=json",
self.base_url,
urlencoding::encode(title)
);
let response = self.fetch_with_retry(&url).await?;
let page_response: WikiPageResponse = response.json().await?;
let categories = page_response
.query
.pages
.values()
.next()
.map(|page| page.categories.iter().map(|c| c.title.clone()).collect())
.unwrap_or_default();
Ok(categories)
}
/// Get links from an article
///
/// # Arguments
/// * `title` - Article title
pub async fn get_links(&self, title: &str) -> Result<Vec<String>> {
let url = format!(
"{}?action=query&prop=links&titles={}&pllimit=500&format=json",
self.base_url,
urlencoding::encode(title)
);
let response = self.fetch_with_retry(&url).await?;
let page_response: WikiPageResponse = response.json().await?;
let links = page_response
.query
.pages
.values()
.next()
.map(|page| page.links.iter().map(|l| l.title.clone()).collect())
.unwrap_or_default();
Ok(links)
}
/// Convert Wikipedia page to DataRecord
fn page_to_record(&self, page: &WikiPage) -> Result<DataRecord> {
// Create embedding from title and extract
let text = format!("{} {}", page.title, page.extract);
let embedding = self.embedder.embed_text(&text);
// Build relationships from categories
let mut relationships = Vec::new();
for category in &page.categories {
relationships.push(Relationship {
target_id: category.title.clone(),
rel_type: "in_category".to_string(),
weight: 1.0,
properties: HashMap::new(),
});
}
// Build relationships from links (limit to first 20)
for link in page.links.iter().take(20) {
relationships.push(Relationship {
target_id: link.title.clone(),
rel_type: "links_to".to_string(),
weight: 0.5,
properties: HashMap::new(),
});
}
let mut data_map = serde_json::Map::new();
data_map.insert("title".to_string(), serde_json::json!(page.title));
data_map.insert("extract".to_string(), serde_json::json!(page.extract));
data_map.insert("pageid".to_string(), serde_json::json!(page.pageid));
data_map.insert("language".to_string(), serde_json::json!(self.language));
data_map.insert(
"url".to_string(),
serde_json::json!(format!(
"https://{}.wikipedia.org/wiki/{}",
self.language,
urlencoding::encode(&page.title)
)),
);
Ok(DataRecord {
id: format!("wikipedia_{}_{}", self.language, page.pageid),
source: "wikipedia".to_string(),
record_type: "article".to_string(),
timestamp: Utc::now(),
data: serde_json::Value::Object(data_map),
embedding: Some(embedding),
relationships,
})
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
{
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
#[async_trait]
impl DataSource for WikipediaClient {
fn source_id(&self) -> &str {
"wikipedia"
}
async fn fetch_batch(
&self,
cursor: Option<String>,
batch_size: usize,
) -> Result<(Vec<DataRecord>, Option<String>)> {
// Default to searching for "machine learning" if no cursor provided
let query = cursor.as_deref().unwrap_or("machine learning");
let records = self.search(query, batch_size).await?;
Ok((records, None))
}
async fn total_count(&self) -> Result<Option<u64>> {
Ok(None)
}
async fn health_check(&self) -> Result<bool> {
let response = self.client.get(&self.base_url).send().await?;
Ok(response.status().is_success())
}
}
// ============================================================================
// Wikidata API Client
// ============================================================================
/// Wikidata entity search response
#[derive(Debug, Deserialize)]
struct WikidataSearchResponse {
search: Vec<WikidataSearchResult>,
}
#[derive(Debug, Deserialize)]
struct WikidataSearchResult {
id: String,
label: String,
description: Option<String>,
}
/// Wikidata entity response
#[derive(Debug, Deserialize)]
struct WikidataEntityResponse {
entities: HashMap<String, WikidataEntityData>,
}
#[derive(Debug, Deserialize)]
struct WikidataEntityData {
id: String,
labels: HashMap<String, WikidataLabel>,
descriptions: HashMap<String, WikidataLabel>,
aliases: HashMap<String, Vec<WikidataLabel>>,
claims: HashMap<String, Vec<WikidataClaim>>,
}
#[derive(Debug, Deserialize)]
struct WikidataLabel {
value: String,
}
#[derive(Debug, Deserialize)]
struct WikidataClaim {
mainsnak: WikidataSnak,
}
#[derive(Debug, Deserialize)]
struct WikidataSnak {
datavalue: Option<WikidataValue>,
}
#[derive(Debug, Deserialize)]
struct WikidataValue {
value: serde_json::Value,
}
/// Wikidata SPARQL response
#[derive(Debug, Deserialize)]
struct WikidataSparqlResponse {
results: WikidataSparqlResults,
}
#[derive(Debug, Deserialize)]
struct WikidataSparqlResults {
bindings: Vec<HashMap<String, WikidataSparqlBinding>>,
}
#[derive(Debug, Deserialize)]
struct WikidataSparqlBinding {
value: String,
}
/// Structured Wikidata entity
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WikidataEntity {
/// Wikidata Q-identifier
pub qid: String,
/// Primary label
pub label: String,
/// Description
pub description: String,
/// Alternative names
pub aliases: Vec<String>,
/// Property claims (property ID -> values)
pub claims: HashMap<String, Vec<String>>,
}
/// Client for Wikidata API and SPARQL endpoint
pub struct WikidataClient {
client: Client,
api_url: String,
sparql_url: String,
rate_limit_delay: Duration,
embedder: Arc<SimpleEmbedder>,
}
impl WikidataClient {
/// Create a new Wikidata client
pub fn new() -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.user_agent("RuVector/1.0 (https://github.com/ruvnet/ruvector)")
.build()
.map_err(|e| FrameworkError::Network(e))?;
Ok(Self {
client,
api_url: "https://www.wikidata.org/w/api.php".to_string(),
sparql_url: "https://query.wikidata.org/sparql".to_string(),
rate_limit_delay: Duration::from_millis(DEFAULT_RATE_LIMIT_DELAY_MS),
embedder: Arc::new(SimpleEmbedder::new(256)),
})
}
/// Search for Wikidata entities
///
/// # Arguments
/// * `query` - Search query
pub async fn search_entities(&self, query: &str) -> Result<Vec<WikidataEntity>> {
let url = format!(
"{}?action=wbsearchentities&search={}&language=en&format=json&limit=50",
self.api_url,
urlencoding::encode(query)
);
let response = self.fetch_with_retry(&url).await?;
let search_response: WikidataSearchResponse = response.json().await?;
let mut entities = Vec::new();
for result in search_response.search {
entities.push(WikidataEntity {
qid: result.id,
label: result.label,
description: result.description.unwrap_or_default(),
aliases: Vec::new(),
claims: HashMap::new(),
});
}
Ok(entities)
}
/// Get a Wikidata entity by QID
///
/// # Arguments
/// * `qid` - Wikidata Q-identifier (e.g., "Q42" for Douglas Adams)
pub async fn get_entity(&self, qid: &str) -> Result<WikidataEntity> {
let url = format!(
"{}?action=wbgetentities&ids={}&format=json",
self.api_url, qid
);
let response = self.fetch_with_retry(&url).await?;
let entity_response: WikidataEntityResponse = response.json().await?;
let entity_data = entity_response
.entities
.get(qid)
.ok_or_else(|| FrameworkError::Discovery(format!("Entity {} not found", qid)))?;
self.entity_data_to_entity(entity_data)
}
/// Execute a SPARQL query
///
/// # Arguments
/// * `query` - SPARQL query string
pub async fn sparql_query(&self, query: &str) -> Result<Vec<HashMap<String, String>>> {
let response = self
.client
.get(&self.sparql_url)
.query(&[("query", query), ("format", "json")])
.send()
.await?;
let sparql_response: WikidataSparqlResponse = response.json().await?;
let results = sparql_response
.results
.bindings
.into_iter()
.map(|binding| {
binding
.into_iter()
.map(|(k, v)| (k, v.value))
.collect::<HashMap<String, String>>()
})
.collect();
Ok(results)
}
/// Query climate change related entities
pub async fn query_climate_entities(&self) -> Result<Vec<DataRecord>> {
let query = r#"
SELECT ?item ?itemLabel ?itemDescription WHERE {
{
?item wdt:P31 wd:Q125977. # climate change
} UNION {
?item wdt:P279* wd:Q125977. # subclass of climate change
} UNION {
?item wdt:P921 wd:Q125977. # main subject climate change
}
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
LIMIT 100
"#;
self.sparql_to_records(query, "climate").await
}
/// Query pharmaceutical companies
pub async fn query_pharmaceutical_companies(&self) -> Result<Vec<DataRecord>> {
let query = r#"
SELECT ?item ?itemLabel ?itemDescription ?founded ?employees WHERE {
?item wdt:P31/wdt:P279* wd:Q507443. # pharmaceutical company
OPTIONAL { ?item wdt:P571 ?founded. }
OPTIONAL { ?item wdt:P1128 ?employees. }
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
LIMIT 100
"#;
self.sparql_to_records(query, "pharma").await
}
/// Query disease outbreaks
pub async fn query_disease_outbreaks(&self) -> Result<Vec<DataRecord>> {
let query = r#"
SELECT ?item ?itemLabel ?itemDescription ?disease ?diseaseLabel ?startTime ?location ?locationLabel WHERE {
?item wdt:P31 wd:Q3241045. # epidemic
OPTIONAL { ?item wdt:P828 ?disease. }
OPTIONAL { ?item wdt:P580 ?startTime. }
OPTIONAL { ?item wdt:P276 ?location. }
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
LIMIT 100
"#;
self.sparql_to_records(query, "disease").await
}
/// Convert SPARQL results to DataRecords
async fn sparql_to_records(&self, query: &str, category: &str) -> Result<Vec<DataRecord>> {
let results = self.sparql_query(query).await?;
let mut records = Vec::new();
for result in results {
// Extract QID from URI
let item_uri = result.get("item").cloned().unwrap_or_default();
let qid = item_uri
.split('/')
.last()
.unwrap_or(&item_uri)
.to_string();
let label = result
.get("itemLabel")
.cloned()
.unwrap_or_else(|| qid.clone());
let description = result.get("itemDescription").cloned().unwrap_or_default();
// Create embedding from label and description
let text = format!("{} {}", label, description);
let embedding = self.embedder.embed_text(&text);
let mut data_map = serde_json::Map::new();
data_map.insert("qid".to_string(), serde_json::json!(qid));
data_map.insert("label".to_string(), serde_json::json!(label));
data_map.insert("description".to_string(), serde_json::json!(description));
data_map.insert("category".to_string(), serde_json::json!(category));
// Add all other SPARQL result fields
for (key, value) in result.iter() {
if !key.ends_with("Label") && key != "item" && key != "itemDescription" {
data_map.insert(key.clone(), serde_json::json!(value));
}
}
records.push(DataRecord {
id: format!("wikidata_{}", qid),
source: "wikidata".to_string(),
record_type: category.to_string(),
timestamp: Utc::now(),
data: serde_json::Value::Object(data_map),
embedding: Some(embedding),
relationships: Vec::new(),
});
}
Ok(records)
}
/// Convert entity data to WikidataEntity
fn entity_data_to_entity(&self, data: &WikidataEntityData) -> Result<WikidataEntity> {
let label = data
.labels
.get("en")
.map(|l| l.value.clone())
.unwrap_or_else(|| data.id.clone());
let description = data
.descriptions
.get("en")
.map(|d| d.value.clone())
.unwrap_or_default();
let aliases = data
.aliases
.get("en")
.map(|aliases| aliases.iter().map(|a| a.value.clone()).collect())
.unwrap_or_default();
let mut claims = HashMap::new();
for (property, claim_list) in &data.claims {
let values: Vec<String> = claim_list
.iter()
.filter_map(|claim| {
claim
.mainsnak
.datavalue
.as_ref()
.map(|dv| dv.value.to_string())
})
.collect();
if !values.is_empty() {
claims.insert(property.clone(), values);
}
}
Ok(WikidataEntity {
qid: data.id.clone(),
label,
description,
aliases,
claims,
})
}
/// Convert WikidataEntity to DataRecord
fn entity_to_record(&self, entity: &WikidataEntity) -> Result<DataRecord> {
// Create embedding from label, description, and aliases
let text = format!(
"{} {} {}",
entity.label,
entity.description,
entity.aliases.join(" ")
);
let embedding = self.embedder.embed_text(&text);
// Build relationships from claims
let mut relationships = Vec::new();
for (property, values) in &entity.claims {
for value in values {
// Try to extract QID if value is an entity reference
if let Some(qid) = value.strip_prefix("Q") {
if qid.chars().all(|c| c.is_ascii_digit()) {
relationships.push(Relationship {
target_id: value.clone(),
rel_type: property.clone(),
weight: 1.0,
properties: HashMap::new(),
});
}
}
}
}
let mut data_map = serde_json::Map::new();
data_map.insert("qid".to_string(), serde_json::json!(entity.qid));
data_map.insert("label".to_string(), serde_json::json!(entity.label));
data_map.insert(
"description".to_string(),
serde_json::json!(entity.description),
);
data_map.insert("aliases".to_string(), serde_json::json!(entity.aliases));
data_map.insert(
"url".to_string(),
serde_json::json!(format!(
"https://www.wikidata.org/wiki/{}",
entity.qid
)),
);
// Add claims as structured data
let claims_json: serde_json::Value = serde_json::to_value(&entity.claims)?;
data_map.insert("claims".to_string(), claims_json);
Ok(DataRecord {
id: format!("wikidata_{}", entity.qid),
source: "wikidata".to_string(),
record_type: "entity".to_string(),
timestamp: Utc::now(),
data: serde_json::Value::Object(data_map),
embedding: Some(embedding),
relationships,
})
}
/// Fetch with retry logic
async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
let mut retries = 0;
loop {
match self.client.get(url).send().await {
Ok(response) => {
if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
{
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
continue;
}
return Ok(response);
}
Err(_) if retries < MAX_RETRIES => {
retries += 1;
sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
}
Err(e) => return Err(FrameworkError::Network(e)),
}
}
}
}
impl Default for WikidataClient {
fn default() -> Self {
Self::new().expect("Failed to create WikidataClient")
}
}
#[async_trait]
impl DataSource for WikidataClient {
fn source_id(&self) -> &str {
"wikidata"
}
async fn fetch_batch(
&self,
cursor: Option<String>,
_batch_size: usize,
) -> Result<(Vec<DataRecord>, Option<String>)> {
// Use cursor to determine which query to run
let records = match cursor.as_deref() {
Some("climate") => self.query_climate_entities().await?,
Some("pharma") => self.query_pharmaceutical_companies().await?,
Some("disease") => self.query_disease_outbreaks().await?,
_ => {
// Default: search for "artificial intelligence"
let entities = self.search_entities("artificial intelligence").await?;
let mut records = Vec::new();
for entity in entities.iter().take(20) {
records.push(self.entity_to_record(entity)?);
}
records
}
};
Ok((records, None))
}
async fn total_count(&self) -> Result<Option<u64>> {
Ok(None)
}
async fn health_check(&self) -> Result<bool> {
let response = self.client.get(&self.api_url).send().await?;
Ok(response.status().is_success())
}
}
// ============================================================================
// Example SPARQL Queries
// ============================================================================
/// Pre-defined SPARQL query templates
pub mod sparql_queries {
/// Query for climate change related entities
pub const CLIMATE_CHANGE: &str = r#"
SELECT ?item ?itemLabel ?itemDescription WHERE {
{
?item wdt:P31 wd:Q125977. # instance of climate change
} UNION {
?item wdt:P279* wd:Q125977. # subclass of climate change
} UNION {
?item wdt:P921 wd:Q125977. # main subject climate change
}
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
LIMIT 100
"#;
/// Query for pharmaceutical companies
pub const PHARMACEUTICAL_COMPANIES: &str = r#"
SELECT ?item ?itemLabel ?itemDescription ?founded ?employees ?headquarters ?headquartersLabel WHERE {
?item wdt:P31/wdt:P279* wd:Q507443. # pharmaceutical company
OPTIONAL { ?item wdt:P571 ?founded. }
OPTIONAL { ?item wdt:P1128 ?employees. }
OPTIONAL { ?item wdt:P159 ?headquarters. }
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
ORDER BY DESC(?employees)
LIMIT 100
"#;
/// Query for disease outbreaks
pub const DISEASE_OUTBREAKS: &str = r#"
SELECT ?item ?itemLabel ?itemDescription ?disease ?diseaseLabel ?startTime ?endTime ?location ?locationLabel ?deaths WHERE {
?item wdt:P31 wd:Q3241045. # epidemic
OPTIONAL { ?item wdt:P828 ?disease. }
OPTIONAL { ?item wdt:P580 ?startTime. }
OPTIONAL { ?item wdt:P582 ?endTime. }
OPTIONAL { ?item wdt:P276 ?location. }
OPTIONAL { ?item wdt:P1120 ?deaths. }
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
ORDER BY DESC(?startTime)
LIMIT 100
"#;
/// Query for scientific research institutions
pub const RESEARCH_INSTITUTIONS: &str = r#"
SELECT ?item ?itemLabel ?itemDescription ?country ?countryLabel ?founded WHERE {
?item wdt:P31/wdt:P279* wd:Q31855. # research institute
OPTIONAL { ?item wdt:P17 ?country. }
OPTIONAL { ?item wdt:P571 ?founded. }
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
LIMIT 100
"#;
/// Query for Nobel Prize winners in specific field
pub const NOBEL_LAUREATES: &str = r#"
SELECT ?item ?itemLabel ?itemDescription ?award ?awardLabel ?year ?field ?fieldLabel WHERE {
?item wdt:P166 ?award.
?award wdt:P279* wd:Q7191. # Nobel Prize
OPTIONAL { ?item wdt:P166 ?award. ?award wdt:P585 ?year. }
OPTIONAL { ?award wdt:P101 ?field. }
SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
ORDER BY DESC(?year)
LIMIT 100
"#;
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_wikipedia_client_creation() {
let client = WikipediaClient::new("en".to_string());
assert!(client.is_ok());
}
#[tokio::test]
async fn test_wikidata_client_creation() {
let client = WikidataClient::new();
assert!(client.is_ok());
}
#[test]
fn test_wikidata_entity_serialization() {
let mut claims = HashMap::new();
claims.insert("P31".to_string(), vec!["Q5".to_string()]);
let entity = WikidataEntity {
qid: "Q42".to_string(),
label: "Douglas Adams".to_string(),
description: "English writer and humorist".to_string(),
aliases: vec!["Douglas Noel Adams".to_string()],
claims,
};
let json = serde_json::to_string(&entity).unwrap();
let parsed: WikidataEntity = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.qid, "Q42");
assert_eq!(parsed.label, "Douglas Adams");
}
#[test]
fn test_sparql_query_templates() {
assert!(!sparql_queries::CLIMATE_CHANGE.is_empty());
assert!(!sparql_queries::PHARMACEUTICAL_COMPANIES.is_empty());
assert!(!sparql_queries::DISEASE_OUTBREAKS.is_empty());
}
}