Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,402 @@
# Tiny Dancer Routing Module
Neural-powered dynamic agent routing with FastGRNN for intelligent AI agent selection.
## Overview
The Tiny Dancer routing module provides intelligent routing of requests to AI agents based on multiple optimization criteria including cost, latency, quality, and balanced performance. It uses a FastGRNN (Fast Gated Recurrent Neural Network) for adaptive decision-making.
## Architecture
### Components
1. **FastGRNN** (`fastgrnn.rs`)
- Lightweight gated recurrent neural network
- Real-time routing decisions with minimal compute
- Adaptive learning from routing patterns
2. **Agent Registry** (`agents.rs`)
- Thread-safe agent storage with DashMap
- Capability-based agent discovery
- Performance metrics tracking
3. **Router** (`router.rs`)
- Multi-objective optimization
- Constraint-based filtering
- Neural-enhanced confidence scoring
4. **PostgreSQL Operators** (`operators.rs`)
- SQL functions for agent management
- Routing query interface
- Statistics and monitoring
## PostgreSQL Functions
### Agent Registration
```sql
-- Register a simple agent
SELECT ruvector_register_agent(
'gpt-4', -- Agent name
'llm', -- Agent type
ARRAY['code_generation', 'reasoning'], -- Capabilities
0.03, -- Cost per request ($)
500.0, -- Average latency (ms)
0.95 -- Quality score (0-1)
);
-- Register with full configuration
SELECT ruvector_register_agent_full('{
"name": "claude-3-opus",
"agent_type": "llm",
"capabilities": ["coding", "reasoning", "writing"],
"cost_model": {
"per_request": 0.025,
"per_token": 0.00005
},
"performance": {
"avg_latency_ms": 400.0,
"quality_score": 0.93,
"success_rate": 0.99,
"p95_latency_ms": 600.0,
"p99_latency_ms": 1000.0
},
"is_active": true
}'::jsonb);
```
### Routing Requests
```sql
-- Basic routing (optimize for balanced performance)
SELECT ruvector_route(
embedding_vector, -- Request embedding (384-dim)
'balanced', -- Optimization target
NULL -- No constraints
)
FROM requests
WHERE id = 123;
-- Cost-optimized routing with constraints
SELECT ruvector_route(
embedding_vector,
'cost',
'{"max_latency_ms": 1000.0, "min_quality": 0.8}'::jsonb
)
FROM requests
WHERE id = 456;
-- Quality-optimized with capability requirements
SELECT ruvector_route(
embedding_vector,
'quality',
'{
"max_cost": 0.1,
"required_capabilities": ["code_generation", "debugging"],
"excluded_agents": ["slow-agent"]
}'::jsonb
);
-- Latency-optimized routing
SELECT ruvector_route(
embedding_vector,
'latency',
'{"max_latency_ms": 500.0}'::jsonb
);
```
### Agent Management
```sql
-- List all agents
SELECT * FROM ruvector_list_agents();
-- Get specific agent details
SELECT ruvector_get_agent('gpt-4');
-- Find agents by capability
SELECT * FROM ruvector_find_agents_by_capability('code_generation', 5);
-- Update agent performance metrics
SELECT ruvector_update_agent_metrics(
'gpt-4', -- Agent name
450.0, -- Observed latency (ms)
true, -- Success
0.92 -- Quality score (optional)
);
-- Deactivate an agent
SELECT ruvector_set_agent_active('gpt-4', false);
-- Remove an agent
SELECT ruvector_remove_agent('old-agent');
-- Get routing statistics
SELECT ruvector_routing_stats();
```
## Usage Examples
### Example 1: Multi-Model Routing System
```sql
-- Register various AI models
SELECT ruvector_register_agent('gpt-4', 'llm',
ARRAY['coding', 'reasoning', 'math'], 0.03, 500.0, 0.95);
SELECT ruvector_register_agent('gpt-3.5-turbo', 'llm',
ARRAY['general', 'fast'], 0.002, 200.0, 0.75);
SELECT ruvector_register_agent('claude-3-opus', 'llm',
ARRAY['coding', 'writing', 'analysis'], 0.025, 400.0, 0.93);
SELECT ruvector_register_agent('llama-2-70b', 'llm',
ARRAY['local', 'private'], 0.0, 800.0, 0.72);
-- Create routing view
CREATE VIEW intelligent_routing AS
SELECT
r.id,
r.query_text,
r.embedding,
route.agent_name,
route.confidence,
route.estimated_cost,
route.estimated_latency_ms,
route.expected_quality,
route.reasoning
FROM requests r,
LATERAL (
SELECT (ruvector_route(
r.embedding,
'balanced',
NULL
))::jsonb AS route_data
) route_query,
LATERAL jsonb_to_record(route_query.route_data) AS route(
agent_name text,
confidence float4,
estimated_cost float4,
estimated_latency_ms float4,
expected_quality float4,
similarity_score float4,
reasoning text
);
-- Query with automatic routing
SELECT * FROM intelligent_routing WHERE id = 123;
```
### Example 2: Cost-Aware Batch Processing
```sql
-- Process batch with cost constraints
CREATE TEMP TABLE batch_results AS
SELECT
r.id,
r.query_text,
routing.agent_name,
routing.estimated_cost,
routing.expected_quality
FROM requests r
CROSS JOIN LATERAL (
SELECT (ruvector_route(
r.embedding,
'cost',
'{"max_cost": 0.01, "min_quality": 0.7}'::jsonb
))::jsonb->'agent_name' AS agent_name,
(ruvector_route(
r.embedding,
'cost',
'{"max_cost": 0.01, "min_quality": 0.7}'::jsonb
))::jsonb->'estimated_cost' AS estimated_cost,
(ruvector_route(
r.embedding,
'cost',
'{"max_cost": 0.01, "min_quality": 0.7}'::jsonb
))::jsonb->'expected_quality' AS expected_quality
) routing
WHERE r.processed = false
LIMIT 1000;
-- Calculate total estimated cost
SELECT
SUM((estimated_cost)::float) AS total_cost,
AVG((expected_quality)::float) AS avg_quality,
COUNT(*) AS total_requests
FROM batch_results;
```
### Example 3: Quality-First Routing
```sql
-- Route critical requests to highest quality agents
CREATE FUNCTION route_critical_request(
request_embedding float4[],
min_quality float4 DEFAULT 0.9
) RETURNS jsonb AS $$
SELECT ruvector_route(
request_embedding,
'quality',
jsonb_build_object(
'min_quality', min_quality,
'max_latency_ms', 2000.0,
'required_capabilities', ARRAY['reasoning', 'analysis']
)
);
$$ LANGUAGE SQL;
-- Use the function
SELECT route_critical_request(embedding_vector, 0.95)
FROM critical_requests
WHERE priority = 'high';
```
### Example 4: Real-time Performance Tracking
```sql
-- Update metrics after each request
CREATE FUNCTION record_agent_performance(
agent_name text,
actual_latency_ms float4,
success boolean,
quality_score float4
) RETURNS void AS $$
BEGIN
PERFORM ruvector_update_agent_metrics(
agent_name,
actual_latency_ms,
success,
quality_score
);
END;
$$ LANGUAGE plpgsql;
-- Trigger to auto-update metrics
CREATE TRIGGER update_agent_metrics_trigger
AFTER INSERT ON request_completions
FOR EACH ROW
EXECUTE FUNCTION record_agent_performance(
NEW.agent_name,
NEW.latency_ms,
NEW.success,
NEW.quality_score
);
```
### Example 5: Capability-Based Routing
```sql
-- Create specialized routing functions
CREATE FUNCTION route_code_request(emb float4[]) RETURNS text AS $$
SELECT (ruvector_route(
emb,
'quality',
'{"required_capabilities": ["coding", "debugging"]}'::jsonb
))::jsonb->>'agent_name';
$$ LANGUAGE SQL;
CREATE FUNCTION route_writing_request(emb float4[]) RETURNS text AS $$
SELECT (ruvector_route(
emb,
'quality',
'{"required_capabilities": ["writing", "editing"]}'::jsonb
))::jsonb->>'agent_name';
$$ LANGUAGE SQL;
-- Use in application logic
SELECT
CASE
WHEN task_type = 'code' THEN route_code_request(embedding)
WHEN task_type = 'write' THEN route_writing_request(embedding)
ELSE (ruvector_route(embedding, 'balanced', NULL))::jsonb->>'agent_name'
END AS selected_agent
FROM tasks;
```
## Optimization Targets
### Cost
- Minimizes cost per request
- Considers both per-request and per-token costs
- Ideal for high-volume, cost-sensitive workloads
### Latency
- Minimizes response time
- Uses average latency metrics
- Best for real-time applications
### Quality
- Maximizes quality score
- Based on historical performance
- Recommended for critical tasks
### Balanced
- Multi-objective optimization
- Balances cost, latency, quality, and similarity
- Default for general-purpose routing
## Constraints
### max_cost
Maximum acceptable cost per request (in dollars)
### max_latency_ms
Maximum acceptable latency in milliseconds
### min_quality
Minimum required quality score (0-1 scale)
### required_capabilities
Array of required agent capabilities
### excluded_agents
Array of agent names to exclude from selection
## Performance Considerations
1. **Agent Registry**: Thread-safe with DashMap for concurrent access
2. **Embedding Similarity**: Uses fast cosine similarity for request matching
3. **FastGRNN**: Lightweight neural network for real-time inference
4. **Caching**: Consider caching routing decisions for identical requests
## Monitoring
```sql
-- View agent statistics
SELECT name, total_requests, avg_latency_ms, quality_score, success_rate
FROM ruvector_list_agents()
ORDER BY total_requests DESC;
-- Get overall routing statistics
SELECT ruvector_routing_stats();
-- Find underperforming agents
SELECT name, success_rate, quality_score
FROM ruvector_list_agents()
WHERE success_rate < 0.95
OR quality_score < 0.7;
```
## Best Practices
1. **Register Accurate Metrics**: Keep agent performance metrics up-to-date
2. **Use Constraints**: Always set appropriate constraints for production
3. **Monitor Performance**: Track actual vs. estimated metrics
4. **Update Regularly**: Use `ruvector_update_agent_metrics` after each request
5. **Capability Matching**: Ensure agents have accurate capability tags
6. **Cost Tracking**: Monitor total routing costs with statistics queries
## Integration with Other Modules
The routing module integrates seamlessly with:
- **Vector Search**: Use query embeddings for semantic routing
- **GNN**: Enhance routing with graph neural networks
- **Quantization**: Reduce embedding storage costs
- **HNSW Index**: Fast similarity search for agent selection
## Future Enhancements
- [ ] A/B testing framework for agent comparison
- [ ] Multi-armed bandit algorithms for exploration
- [ ] Reinforcement learning for adaptive routing
- [ ] Cost prediction models
- [ ] Load balancing across agent instances
- [ ] Geo-distributed agent routing

View File

@@ -0,0 +1,500 @@
// Agent Registry and Management
//
// Thread-safe registry for managing AI agents with capabilities and performance metrics.
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Type of AI agent
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AgentType {
/// Language model (GPT, Claude, etc.)
LLM,
/// Embedding model
Embedding,
/// Specialized task agent
Specialized,
/// Vision model
Vision,
/// Audio model
Audio,
/// Multimodal agent
Multimodal,
/// Custom agent type
Custom(String),
}
impl AgentType {
/// Parse agent type from string
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"llm" => AgentType::LLM,
"embedding" => AgentType::Embedding,
"specialized" => AgentType::Specialized,
"vision" => AgentType::Vision,
"audio" => AgentType::Audio,
"multimodal" => AgentType::Multimodal,
_ => AgentType::Custom(s.to_string()),
}
}
/// Convert to string
pub fn as_str(&self) -> &str {
match self {
AgentType::LLM => "llm",
AgentType::Embedding => "embedding",
AgentType::Specialized => "specialized",
AgentType::Vision => "vision",
AgentType::Audio => "audio",
AgentType::Multimodal => "multimodal",
AgentType::Custom(s) => s,
}
}
}
/// Cost model for agent usage
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostModel {
/// Cost per request
pub per_request: f32,
/// Cost per token (if applicable)
pub per_token: Option<f32>,
/// Fixed monthly cost
pub monthly_fixed: Option<f32>,
}
impl Default for CostModel {
fn default() -> Self {
Self {
per_request: 0.0,
per_token: None,
monthly_fixed: None,
}
}
}
/// Performance metrics for an agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceMetrics {
/// Average latency in milliseconds
pub avg_latency_ms: f32,
/// 95th percentile latency
pub p95_latency_ms: f32,
/// 99th percentile latency
pub p99_latency_ms: f32,
/// Quality score (0-1)
pub quality_score: f32,
/// Success rate (0-1)
pub success_rate: f32,
/// Total requests processed
pub total_requests: u64,
}
impl Default for PerformanceMetrics {
fn default() -> Self {
Self {
avg_latency_ms: 100.0,
p95_latency_ms: 200.0,
p99_latency_ms: 500.0,
quality_score: 0.8,
success_rate: 0.99,
total_requests: 0,
}
}
}
/// AI Agent definition with capabilities and metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
/// Unique agent name
pub name: String,
/// Agent type
pub agent_type: AgentType,
/// Capabilities (e.g., ["code_generation", "translation"])
pub capabilities: Vec<String>,
/// Cost model
pub cost_model: CostModel,
/// Performance metrics
pub performance: PerformanceMetrics,
/// Agent embedding for similarity matching (384-dim)
pub embedding: Option<Vec<f32>>,
/// Whether agent is currently active
pub is_active: bool,
/// Additional metadata
pub metadata: serde_json::Value,
}
impl Agent {
/// Create a new agent
pub fn new(name: String, agent_type: AgentType, capabilities: Vec<String>) -> Self {
Self {
name,
agent_type,
capabilities,
cost_model: CostModel::default(),
performance: PerformanceMetrics::default(),
embedding: None,
is_active: true,
metadata: serde_json::Value::Null,
}
}
/// Check if agent has a specific capability
pub fn has_capability(&self, capability: &str) -> bool {
self.capabilities
.iter()
.any(|c| c.eq_ignore_ascii_case(capability))
}
/// Calculate total cost for a request
pub fn calculate_cost(&self, token_count: Option<u32>) -> f32 {
let mut cost = self.cost_model.per_request;
if let (Some(tokens), Some(per_token)) = (token_count, self.cost_model.per_token) {
cost += tokens as f32 * per_token;
}
cost
}
/// Update performance metrics with new observation
pub fn update_metrics(&mut self, latency_ms: f32, success: bool, quality: Option<f32>) {
let n = self.performance.total_requests as f32;
let new_n = n + 1.0;
// Update average latency with exponential moving average
self.performance.avg_latency_ms =
(self.performance.avg_latency_ms * n + latency_ms) / new_n;
// Update success rate
let prev_successes = (self.performance.success_rate * n) as u64;
let new_successes = prev_successes + if success { 1 } else { 0 };
self.performance.success_rate = new_successes as f32 / new_n;
// Update quality score if provided
if let Some(q) = quality {
self.performance.quality_score = (self.performance.quality_score * n + q) / new_n;
}
self.performance.total_requests += 1;
// Update percentiles (simplified approach)
if latency_ms > self.performance.avg_latency_ms * 1.5 {
self.performance.p95_latency_ms =
(self.performance.p95_latency_ms * 0.95 + latency_ms * 0.05).max(latency_ms);
}
if latency_ms > self.performance.avg_latency_ms * 2.0 {
self.performance.p99_latency_ms =
(self.performance.p99_latency_ms * 0.99 + latency_ms * 0.01).max(latency_ms);
}
}
}
/// Thread-safe agent registry
pub struct AgentRegistry {
/// Agents stored by name
agents: Arc<DashMap<String, Agent>>,
}
impl AgentRegistry {
/// Create a new agent registry
pub fn new() -> Self {
Self {
agents: Arc::new(DashMap::new()),
}
}
/// Register a new agent
pub fn register(&self, agent: Agent) -> Result<(), String> {
if self.agents.contains_key(&agent.name) {
return Err(format!("Agent '{}' already exists", agent.name));
}
self.agents.insert(agent.name.clone(), agent);
Ok(())
}
/// Update an existing agent
pub fn update(&self, agent: Agent) -> Result<(), String> {
if !self.agents.contains_key(&agent.name) {
return Err(format!("Agent '{}' not found", agent.name));
}
self.agents.insert(agent.name.clone(), agent);
Ok(())
}
/// Get an agent by name
pub fn get(&self, name: &str) -> Option<Agent> {
self.agents.get(name).map(|entry| entry.clone())
}
/// Remove an agent
pub fn remove(&self, name: &str) -> Option<Agent> {
self.agents.remove(name).map(|(_, agent)| agent)
}
/// List all active agents
pub fn list_active(&self) -> Vec<Agent> {
self.agents
.iter()
.filter(|entry| entry.is_active)
.map(|entry| entry.clone())
.collect()
}
/// List all agents
pub fn list_all(&self) -> Vec<Agent> {
self.agents.iter().map(|entry| entry.clone()).collect()
}
/// Find agents by capability
pub fn find_by_capability(&self, capability: &str, k: usize) -> Vec<Agent> {
let mut agents: Vec<Agent> = self
.agents
.iter()
.filter(|entry| entry.is_active && entry.has_capability(capability))
.map(|entry| entry.clone())
.collect();
// Sort by quality score (descending)
agents.sort_by(|a, b| {
b.performance
.quality_score
.partial_cmp(&a.performance.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
agents.into_iter().take(k).collect()
}
/// Find agents by type
pub fn find_by_type(&self, agent_type: &AgentType) -> Vec<Agent> {
self.agents
.iter()
.filter(|entry| entry.is_active && &entry.agent_type == agent_type)
.map(|entry| entry.clone())
.collect()
}
/// Get agent count
pub fn count(&self) -> usize {
self.agents.len()
}
/// Get active agent count
pub fn count_active(&self) -> usize {
self.agents.iter().filter(|entry| entry.is_active).count()
}
/// Clear all agents
pub fn clear(&self) {
self.agents.clear();
}
}
impl Default for AgentRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_type_parsing() {
assert_eq!(AgentType::from_str("llm"), AgentType::LLM);
assert_eq!(AgentType::from_str("LLM"), AgentType::LLM);
assert_eq!(AgentType::from_str("embedding"), AgentType::Embedding);
assert_eq!(
AgentType::from_str("custom"),
AgentType::Custom("custom".to_string())
);
}
#[test]
fn test_agent_creation() {
let agent = Agent::new(
"gpt-4".to_string(),
AgentType::LLM,
vec!["code_generation".to_string(), "translation".to_string()],
);
assert_eq!(agent.name, "gpt-4");
assert_eq!(agent.agent_type, AgentType::LLM);
assert_eq!(agent.capabilities.len(), 2);
assert!(agent.is_active);
}
#[test]
fn test_agent_has_capability() {
let agent = Agent::new(
"test".to_string(),
AgentType::LLM,
vec!["code_generation".to_string()],
);
assert!(agent.has_capability("code_generation"));
assert!(agent.has_capability("CODE_GENERATION"));
assert!(!agent.has_capability("translation"));
}
#[test]
fn test_agent_cost_calculation() {
let mut agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
agent.cost_model.per_request = 0.01;
agent.cost_model.per_token = Some(0.0001);
assert_eq!(agent.calculate_cost(None), 0.01);
assert_eq!(agent.calculate_cost(Some(1000)), 0.11); // 0.01 + 1000 * 0.0001
}
#[test]
fn test_agent_update_metrics() {
let mut agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
// Initial state
assert_eq!(agent.performance.total_requests, 0);
// Add first observation
agent.update_metrics(100.0, true, Some(0.9));
assert_eq!(agent.performance.total_requests, 1);
assert_eq!(agent.performance.avg_latency_ms, 100.0);
assert_eq!(agent.performance.success_rate, 1.0);
assert_eq!(agent.performance.quality_score, 0.9);
// Add second observation
agent.update_metrics(200.0, true, Some(0.8));
assert_eq!(agent.performance.total_requests, 2);
assert_eq!(agent.performance.avg_latency_ms, 150.0);
assert_eq!(agent.performance.success_rate, 1.0);
assert!((agent.performance.quality_score - 0.85).abs() < 0.01);
}
#[test]
fn test_registry_register() {
let registry = AgentRegistry::new();
let agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
assert!(registry.register(agent.clone()).is_ok());
assert_eq!(registry.count(), 1);
// Duplicate registration should fail
assert!(registry.register(agent).is_err());
}
#[test]
fn test_registry_get() {
let registry = AgentRegistry::new();
let agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
registry.register(agent.clone()).unwrap();
let retrieved = registry.get("test").unwrap();
assert_eq!(retrieved.name, "test");
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_registry_remove() {
let registry = AgentRegistry::new();
let agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
registry.register(agent).unwrap();
assert_eq!(registry.count(), 1);
let removed = registry.remove("test").unwrap();
assert_eq!(removed.name, "test");
assert_eq!(registry.count(), 0);
}
#[test]
fn test_registry_list_active() {
let registry = AgentRegistry::new();
let mut agent1 = Agent::new("active".to_string(), AgentType::LLM, vec![]);
agent1.is_active = true;
let mut agent2 = Agent::new("inactive".to_string(), AgentType::LLM, vec![]);
agent2.is_active = false;
registry.register(agent1).unwrap();
registry.register(agent2).unwrap();
let active = registry.list_active();
assert_eq!(active.len(), 1);
assert_eq!(active[0].name, "active");
}
#[test]
fn test_registry_find_by_capability() {
let registry = AgentRegistry::new();
let agent1 = Agent::new(
"agent1".to_string(),
AgentType::LLM,
vec!["coding".to_string()],
);
let agent2 = Agent::new(
"agent2".to_string(),
AgentType::LLM,
vec!["translation".to_string()],
);
let agent3 = Agent::new(
"agent3".to_string(),
AgentType::LLM,
vec!["coding".to_string(), "translation".to_string()],
);
registry.register(agent1).unwrap();
registry.register(agent2).unwrap();
registry.register(agent3).unwrap();
let coders = registry.find_by_capability("coding", 10);
assert_eq!(coders.len(), 2);
let translators = registry.find_by_capability("translation", 10);
assert_eq!(translators.len(), 2);
}
#[test]
fn test_registry_find_by_type() {
let registry = AgentRegistry::new();
registry
.register(Agent::new("llm1".to_string(), AgentType::LLM, vec![]))
.unwrap();
registry
.register(Agent::new("llm2".to_string(), AgentType::LLM, vec![]))
.unwrap();
registry
.register(Agent::new(
"embed1".to_string(),
AgentType::Embedding,
vec![],
))
.unwrap();
let llms = registry.find_by_type(&AgentType::LLM);
assert_eq!(llms.len(), 2);
let embeddings = registry.find_by_type(&AgentType::Embedding);
assert_eq!(embeddings.len(), 1);
}
#[test]
fn test_registry_clear() {
let registry = AgentRegistry::new();
registry
.register(Agent::new("test".to_string(), AgentType::LLM, vec![]))
.unwrap();
assert_eq!(registry.count(), 1);
registry.clear();
assert_eq!(registry.count(), 0);
}
}

View File

@@ -0,0 +1,253 @@
// FastGRNN - Fast Gated Recurrent Neural Network
//
// Lightweight RNN for real-time routing decisions with minimal compute overhead.
// Based on "FastGRNN: A Fast, Accurate, Stable and Tiny Kilobyte Sized Gated Recurrent Neural Network"
use std::f32;
/// FastGRNN cell for sequence processing with gating mechanisms
#[derive(Clone)]
pub struct FastGRNN {
/// Input dimension
input_dim: usize,
/// Hidden state dimension
hidden_dim: usize,
/// Gate weights for input
w_gate: Vec<f32>,
/// Gate weights for hidden state
u_gate: Vec<f32>,
/// Update weights for input
w_update: Vec<f32>,
/// Update weights for hidden state
u_update: Vec<f32>,
/// Biases for gate and update
bias_gate: Vec<f32>,
bias_update: Vec<f32>,
/// Zeta parameter for gate scaling
zeta: f32,
/// Nu parameter for update scaling
nu: f32,
}
impl FastGRNN {
/// Create a new FastGRNN cell with specified dimensions
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
// Initialize with small random weights (Xavier initialization)
let scale = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
Self {
input_dim,
hidden_dim,
w_gate: vec![0.1 * scale; input_dim * hidden_dim],
u_gate: vec![0.1 * scale; hidden_dim * hidden_dim],
w_update: vec![0.1 * scale; input_dim * hidden_dim],
u_update: vec![0.1 * scale; hidden_dim * hidden_dim],
bias_gate: vec![0.0; hidden_dim],
bias_update: vec![0.0; hidden_dim],
zeta: 1.0,
nu: 1.0,
}
}
/// Create FastGRNN from pre-trained weights
pub fn from_weights(
input_dim: usize,
hidden_dim: usize,
w_gate: Vec<f32>,
u_gate: Vec<f32>,
w_update: Vec<f32>,
u_update: Vec<f32>,
bias_gate: Vec<f32>,
bias_update: Vec<f32>,
zeta: f32,
nu: f32,
) -> Self {
Self {
input_dim,
hidden_dim,
w_gate,
u_gate,
w_update,
u_update,
bias_gate,
bias_update,
zeta,
nu,
}
}
/// Perform one step of FastGRNN computation
///
/// # Arguments
/// * `input` - Input vector of size input_dim
/// * `hidden` - Previous hidden state of size hidden_dim
///
/// # Returns
/// New hidden state of size hidden_dim
pub fn step(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
assert_eq!(input.len(), self.input_dim, "Input dimension mismatch");
assert_eq!(hidden.len(), self.hidden_dim, "Hidden dimension mismatch");
let mut new_hidden = vec![0.0; self.hidden_dim];
// Compute gate: g = sigmoid(W_g * x + U_g * h + b_g)
let mut gate = vec![0.0; self.hidden_dim];
self.matmul_add(&self.w_gate, input, &mut gate);
self.matmul_add(&self.u_gate, hidden, &mut gate);
for i in 0..self.hidden_dim {
gate[i] = self.sigmoid(gate[i] + self.bias_gate[i]);
}
// Compute update: c = tanh(W_u * x + U_u * h + b_u)
let mut update = vec![0.0; self.hidden_dim];
self.matmul_add(&self.w_update, input, &mut update);
self.matmul_add(&self.u_update, hidden, &mut update);
for i in 0..self.hidden_dim {
update[i] = self.tanh(update[i] + self.bias_update[i]);
}
// Compute new hidden: h' = (zeta * g + nu) ⊙ h + (1 - zeta * g - nu) ⊙ c
for i in 0..self.hidden_dim {
let gate_factor = self.zeta * gate[i] + self.nu;
let gate_factor = gate_factor.min(1.0).max(0.0); // Clip to [0, 1]
new_hidden[i] = gate_factor * hidden[i] + (1.0 - gate_factor) * update[i];
}
new_hidden
}
/// Process a single input and return hidden state (for single-step inference)
pub fn forward_single(&self, input: &[f32]) -> Vec<f32> {
let hidden = vec![0.0; self.hidden_dim];
self.step(input, &hidden)
}
/// Process a sequence of inputs
pub fn forward_sequence(&self, inputs: &[Vec<f32>]) -> Vec<Vec<f32>> {
let mut hidden = vec![0.0; self.hidden_dim];
let mut outputs = Vec::with_capacity(inputs.len());
for input in inputs {
hidden = self.step(input, &hidden);
outputs.push(hidden.clone());
}
outputs
}
/// Matrix-vector multiplication with accumulation: result += W * input
fn matmul_add(&self, weights: &[f32], input: &[f32], result: &mut [f32]) {
let rows = result.len();
let cols = input.len();
for i in 0..rows {
for j in 0..cols {
result[i] += weights[i * cols + j] * input[j];
}
}
}
/// Sigmoid activation function
fn sigmoid(&self, x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
/// Hyperbolic tangent activation function
fn tanh(&self, x: f32) -> f32 {
x.tanh()
}
/// Get input dimension
pub fn input_dim(&self) -> usize {
self.input_dim
}
/// Get hidden dimension
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fastgrnn_creation() {
let grnn = FastGRNN::new(10, 5);
assert_eq!(grnn.input_dim(), 10);
assert_eq!(grnn.hidden_dim(), 5);
}
#[test]
fn test_fastgrnn_step() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5, -0.5, 0.0];
let hidden = vec![0.1, 0.2, 0.3];
let new_hidden = grnn.step(&input, &hidden);
assert_eq!(new_hidden.len(), 3);
// Check that output is bounded (due to tanh and sigmoid)
for &h in &new_hidden {
assert!(h.abs() <= 2.0, "Hidden state should be bounded");
}
}
#[test]
fn test_fastgrnn_forward_single() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = grnn.forward_single(&input);
assert_eq!(output.len(), 3);
}
#[test]
fn test_fastgrnn_sequence() {
let grnn = FastGRNN::new(4, 3);
let inputs = vec![
vec![1.0, 0.5, -0.5, 0.0],
vec![0.5, 1.0, 0.0, -0.5],
vec![-0.5, 0.0, 1.0, 0.5],
];
let outputs = grnn.forward_sequence(&inputs);
assert_eq!(outputs.len(), 3);
assert_eq!(outputs[0].len(), 3);
}
#[test]
fn test_sigmoid() {
let grnn = FastGRNN::new(1, 1);
assert!((grnn.sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(grnn.sigmoid(10.0) > 0.99);
assert!(grnn.sigmoid(-10.0) < 0.01);
}
#[test]
fn test_tanh() {
let grnn = FastGRNN::new(1, 1);
assert!(grnn.tanh(0.0).abs() < 1e-6);
assert!(grnn.tanh(10.0) > 0.99);
assert!(grnn.tanh(-10.0) < -0.99);
}
#[test]
#[should_panic(expected = "Input dimension mismatch")]
fn test_wrong_input_dimension() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5]; // Wrong size
let hidden = vec![0.1, 0.2, 0.3];
grnn.step(&input, &hidden);
}
#[test]
#[should_panic(expected = "Hidden dimension mismatch")]
fn test_wrong_hidden_dimension() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5, -0.5, 0.0];
let hidden = vec![0.1, 0.2]; // Wrong size
grnn.step(&input, &hidden);
}
}

View File

@@ -0,0 +1,24 @@
// Tiny Dancer Routing Module
//
// Neural-powered dynamic agent routing with FastGRNN for adaptive decision-making.
pub mod agents;
pub mod fastgrnn;
pub mod operators;
pub mod router;
pub use agents::{Agent, AgentRegistry, AgentType};
pub use fastgrnn::FastGRNN;
pub use router::{OptimizationTarget, Router, RoutingConstraints, RoutingDecision};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_module_exports() {
// Verify all types are exported
let _registry = AgentRegistry::new();
let _router = Router::new();
}
}

View File

@@ -0,0 +1,610 @@
// PostgreSQL Operators for Tiny Dancer Routing
//
// SQL functions for agent registration, routing, and management.
use pgrx::prelude::*;
use pgrx::JsonB;
use serde_json::json;
use std::sync::OnceLock;
use super::agents::{Agent, AgentRegistry, AgentType};
use super::router::{OptimizationTarget, Router, RoutingConstraints};
// Global agent registry and router
static AGENT_REGISTRY: OnceLock<AgentRegistry> = OnceLock::new();
static ROUTER: OnceLock<Router> = OnceLock::new();
/// Initialize the global registry and router
fn init_router() -> &'static Router {
ROUTER.get_or_init(|| {
let _registry = AGENT_REGISTRY.get_or_init(AgentRegistry::new);
Router::with_registry(std::sync::Arc::new(AgentRegistry::new()))
})
}
/// Get the global agent registry
fn get_registry() -> &'static AgentRegistry {
AGENT_REGISTRY.get_or_init(AgentRegistry::new)
}
/// Register a new AI agent
///
/// # Arguments
/// * `name` - Unique agent identifier
/// * `agent_type` - Type of agent (llm, embedding, specialized, etc.)
/// * `capabilities` - Array of capability strings
/// * `cost_per_request` - Cost per request in dollars
/// * `avg_latency_ms` - Average latency in milliseconds
/// * `quality_score` - Quality score (0-1)
///
/// # Example
/// ```sql
/// SELECT ruvector_register_agent(
/// 'gpt-4',
/// 'llm',
/// ARRAY['code_generation', 'translation'],
/// 0.03,
/// 500.0,
/// 0.95
/// );
/// ```
#[pg_extern]
fn ruvector_register_agent(
name: String,
agent_type: String,
capabilities: Vec<String>,
cost_per_request: f32,
avg_latency_ms: f32,
quality_score: f32,
) -> Result<bool, String> {
let registry = get_registry();
let mut agent = Agent::new(name.clone(), AgentType::from_str(&agent_type), capabilities);
agent.cost_model.per_request = cost_per_request;
agent.performance.avg_latency_ms = avg_latency_ms;
agent.performance.quality_score = quality_score;
registry.register(agent)?;
Ok(true)
}
/// Register an agent with full configuration
///
/// # Arguments
/// * `config` - JSONB configuration with all agent properties
///
/// # Example
/// ```sql
/// SELECT ruvector_register_agent_full('{
/// "name": "gpt-4",
/// "agent_type": "llm",
/// "capabilities": ["code_generation", "translation"],
/// "cost_model": {
/// "per_request": 0.03,
/// "per_token": 0.00006
/// },
/// "performance": {
/// "avg_latency_ms": 500.0,
/// "quality_score": 0.95,
/// "success_rate": 0.99
/// }
/// }'::jsonb);
/// ```
#[pg_extern]
fn ruvector_register_agent_full(config: JsonB) -> Result<bool, String> {
let registry = get_registry();
let agent: Agent = serde_json::from_value(config.0)
.map_err(|e| format!("Invalid agent configuration: {}", e))?;
registry.register(agent)?;
Ok(true)
}
/// Update an existing agent's performance metrics
///
/// # Arguments
/// * `name` - Agent name
/// * `latency_ms` - Observed latency
/// * `success` - Whether the request succeeded
/// * `quality` - Optional quality score for this request
///
/// # Example
/// ```sql
/// SELECT ruvector_update_agent_metrics('gpt-4', 450.0, true, 0.92);
/// ```
#[pg_extern]
fn ruvector_update_agent_metrics(
name: String,
latency_ms: f32,
success: bool,
quality: Option<f32>,
) -> Result<bool, String> {
let registry = get_registry();
let mut agent = registry
.get(&name)
.ok_or_else(|| format!("Agent '{}' not found", name))?;
agent.update_metrics(latency_ms, success, quality);
registry.update(agent)?;
Ok(true)
}
/// Remove an agent from the registry
///
/// # Example
/// ```sql
/// SELECT ruvector_remove_agent('gpt-4');
/// ```
#[pg_extern]
fn ruvector_remove_agent(name: String) -> Result<bool, String> {
let registry = get_registry();
registry
.remove(&name)
.ok_or_else(|| format!("Agent '{}' not found", name))?;
Ok(true)
}
/// Set an agent's active status
///
/// # Example
/// ```sql
/// SELECT ruvector_set_agent_active('gpt-4', false);
/// ```
#[pg_extern]
fn ruvector_set_agent_active(name: String, is_active: bool) -> Result<bool, String> {
let registry = get_registry();
let mut agent = registry
.get(&name)
.ok_or_else(|| format!("Agent '{}' not found", name))?;
agent.is_active = is_active;
registry.update(agent)?;
Ok(true)
}
/// Route a request to the best agent
///
/// # Arguments
/// * `request_embedding` - Request embedding vector (384-dim)
/// * `optimize_for` - Optimization target: 'cost', 'latency', 'quality', 'balanced'
/// * `constraints` - Optional JSONB constraints object
///
/// # Example
/// ```sql
/// SELECT ruvector_route(
/// embedding,
/// 'balanced',
/// '{"max_cost": 0.1, "min_quality": 0.8}'::jsonb
/// )
/// FROM request_embeddings
/// WHERE id = 123;
/// ```
#[pg_extern]
fn ruvector_route(
request_embedding: Vec<f32>,
optimize_for: default!(String, "'balanced'"),
constraints: default!(Option<JsonB>, "NULL"),
) -> Result<JsonB, String> {
init_router(); // Ensure router is initialized
let target = OptimizationTarget::from_str(&optimize_for);
let routing_constraints = if let Some(JsonB(json_val)) = constraints {
serde_json::from_value(json_val).map_err(|e| format!("Invalid constraints: {}", e))?
} else {
RoutingConstraints::default()
};
// Get router with proper registry
let registry = get_registry();
let router = Router::with_registry(std::sync::Arc::new(AgentRegistry::new()));
// Copy agents from global registry to router's registry
for agent in registry.list_all() {
router.registry().register(agent).ok();
}
let decision = router.route(&request_embedding, &routing_constraints, target)?;
let result = json!({
"agent_name": decision.agent_name,
"confidence": decision.confidence,
"estimated_cost": decision.estimated_cost,
"estimated_latency_ms": decision.estimated_latency_ms,
"expected_quality": decision.expected_quality,
"similarity_score": decision.similarity_score,
"reasoning": decision.reasoning,
"alternatives": decision.alternatives,
});
Ok(JsonB(result))
}
/// List all registered agents
///
/// # Example
/// ```sql
/// SELECT * FROM ruvector_list_agents();
/// ```
#[pg_extern]
fn ruvector_list_agents() -> TableIterator<
'static,
(
name!(name, String),
name!(agent_type, String),
name!(capabilities, Vec<String>),
name!(cost_per_request, f32),
name!(avg_latency_ms, f32),
name!(quality_score, f32),
name!(success_rate, f32),
name!(total_requests, i64),
name!(is_active, bool),
),
> {
let registry = get_registry();
let agents = registry.list_all();
TableIterator::new(
agents
.into_iter()
.map(|agent| {
(
agent.name,
agent.agent_type.as_str().to_string(),
agent.capabilities,
agent.cost_model.per_request,
agent.performance.avg_latency_ms,
agent.performance.quality_score,
agent.performance.success_rate,
agent.performance.total_requests as i64,
agent.is_active,
)
})
.collect::<Vec<_>>(),
)
}
/// Get detailed information about a specific agent
///
/// # Example
/// ```sql
/// SELECT ruvector_get_agent('gpt-4');
/// ```
#[pg_extern]
fn ruvector_get_agent(name: String) -> Result<JsonB, String> {
let registry = get_registry();
let agent = registry
.get(&name)
.ok_or_else(|| format!("Agent '{}' not found", name))?;
let result = serde_json::to_value(&agent).map_err(|e| format!("Serialization error: {}", e))?;
Ok(JsonB(result))
}
/// Find agents by capability
///
/// # Example
/// ```sql
/// SELECT * FROM ruvector_find_agents_by_capability('code_generation', 5);
/// ```
#[pg_extern]
fn ruvector_find_agents_by_capability(
capability: String,
limit: default!(i32, 10),
) -> TableIterator<
'static,
(
name!(name, String),
name!(quality_score, f32),
name!(avg_latency_ms, f32),
name!(cost_per_request, f32),
),
> {
let registry = get_registry();
let agents = registry.find_by_capability(&capability, limit as usize);
TableIterator::new(
agents
.into_iter()
.map(|agent| {
(
agent.name,
agent.performance.quality_score,
agent.performance.avg_latency_ms,
agent.cost_model.per_request,
)
})
.collect::<Vec<_>>(),
)
}
/// Get routing statistics
///
/// # Example
/// ```sql
/// SELECT ruvector_routing_stats();
/// ```
#[pg_extern]
fn ruvector_routing_stats() -> JsonB {
let registry = get_registry();
let total_agents = registry.count();
let active_agents = registry.count_active();
let agents = registry.list_all();
let total_requests: u64 = agents.iter().map(|a| a.performance.total_requests).sum();
let avg_quality: f32 = if !agents.is_empty() {
agents
.iter()
.map(|a| a.performance.quality_score)
.sum::<f32>()
/ agents.len() as f32
} else {
0.0
};
let result = json!({
"total_agents": total_agents,
"active_agents": active_agents,
"total_requests": total_requests,
"average_quality": avg_quality,
});
JsonB(result)
}
/// Clear all agents (for testing)
#[pg_extern]
fn ruvector_clear_agents() -> bool {
let registry = get_registry();
registry.clear();
true
}
#[cfg(feature = "pg_test")]
#[pg_schema]
mod tests {
use super::*;
#[pg_test]
fn test_register_agent() {
ruvector_clear_agents();
let result = ruvector_register_agent(
"test-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
);
assert!(result.is_ok());
assert_eq!(result.unwrap(), true);
// Verify agent was registered
let agent = ruvector_get_agent("test-agent".to_string());
assert!(agent.is_ok());
}
#[pg_test]
fn test_register_duplicate_agent() {
ruvector_clear_agents();
ruvector_register_agent(
"test-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
)
.unwrap();
// Try to register again
let result = ruvector_register_agent(
"test-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
);
assert!(result.is_err());
}
#[pg_test]
fn test_update_agent_metrics() {
ruvector_clear_agents();
ruvector_register_agent(
"test-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
)
.unwrap();
let result =
ruvector_update_agent_metrics("test-agent".to_string(), 150.0, true, Some(0.9));
assert!(result.is_ok());
}
#[pg_test]
fn test_remove_agent() {
ruvector_clear_agents();
ruvector_register_agent(
"test-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
)
.unwrap();
let result = ruvector_remove_agent("test-agent".to_string());
assert!(result.is_ok());
// Verify agent was removed
let agent = ruvector_get_agent("test-agent".to_string());
assert!(agent.is_err());
}
#[pg_test]
fn test_set_agent_active() {
ruvector_clear_agents();
ruvector_register_agent(
"test-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
)
.unwrap();
let result = ruvector_set_agent_active("test-agent".to_string(), false);
assert!(result.is_ok());
let agent_json = ruvector_get_agent("test-agent".to_string()).unwrap();
let agent: Agent = serde_json::from_value(agent_json.0).unwrap();
assert_eq!(agent.is_active, false);
}
#[pg_test]
fn test_list_agents() {
ruvector_clear_agents();
ruvector_register_agent(
"agent1".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
)
.unwrap();
ruvector_register_agent(
"agent2".to_string(),
"embedding".to_string(),
vec!["similarity".to_string()],
0.01,
50.0,
0.90,
)
.unwrap();
let agents: Vec<_> = ruvector_list_agents().collect();
assert_eq!(agents.len(), 2);
}
#[pg_test]
fn test_find_agents_by_capability() {
ruvector_clear_agents();
ruvector_register_agent(
"coder1".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
)
.unwrap();
ruvector_register_agent(
"coder2".to_string(),
"llm".to_string(),
vec!["coding".to_string(), "translation".to_string()],
0.08,
250.0,
0.90,
)
.unwrap();
ruvector_register_agent(
"translator".to_string(),
"llm".to_string(),
vec!["translation".to_string()],
0.03,
150.0,
0.80,
)
.unwrap();
let coders: Vec<_> = ruvector_find_agents_by_capability("coding".to_string(), 10).collect();
assert_eq!(coders.len(), 2);
}
#[pg_test]
fn test_routing_stats() {
ruvector_clear_agents();
ruvector_register_agent(
"agent1".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.05,
200.0,
0.85,
)
.unwrap();
let stats = ruvector_routing_stats();
let stats_obj: serde_json::Value = stats.0;
assert_eq!(stats_obj["total_agents"], 1);
assert_eq!(stats_obj["active_agents"], 1);
}
#[pg_test]
fn test_route_basic() {
ruvector_clear_agents();
ruvector_register_agent(
"cheap-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.01,
200.0,
0.70,
)
.unwrap();
ruvector_register_agent(
"expensive-agent".to_string(),
"llm".to_string(),
vec!["coding".to_string()],
0.10,
200.0,
0.95,
)
.unwrap();
let embedding = vec![0.1; 384];
// Route optimizing for cost
let result = ruvector_route(embedding.clone(), "cost".to_string(), None);
assert!(result.is_ok());
let decision = result.unwrap().0;
assert_eq!(decision["agent_name"], "cheap-agent");
}
}

View File

@@ -0,0 +1,628 @@
// Neural-Powered Agent Router
//
// Dynamic routing with FastGRNN and multi-objective optimization.
use super::agents::{Agent, AgentRegistry};
use super::fastgrnn::FastGRNN;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Optimization target for routing decisions
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OptimizationTarget {
/// Minimize cost
Cost,
/// Minimize latency
Latency,
/// Maximize quality
Quality,
/// Balanced optimization
Balanced,
}
impl OptimizationTarget {
/// Parse from string
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"cost" => OptimizationTarget::Cost,
"latency" => OptimizationTarget::Latency,
"quality" => OptimizationTarget::Quality,
"balanced" => OptimizationTarget::Balanced,
_ => OptimizationTarget::Balanced,
}
}
/// Convert to string
pub fn as_str(&self) -> &str {
match self {
OptimizationTarget::Cost => "cost",
OptimizationTarget::Latency => "latency",
OptimizationTarget::Quality => "quality",
OptimizationTarget::Balanced => "balanced",
}
}
}
/// Constraints for routing decisions
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingConstraints {
/// Maximum acceptable cost
pub max_cost: Option<f32>,
/// Maximum acceptable latency in ms
pub max_latency_ms: Option<f32>,
/// Minimum required quality score (0-1)
pub min_quality: Option<f32>,
/// Required capabilities
pub required_capabilities: Vec<String>,
/// Excluded agent names
pub excluded_agents: Vec<String>,
}
impl Default for RoutingConstraints {
fn default() -> Self {
Self {
max_cost: None,
max_latency_ms: None,
min_quality: None,
required_capabilities: Vec::new(),
excluded_agents: Vec::new(),
}
}
}
impl RoutingConstraints {
/// Create new constraints
pub fn new() -> Self {
Self::default()
}
/// Set maximum cost
pub fn with_max_cost(mut self, cost: f32) -> Self {
self.max_cost = Some(cost);
self
}
/// Set maximum latency
pub fn with_max_latency(mut self, latency_ms: f32) -> Self {
self.max_latency_ms = Some(latency_ms);
self
}
/// Set minimum quality
pub fn with_min_quality(mut self, quality: f32) -> Self {
self.min_quality = Some(quality);
self
}
/// Add required capability
pub fn with_capability(mut self, capability: String) -> Self {
self.required_capabilities.push(capability);
self
}
/// Add excluded agent
pub fn with_excluded_agent(mut self, agent_name: String) -> Self {
self.excluded_agents.push(agent_name);
self
}
}
/// Routing decision result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingDecision {
/// Selected agent name
pub agent_name: String,
/// Confidence score (0-1)
pub confidence: f32,
/// Estimated cost
pub estimated_cost: f32,
/// Estimated latency in ms
pub estimated_latency_ms: f32,
/// Expected quality
pub expected_quality: f32,
/// Similarity score to request
pub similarity_score: f32,
/// Reasoning for the decision
pub reasoning: String,
/// Alternative agents considered
pub alternatives: Vec<AlternativeAgent>,
}
/// Alternative agent option
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AlternativeAgent {
/// Agent name
pub name: String,
/// Score
pub score: f32,
/// Why it wasn't selected
pub reason: String,
}
/// Neural-powered agent router
pub struct Router {
/// Agent registry
registry: Arc<AgentRegistry>,
/// FastGRNN model for neural routing
grnn: Option<FastGRNN>,
/// Embedding dimension
embedding_dim: usize,
}
impl Router {
/// Create a new router
pub fn new() -> Self {
Self {
registry: Arc::new(AgentRegistry::new()),
grnn: None,
embedding_dim: 384, // Default embedding size
}
}
/// Create router with custom registry
pub fn with_registry(registry: Arc<AgentRegistry>) -> Self {
Self {
registry,
grnn: None,
embedding_dim: 384,
}
}
/// Initialize FastGRNN model
pub fn init_grnn(&mut self, hidden_dim: usize) {
self.grnn = Some(FastGRNN::new(self.embedding_dim, hidden_dim));
}
/// Set FastGRNN model from weights
pub fn set_grnn(&mut self, grnn: FastGRNN) {
self.grnn = Some(grnn);
}
/// Route a request to the best agent
pub fn route(
&self,
request_embedding: &[f32],
constraints: &RoutingConstraints,
target: OptimizationTarget,
) -> Result<RoutingDecision, String> {
// Get candidate agents
let candidates = self.get_candidates(constraints)?;
if candidates.is_empty() {
return Err("No agents match the constraints".to_string());
}
// Score all candidates
let mut scored_candidates: Vec<(Agent, f32, f32)> = candidates
.iter()
.filter_map(|agent| {
// Calculate similarity
let similarity = if let Some(agent_emb) = &agent.embedding {
cosine_similarity(request_embedding, agent_emb)
} else {
0.5 // Default similarity if no embedding
};
// Calculate score based on target
let score = self.score_agent(agent, request_embedding, target, similarity);
// Apply constraints
if self.meets_constraints(agent, constraints) {
Some((agent.clone(), score, similarity))
} else {
None
}
})
.collect();
if scored_candidates.is_empty() {
return Err("No agents meet the specified constraints".to_string());
}
// Sort by score (descending)
scored_candidates
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
// Select best agent
let (best_agent, best_score, similarity) = &scored_candidates[0];
// Calculate confidence using FastGRNN if available
let confidence = if let Some(ref grnn) = self.grnn {
let hidden = grnn.forward_single(request_embedding);
// Use hidden state magnitude as confidence
let magnitude: f32 = hidden.iter().map(|&h| h * h).sum::<f32>().sqrt();
(magnitude / hidden.len() as f32).min(1.0).max(0.0)
} else {
*best_score
};
// Build alternatives list
let alternatives: Vec<AlternativeAgent> = scored_candidates
.iter()
.skip(1)
.take(3)
.map(|(agent, score, _)| AlternativeAgent {
name: agent.name.clone(),
score: *score,
reason: self.compare_to_best(agent, best_agent, target),
})
.collect();
// Generate reasoning
let reasoning = self.generate_reasoning(best_agent, target, *similarity);
Ok(RoutingDecision {
agent_name: best_agent.name.clone(),
confidence,
estimated_cost: best_agent.cost_model.per_request,
estimated_latency_ms: best_agent.performance.avg_latency_ms,
expected_quality: best_agent.performance.quality_score,
similarity_score: *similarity,
reasoning,
alternatives,
})
}
/// Get candidate agents based on constraints
fn get_candidates(&self, constraints: &RoutingConstraints) -> Result<Vec<Agent>, String> {
let mut agents = self.registry.list_active();
// Filter by required capabilities
if !constraints.required_capabilities.is_empty() {
agents.retain(|agent| {
constraints
.required_capabilities
.iter()
.all(|cap| agent.has_capability(cap))
});
}
// Filter excluded agents
if !constraints.excluded_agents.is_empty() {
agents.retain(|agent| !constraints.excluded_agents.contains(&agent.name));
}
Ok(agents)
}
/// Check if agent meets constraints
fn meets_constraints(&self, agent: &Agent, constraints: &RoutingConstraints) -> bool {
// Check cost constraint
if let Some(max_cost) = constraints.max_cost {
if agent.cost_model.per_request > max_cost {
return false;
}
}
// Check latency constraint
if let Some(max_latency) = constraints.max_latency_ms {
if agent.performance.avg_latency_ms > max_latency {
return false;
}
}
// Check quality constraint
if let Some(min_quality) = constraints.min_quality {
if agent.performance.quality_score < min_quality {
return false;
}
}
true
}
/// Score an agent for a given target
fn score_agent(
&self,
agent: &Agent,
_request_embedding: &[f32],
target: OptimizationTarget,
similarity: f32,
) -> f32 {
match target {
OptimizationTarget::Cost => {
// Lower cost = higher score
let cost_score = 1.0 / (1.0 + agent.cost_model.per_request);
cost_score * 0.7 + similarity * 0.3
}
OptimizationTarget::Latency => {
// Lower latency = higher score
let latency_score = 1.0 / (1.0 + agent.performance.avg_latency_ms / 1000.0);
latency_score * 0.7 + similarity * 0.3
}
OptimizationTarget::Quality => {
// Higher quality = higher score
agent.performance.quality_score * 0.7 + similarity * 0.3
}
OptimizationTarget::Balanced => {
// Balanced scoring
let cost_score = 1.0 / (1.0 + agent.cost_model.per_request);
let latency_score = 1.0 / (1.0 + agent.performance.avg_latency_ms / 1000.0);
let quality_score = agent.performance.quality_score;
cost_score * 0.25 + latency_score * 0.25 + quality_score * 0.25 + similarity * 0.25
}
}
}
/// Compare agent to best agent
fn compare_to_best(&self, agent: &Agent, best: &Agent, target: OptimizationTarget) -> String {
match target {
OptimizationTarget::Cost => {
let diff = agent.cost_model.per_request - best.cost_model.per_request;
format!("${:.4} more expensive", diff)
}
OptimizationTarget::Latency => {
let diff = agent.performance.avg_latency_ms - best.performance.avg_latency_ms;
format!("{:.1}ms slower", diff)
}
OptimizationTarget::Quality => {
let diff = best.performance.quality_score - agent.performance.quality_score;
format!("{:.2} lower quality", diff)
}
OptimizationTarget::Balanced => "Lower overall score".to_string(),
}
}
/// Generate reasoning for decision
fn generate_reasoning(
&self,
agent: &Agent,
target: OptimizationTarget,
similarity: f32,
) -> String {
let target_reason = match target {
OptimizationTarget::Cost => {
format!("lowest cost (${:.4}/request)", agent.cost_model.per_request)
}
OptimizationTarget::Latency => format!(
"fastest response ({:.1}ms avg)",
agent.performance.avg_latency_ms
),
OptimizationTarget::Quality => format!(
"highest quality (score: {:.2})",
agent.performance.quality_score
),
OptimizationTarget::Balanced => "best overall balance".to_string(),
};
format!(
"Selected {} for {} with {:.1}% similarity to request",
agent.name,
target_reason,
similarity * 100.0
)
}
/// Get registry reference
pub fn registry(&self) -> &Arc<AgentRegistry> {
&self.registry
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
/// Calculate cosine similarity between two vectors
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot_product / (norm_a * norm_b)).max(-1.0).min(1.0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::routing::agents::AgentType;
fn create_test_agent(name: &str, cost: f32, latency: f32, quality: f32) -> Agent {
let mut agent = Agent::new(name.to_string(), AgentType::LLM, vec!["test".to_string()]);
agent.cost_model.per_request = cost;
agent.performance.avg_latency_ms = latency;
agent.performance.quality_score = quality;
agent.embedding = Some(vec![0.1; 384]);
agent
}
#[test]
fn test_optimization_target_parsing() {
assert_eq!(
OptimizationTarget::from_str("cost"),
OptimizationTarget::Cost
);
assert_eq!(
OptimizationTarget::from_str("LATENCY"),
OptimizationTarget::Latency
);
assert_eq!(
OptimizationTarget::from_str("quality"),
OptimizationTarget::Quality
);
assert_eq!(
OptimizationTarget::from_str("balanced"),
OptimizationTarget::Balanced
);
assert_eq!(
OptimizationTarget::from_str("unknown"),
OptimizationTarget::Balanced
);
}
#[test]
fn test_routing_constraints_builder() {
let constraints = RoutingConstraints::new()
.with_max_cost(0.1)
.with_max_latency(500.0)
.with_min_quality(0.8)
.with_capability("test".to_string());
assert_eq!(constraints.max_cost, Some(0.1));
assert_eq!(constraints.max_latency_ms, Some(500.0));
assert_eq!(constraints.min_quality, Some(0.8));
assert_eq!(constraints.required_capabilities.len(), 1);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![1.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&c, &d).abs() < 1e-6);
let e = vec![1.0, 1.0, 0.0];
let f = vec![1.0, 1.0, 0.0];
assert!((cosine_similarity(&e, &f) - 1.0).abs() < 1e-6);
}
#[test]
fn test_router_creation() {
let router = Router::new();
assert!(router.grnn.is_none());
assert_eq!(router.registry().count(), 0);
}
#[test]
fn test_router_init_grnn() {
let mut router = Router::new();
router.init_grnn(64);
assert!(router.grnn.is_some());
}
#[test]
fn test_route_cost_optimization() {
let router = Router::new();
// Register agents with different costs
router
.registry()
.register(create_test_agent("cheap", 0.01, 100.0, 0.7))
.unwrap();
router
.registry()
.register(create_test_agent("expensive", 0.10, 100.0, 0.9))
.unwrap();
let request_emb = vec![0.1; 384];
let constraints = RoutingConstraints::new();
let decision = router
.route(&request_emb, &constraints, OptimizationTarget::Cost)
.unwrap();
assert_eq!(decision.agent_name, "cheap");
}
#[test]
fn test_route_latency_optimization() {
let router = Router::new();
router
.registry()
.register(create_test_agent("fast", 0.05, 50.0, 0.7))
.unwrap();
router
.registry()
.register(create_test_agent("slow", 0.05, 500.0, 0.9))
.unwrap();
let request_emb = vec![0.1; 384];
let constraints = RoutingConstraints::new();
let decision = router
.route(&request_emb, &constraints, OptimizationTarget::Latency)
.unwrap();
assert_eq!(decision.agent_name, "fast");
}
#[test]
fn test_route_quality_optimization() {
let router = Router::new();
router
.registry()
.register(create_test_agent("low_quality", 0.05, 100.0, 0.5))
.unwrap();
router
.registry()
.register(create_test_agent("high_quality", 0.05, 100.0, 0.95))
.unwrap();
let request_emb = vec![0.1; 384];
let constraints = RoutingConstraints::new();
let decision = router
.route(&request_emb, &constraints, OptimizationTarget::Quality)
.unwrap();
assert_eq!(decision.agent_name, "high_quality");
}
#[test]
fn test_route_with_constraints() {
let router = Router::new();
router
.registry()
.register(create_test_agent("expensive", 1.0, 100.0, 0.9))
.unwrap();
router
.registry()
.register(create_test_agent("cheap", 0.01, 100.0, 0.7))
.unwrap();
let request_emb = vec![0.1; 384];
let constraints = RoutingConstraints::new().with_max_cost(0.5);
let decision = router
.route(&request_emb, &constraints, OptimizationTarget::Quality)
.unwrap();
// Should select cheap even though expensive has higher quality
assert_eq!(decision.agent_name, "cheap");
}
#[test]
fn test_route_no_candidates() {
let router = Router::new();
let request_emb = vec![0.1; 384];
let constraints = RoutingConstraints::new();
let result = router.route(&request_emb, &constraints, OptimizationTarget::Balanced);
assert!(result.is_err());
}
#[test]
fn test_route_capability_filter() {
let router = Router::new();
let mut agent1 = create_test_agent("coder", 0.05, 100.0, 0.8);
agent1.capabilities = vec!["coding".to_string()];
let mut agent2 = create_test_agent("translator", 0.05, 100.0, 0.8);
agent2.capabilities = vec!["translation".to_string()];
router.registry().register(agent1).unwrap();
router.registry().register(agent2).unwrap();
let request_emb = vec![0.1; 384];
let constraints = RoutingConstraints::new().with_capability("coding".to_string());
let decision = router
.route(&request_emb, &constraints, OptimizationTarget::Balanced)
.unwrap();
assert_eq!(decision.agent_name, "coder");
}
}