Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
402
vendor/ruvector/crates/ruvector-postgres/src/routing/README.md
vendored
Normal file
402
vendor/ruvector/crates/ruvector-postgres/src/routing/README.md
vendored
Normal file
@@ -0,0 +1,402 @@
|
||||
# Tiny Dancer Routing Module
|
||||
|
||||
Neural-powered dynamic agent routing with FastGRNN for intelligent AI agent selection.
|
||||
|
||||
## Overview
|
||||
|
||||
The Tiny Dancer routing module provides intelligent routing of requests to AI agents based on multiple optimization criteria including cost, latency, quality, and balanced performance. It uses a FastGRNN (Fast Gated Recurrent Neural Network) for adaptive decision-making.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Components
|
||||
|
||||
1. **FastGRNN** (`fastgrnn.rs`)
|
||||
- Lightweight gated recurrent neural network
|
||||
- Real-time routing decisions with minimal compute
|
||||
- Adaptive learning from routing patterns
|
||||
|
||||
2. **Agent Registry** (`agents.rs`)
|
||||
- Thread-safe agent storage with DashMap
|
||||
- Capability-based agent discovery
|
||||
- Performance metrics tracking
|
||||
|
||||
3. **Router** (`router.rs`)
|
||||
- Multi-objective optimization
|
||||
- Constraint-based filtering
|
||||
- Neural-enhanced confidence scoring
|
||||
|
||||
4. **PostgreSQL Operators** (`operators.rs`)
|
||||
- SQL functions for agent management
|
||||
- Routing query interface
|
||||
- Statistics and monitoring
|
||||
|
||||
## PostgreSQL Functions
|
||||
|
||||
### Agent Registration
|
||||
|
||||
```sql
|
||||
-- Register a simple agent
|
||||
SELECT ruvector_register_agent(
|
||||
'gpt-4', -- Agent name
|
||||
'llm', -- Agent type
|
||||
ARRAY['code_generation', 'reasoning'], -- Capabilities
|
||||
0.03, -- Cost per request ($)
|
||||
500.0, -- Average latency (ms)
|
||||
0.95 -- Quality score (0-1)
|
||||
);
|
||||
|
||||
-- Register with full configuration
|
||||
SELECT ruvector_register_agent_full('{
|
||||
"name": "claude-3-opus",
|
||||
"agent_type": "llm",
|
||||
"capabilities": ["coding", "reasoning", "writing"],
|
||||
"cost_model": {
|
||||
"per_request": 0.025,
|
||||
"per_token": 0.00005
|
||||
},
|
||||
"performance": {
|
||||
"avg_latency_ms": 400.0,
|
||||
"quality_score": 0.93,
|
||||
"success_rate": 0.99,
|
||||
"p95_latency_ms": 600.0,
|
||||
"p99_latency_ms": 1000.0
|
||||
},
|
||||
"is_active": true
|
||||
}'::jsonb);
|
||||
```
|
||||
|
||||
### Routing Requests
|
||||
|
||||
```sql
|
||||
-- Basic routing (optimize for balanced performance)
|
||||
SELECT ruvector_route(
|
||||
embedding_vector, -- Request embedding (384-dim)
|
||||
'balanced', -- Optimization target
|
||||
NULL -- No constraints
|
||||
)
|
||||
FROM requests
|
||||
WHERE id = 123;
|
||||
|
||||
-- Cost-optimized routing with constraints
|
||||
SELECT ruvector_route(
|
||||
embedding_vector,
|
||||
'cost',
|
||||
'{"max_latency_ms": 1000.0, "min_quality": 0.8}'::jsonb
|
||||
)
|
||||
FROM requests
|
||||
WHERE id = 456;
|
||||
|
||||
-- Quality-optimized with capability requirements
|
||||
SELECT ruvector_route(
|
||||
embedding_vector,
|
||||
'quality',
|
||||
'{
|
||||
"max_cost": 0.1,
|
||||
"required_capabilities": ["code_generation", "debugging"],
|
||||
"excluded_agents": ["slow-agent"]
|
||||
}'::jsonb
|
||||
);
|
||||
|
||||
-- Latency-optimized routing
|
||||
SELECT ruvector_route(
|
||||
embedding_vector,
|
||||
'latency',
|
||||
'{"max_latency_ms": 500.0}'::jsonb
|
||||
);
|
||||
```
|
||||
|
||||
### Agent Management
|
||||
|
||||
```sql
|
||||
-- List all agents
|
||||
SELECT * FROM ruvector_list_agents();
|
||||
|
||||
-- Get specific agent details
|
||||
SELECT ruvector_get_agent('gpt-4');
|
||||
|
||||
-- Find agents by capability
|
||||
SELECT * FROM ruvector_find_agents_by_capability('code_generation', 5);
|
||||
|
||||
-- Update agent performance metrics
|
||||
SELECT ruvector_update_agent_metrics(
|
||||
'gpt-4', -- Agent name
|
||||
450.0, -- Observed latency (ms)
|
||||
true, -- Success
|
||||
0.92 -- Quality score (optional)
|
||||
);
|
||||
|
||||
-- Deactivate an agent
|
||||
SELECT ruvector_set_agent_active('gpt-4', false);
|
||||
|
||||
-- Remove an agent
|
||||
SELECT ruvector_remove_agent('old-agent');
|
||||
|
||||
-- Get routing statistics
|
||||
SELECT ruvector_routing_stats();
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Example 1: Multi-Model Routing System
|
||||
|
||||
```sql
|
||||
-- Register various AI models
|
||||
SELECT ruvector_register_agent('gpt-4', 'llm',
|
||||
ARRAY['coding', 'reasoning', 'math'], 0.03, 500.0, 0.95);
|
||||
SELECT ruvector_register_agent('gpt-3.5-turbo', 'llm',
|
||||
ARRAY['general', 'fast'], 0.002, 200.0, 0.75);
|
||||
SELECT ruvector_register_agent('claude-3-opus', 'llm',
|
||||
ARRAY['coding', 'writing', 'analysis'], 0.025, 400.0, 0.93);
|
||||
SELECT ruvector_register_agent('llama-2-70b', 'llm',
|
||||
ARRAY['local', 'private'], 0.0, 800.0, 0.72);
|
||||
|
||||
-- Create routing view
|
||||
CREATE VIEW intelligent_routing AS
|
||||
SELECT
|
||||
r.id,
|
||||
r.query_text,
|
||||
r.embedding,
|
||||
route.agent_name,
|
||||
route.confidence,
|
||||
route.estimated_cost,
|
||||
route.estimated_latency_ms,
|
||||
route.expected_quality,
|
||||
route.reasoning
|
||||
FROM requests r,
|
||||
LATERAL (
|
||||
SELECT (ruvector_route(
|
||||
r.embedding,
|
||||
'balanced',
|
||||
NULL
|
||||
))::jsonb AS route_data
|
||||
) route_query,
|
||||
LATERAL jsonb_to_record(route_query.route_data) AS route(
|
||||
agent_name text,
|
||||
confidence float4,
|
||||
estimated_cost float4,
|
||||
estimated_latency_ms float4,
|
||||
expected_quality float4,
|
||||
similarity_score float4,
|
||||
reasoning text
|
||||
);
|
||||
|
||||
-- Query with automatic routing
|
||||
SELECT * FROM intelligent_routing WHERE id = 123;
|
||||
```
|
||||
|
||||
### Example 2: Cost-Aware Batch Processing
|
||||
|
||||
```sql
|
||||
-- Process batch with cost constraints
|
||||
CREATE TEMP TABLE batch_results AS
|
||||
SELECT
|
||||
r.id,
|
||||
r.query_text,
|
||||
routing.agent_name,
|
||||
routing.estimated_cost,
|
||||
routing.expected_quality
|
||||
FROM requests r
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT (ruvector_route(
|
||||
r.embedding,
|
||||
'cost',
|
||||
'{"max_cost": 0.01, "min_quality": 0.7}'::jsonb
|
||||
))::jsonb->'agent_name' AS agent_name,
|
||||
(ruvector_route(
|
||||
r.embedding,
|
||||
'cost',
|
||||
'{"max_cost": 0.01, "min_quality": 0.7}'::jsonb
|
||||
))::jsonb->'estimated_cost' AS estimated_cost,
|
||||
(ruvector_route(
|
||||
r.embedding,
|
||||
'cost',
|
||||
'{"max_cost": 0.01, "min_quality": 0.7}'::jsonb
|
||||
))::jsonb->'expected_quality' AS expected_quality
|
||||
) routing
|
||||
WHERE r.processed = false
|
||||
LIMIT 1000;
|
||||
|
||||
-- Calculate total estimated cost
|
||||
SELECT
|
||||
SUM((estimated_cost)::float) AS total_cost,
|
||||
AVG((expected_quality)::float) AS avg_quality,
|
||||
COUNT(*) AS total_requests
|
||||
FROM batch_results;
|
||||
```
|
||||
|
||||
### Example 3: Quality-First Routing
|
||||
|
||||
```sql
|
||||
-- Route critical requests to highest quality agents
|
||||
CREATE FUNCTION route_critical_request(
|
||||
request_embedding float4[],
|
||||
min_quality float4 DEFAULT 0.9
|
||||
) RETURNS jsonb AS $$
|
||||
SELECT ruvector_route(
|
||||
request_embedding,
|
||||
'quality',
|
||||
jsonb_build_object(
|
||||
'min_quality', min_quality,
|
||||
'max_latency_ms', 2000.0,
|
||||
'required_capabilities', ARRAY['reasoning', 'analysis']
|
||||
)
|
||||
);
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
-- Use the function
|
||||
SELECT route_critical_request(embedding_vector, 0.95)
|
||||
FROM critical_requests
|
||||
WHERE priority = 'high';
|
||||
```
|
||||
|
||||
### Example 4: Real-time Performance Tracking
|
||||
|
||||
```sql
|
||||
-- Update metrics after each request
|
||||
CREATE FUNCTION record_agent_performance(
|
||||
agent_name text,
|
||||
actual_latency_ms float4,
|
||||
success boolean,
|
||||
quality_score float4
|
||||
) RETURNS void AS $$
|
||||
BEGIN
|
||||
PERFORM ruvector_update_agent_metrics(
|
||||
agent_name,
|
||||
actual_latency_ms,
|
||||
success,
|
||||
quality_score
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Trigger to auto-update metrics
|
||||
CREATE TRIGGER update_agent_metrics_trigger
|
||||
AFTER INSERT ON request_completions
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION record_agent_performance(
|
||||
NEW.agent_name,
|
||||
NEW.latency_ms,
|
||||
NEW.success,
|
||||
NEW.quality_score
|
||||
);
|
||||
```
|
||||
|
||||
### Example 5: Capability-Based Routing
|
||||
|
||||
```sql
|
||||
-- Create specialized routing functions
|
||||
CREATE FUNCTION route_code_request(emb float4[]) RETURNS text AS $$
|
||||
SELECT (ruvector_route(
|
||||
emb,
|
||||
'quality',
|
||||
'{"required_capabilities": ["coding", "debugging"]}'::jsonb
|
||||
))::jsonb->>'agent_name';
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
CREATE FUNCTION route_writing_request(emb float4[]) RETURNS text AS $$
|
||||
SELECT (ruvector_route(
|
||||
emb,
|
||||
'quality',
|
||||
'{"required_capabilities": ["writing", "editing"]}'::jsonb
|
||||
))::jsonb->>'agent_name';
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
-- Use in application logic
|
||||
SELECT
|
||||
CASE
|
||||
WHEN task_type = 'code' THEN route_code_request(embedding)
|
||||
WHEN task_type = 'write' THEN route_writing_request(embedding)
|
||||
ELSE (ruvector_route(embedding, 'balanced', NULL))::jsonb->>'agent_name'
|
||||
END AS selected_agent
|
||||
FROM tasks;
|
||||
```
|
||||
|
||||
## Optimization Targets
|
||||
|
||||
### Cost
|
||||
- Minimizes cost per request
|
||||
- Considers both per-request and per-token costs
|
||||
- Ideal for high-volume, cost-sensitive workloads
|
||||
|
||||
### Latency
|
||||
- Minimizes response time
|
||||
- Uses average latency metrics
|
||||
- Best for real-time applications
|
||||
|
||||
### Quality
|
||||
- Maximizes quality score
|
||||
- Based on historical performance
|
||||
- Recommended for critical tasks
|
||||
|
||||
### Balanced
|
||||
- Multi-objective optimization
|
||||
- Balances cost, latency, quality, and similarity
|
||||
- Default for general-purpose routing
|
||||
|
||||
## Constraints
|
||||
|
||||
### max_cost
|
||||
Maximum acceptable cost per request (in dollars)
|
||||
|
||||
### max_latency_ms
|
||||
Maximum acceptable latency in milliseconds
|
||||
|
||||
### min_quality
|
||||
Minimum required quality score (0-1 scale)
|
||||
|
||||
### required_capabilities
|
||||
Array of required agent capabilities
|
||||
|
||||
### excluded_agents
|
||||
Array of agent names to exclude from selection
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Agent Registry**: Thread-safe with DashMap for concurrent access
|
||||
2. **Embedding Similarity**: Uses fast cosine similarity for request matching
|
||||
3. **FastGRNN**: Lightweight neural network for real-time inference
|
||||
4. **Caching**: Consider caching routing decisions for identical requests
|
||||
|
||||
## Monitoring
|
||||
|
||||
```sql
|
||||
-- View agent statistics
|
||||
SELECT name, total_requests, avg_latency_ms, quality_score, success_rate
|
||||
FROM ruvector_list_agents()
|
||||
ORDER BY total_requests DESC;
|
||||
|
||||
-- Get overall routing statistics
|
||||
SELECT ruvector_routing_stats();
|
||||
|
||||
-- Find underperforming agents
|
||||
SELECT name, success_rate, quality_score
|
||||
FROM ruvector_list_agents()
|
||||
WHERE success_rate < 0.95
|
||||
OR quality_score < 0.7;
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Register Accurate Metrics**: Keep agent performance metrics up-to-date
|
||||
2. **Use Constraints**: Always set appropriate constraints for production
|
||||
3. **Monitor Performance**: Track actual vs. estimated metrics
|
||||
4. **Update Regularly**: Use `ruvector_update_agent_metrics` after each request
|
||||
5. **Capability Matching**: Ensure agents have accurate capability tags
|
||||
6. **Cost Tracking**: Monitor total routing costs with statistics queries
|
||||
|
||||
## Integration with Other Modules
|
||||
|
||||
The routing module integrates seamlessly with:
|
||||
- **Vector Search**: Use query embeddings for semantic routing
|
||||
- **GNN**: Enhance routing with graph neural networks
|
||||
- **Quantization**: Reduce embedding storage costs
|
||||
- **HNSW Index**: Fast similarity search for agent selection
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] A/B testing framework for agent comparison
|
||||
- [ ] Multi-armed bandit algorithms for exploration
|
||||
- [ ] Reinforcement learning for adaptive routing
|
||||
- [ ] Cost prediction models
|
||||
- [ ] Load balancing across agent instances
|
||||
- [ ] Geo-distributed agent routing
|
||||
500
vendor/ruvector/crates/ruvector-postgres/src/routing/agents.rs
vendored
Normal file
500
vendor/ruvector/crates/ruvector-postgres/src/routing/agents.rs
vendored
Normal file
@@ -0,0 +1,500 @@
|
||||
// Agent Registry and Management
|
||||
//
|
||||
// Thread-safe registry for managing AI agents with capabilities and performance metrics.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Type of AI agent
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum AgentType {
|
||||
/// Language model (GPT, Claude, etc.)
|
||||
LLM,
|
||||
/// Embedding model
|
||||
Embedding,
|
||||
/// Specialized task agent
|
||||
Specialized,
|
||||
/// Vision model
|
||||
Vision,
|
||||
/// Audio model
|
||||
Audio,
|
||||
/// Multimodal agent
|
||||
Multimodal,
|
||||
/// Custom agent type
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl AgentType {
|
||||
/// Parse agent type from string
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"llm" => AgentType::LLM,
|
||||
"embedding" => AgentType::Embedding,
|
||||
"specialized" => AgentType::Specialized,
|
||||
"vision" => AgentType::Vision,
|
||||
"audio" => AgentType::Audio,
|
||||
"multimodal" => AgentType::Multimodal,
|
||||
_ => AgentType::Custom(s.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to string
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
AgentType::LLM => "llm",
|
||||
AgentType::Embedding => "embedding",
|
||||
AgentType::Specialized => "specialized",
|
||||
AgentType::Vision => "vision",
|
||||
AgentType::Audio => "audio",
|
||||
AgentType::Multimodal => "multimodal",
|
||||
AgentType::Custom(s) => s,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cost model for agent usage
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CostModel {
|
||||
/// Cost per request
|
||||
pub per_request: f32,
|
||||
/// Cost per token (if applicable)
|
||||
pub per_token: Option<f32>,
|
||||
/// Fixed monthly cost
|
||||
pub monthly_fixed: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for CostModel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
per_request: 0.0,
|
||||
per_token: None,
|
||||
monthly_fixed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Performance metrics for an agent
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PerformanceMetrics {
|
||||
/// Average latency in milliseconds
|
||||
pub avg_latency_ms: f32,
|
||||
/// 95th percentile latency
|
||||
pub p95_latency_ms: f32,
|
||||
/// 99th percentile latency
|
||||
pub p99_latency_ms: f32,
|
||||
/// Quality score (0-1)
|
||||
pub quality_score: f32,
|
||||
/// Success rate (0-1)
|
||||
pub success_rate: f32,
|
||||
/// Total requests processed
|
||||
pub total_requests: u64,
|
||||
}
|
||||
|
||||
impl Default for PerformanceMetrics {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
avg_latency_ms: 100.0,
|
||||
p95_latency_ms: 200.0,
|
||||
p99_latency_ms: 500.0,
|
||||
quality_score: 0.8,
|
||||
success_rate: 0.99,
|
||||
total_requests: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// AI Agent definition with capabilities and metrics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Agent {
|
||||
/// Unique agent name
|
||||
pub name: String,
|
||||
/// Agent type
|
||||
pub agent_type: AgentType,
|
||||
/// Capabilities (e.g., ["code_generation", "translation"])
|
||||
pub capabilities: Vec<String>,
|
||||
/// Cost model
|
||||
pub cost_model: CostModel,
|
||||
/// Performance metrics
|
||||
pub performance: PerformanceMetrics,
|
||||
/// Agent embedding for similarity matching (384-dim)
|
||||
pub embedding: Option<Vec<f32>>,
|
||||
/// Whether agent is currently active
|
||||
pub is_active: bool,
|
||||
/// Additional metadata
|
||||
pub metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Create a new agent
|
||||
pub fn new(name: String, agent_type: AgentType, capabilities: Vec<String>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
agent_type,
|
||||
capabilities,
|
||||
cost_model: CostModel::default(),
|
||||
performance: PerformanceMetrics::default(),
|
||||
embedding: None,
|
||||
is_active: true,
|
||||
metadata: serde_json::Value::Null,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if agent has a specific capability
|
||||
pub fn has_capability(&self, capability: &str) -> bool {
|
||||
self.capabilities
|
||||
.iter()
|
||||
.any(|c| c.eq_ignore_ascii_case(capability))
|
||||
}
|
||||
|
||||
/// Calculate total cost for a request
|
||||
pub fn calculate_cost(&self, token_count: Option<u32>) -> f32 {
|
||||
let mut cost = self.cost_model.per_request;
|
||||
|
||||
if let (Some(tokens), Some(per_token)) = (token_count, self.cost_model.per_token) {
|
||||
cost += tokens as f32 * per_token;
|
||||
}
|
||||
|
||||
cost
|
||||
}
|
||||
|
||||
/// Update performance metrics with new observation
|
||||
pub fn update_metrics(&mut self, latency_ms: f32, success: bool, quality: Option<f32>) {
|
||||
let n = self.performance.total_requests as f32;
|
||||
let new_n = n + 1.0;
|
||||
|
||||
// Update average latency with exponential moving average
|
||||
self.performance.avg_latency_ms =
|
||||
(self.performance.avg_latency_ms * n + latency_ms) / new_n;
|
||||
|
||||
// Update success rate
|
||||
let prev_successes = (self.performance.success_rate * n) as u64;
|
||||
let new_successes = prev_successes + if success { 1 } else { 0 };
|
||||
self.performance.success_rate = new_successes as f32 / new_n;
|
||||
|
||||
// Update quality score if provided
|
||||
if let Some(q) = quality {
|
||||
self.performance.quality_score = (self.performance.quality_score * n + q) / new_n;
|
||||
}
|
||||
|
||||
self.performance.total_requests += 1;
|
||||
|
||||
// Update percentiles (simplified approach)
|
||||
if latency_ms > self.performance.avg_latency_ms * 1.5 {
|
||||
self.performance.p95_latency_ms =
|
||||
(self.performance.p95_latency_ms * 0.95 + latency_ms * 0.05).max(latency_ms);
|
||||
}
|
||||
if latency_ms > self.performance.avg_latency_ms * 2.0 {
|
||||
self.performance.p99_latency_ms =
|
||||
(self.performance.p99_latency_ms * 0.99 + latency_ms * 0.01).max(latency_ms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe agent registry
|
||||
pub struct AgentRegistry {
|
||||
/// Agents stored by name
|
||||
agents: Arc<DashMap<String, Agent>>,
|
||||
}
|
||||
|
||||
impl AgentRegistry {
|
||||
/// Create a new agent registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
agents: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new agent
|
||||
pub fn register(&self, agent: Agent) -> Result<(), String> {
|
||||
if self.agents.contains_key(&agent.name) {
|
||||
return Err(format!("Agent '{}' already exists", agent.name));
|
||||
}
|
||||
|
||||
self.agents.insert(agent.name.clone(), agent);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an existing agent
|
||||
pub fn update(&self, agent: Agent) -> Result<(), String> {
|
||||
if !self.agents.contains_key(&agent.name) {
|
||||
return Err(format!("Agent '{}' not found", agent.name));
|
||||
}
|
||||
|
||||
self.agents.insert(agent.name.clone(), agent);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get an agent by name
|
||||
pub fn get(&self, name: &str) -> Option<Agent> {
|
||||
self.agents.get(name).map(|entry| entry.clone())
|
||||
}
|
||||
|
||||
/// Remove an agent
|
||||
pub fn remove(&self, name: &str) -> Option<Agent> {
|
||||
self.agents.remove(name).map(|(_, agent)| agent)
|
||||
}
|
||||
|
||||
/// List all active agents
|
||||
pub fn list_active(&self) -> Vec<Agent> {
|
||||
self.agents
|
||||
.iter()
|
||||
.filter(|entry| entry.is_active)
|
||||
.map(|entry| entry.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// List all agents
|
||||
pub fn list_all(&self) -> Vec<Agent> {
|
||||
self.agents.iter().map(|entry| entry.clone()).collect()
|
||||
}
|
||||
|
||||
/// Find agents by capability
|
||||
pub fn find_by_capability(&self, capability: &str, k: usize) -> Vec<Agent> {
|
||||
let mut agents: Vec<Agent> = self
|
||||
.agents
|
||||
.iter()
|
||||
.filter(|entry| entry.is_active && entry.has_capability(capability))
|
||||
.map(|entry| entry.clone())
|
||||
.collect();
|
||||
|
||||
// Sort by quality score (descending)
|
||||
agents.sort_by(|a, b| {
|
||||
b.performance
|
||||
.quality_score
|
||||
.partial_cmp(&a.performance.quality_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
agents.into_iter().take(k).collect()
|
||||
}
|
||||
|
||||
/// Find agents by type
|
||||
pub fn find_by_type(&self, agent_type: &AgentType) -> Vec<Agent> {
|
||||
self.agents
|
||||
.iter()
|
||||
.filter(|entry| entry.is_active && &entry.agent_type == agent_type)
|
||||
.map(|entry| entry.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get agent count
|
||||
pub fn count(&self) -> usize {
|
||||
self.agents.len()
|
||||
}
|
||||
|
||||
/// Get active agent count
|
||||
pub fn count_active(&self) -> usize {
|
||||
self.agents.iter().filter(|entry| entry.is_active).count()
|
||||
}
|
||||
|
||||
/// Clear all agents
|
||||
pub fn clear(&self) {
|
||||
self.agents.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_agent_type_parsing() {
|
||||
assert_eq!(AgentType::from_str("llm"), AgentType::LLM);
|
||||
assert_eq!(AgentType::from_str("LLM"), AgentType::LLM);
|
||||
assert_eq!(AgentType::from_str("embedding"), AgentType::Embedding);
|
||||
assert_eq!(
|
||||
AgentType::from_str("custom"),
|
||||
AgentType::Custom("custom".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_creation() {
|
||||
let agent = Agent::new(
|
||||
"gpt-4".to_string(),
|
||||
AgentType::LLM,
|
||||
vec!["code_generation".to_string(), "translation".to_string()],
|
||||
);
|
||||
|
||||
assert_eq!(agent.name, "gpt-4");
|
||||
assert_eq!(agent.agent_type, AgentType::LLM);
|
||||
assert_eq!(agent.capabilities.len(), 2);
|
||||
assert!(agent.is_active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_has_capability() {
|
||||
let agent = Agent::new(
|
||||
"test".to_string(),
|
||||
AgentType::LLM,
|
||||
vec!["code_generation".to_string()],
|
||||
);
|
||||
|
||||
assert!(agent.has_capability("code_generation"));
|
||||
assert!(agent.has_capability("CODE_GENERATION"));
|
||||
assert!(!agent.has_capability("translation"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_cost_calculation() {
|
||||
let mut agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
|
||||
agent.cost_model.per_request = 0.01;
|
||||
agent.cost_model.per_token = Some(0.0001);
|
||||
|
||||
assert_eq!(agent.calculate_cost(None), 0.01);
|
||||
assert_eq!(agent.calculate_cost(Some(1000)), 0.11); // 0.01 + 1000 * 0.0001
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_update_metrics() {
|
||||
let mut agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
|
||||
|
||||
// Initial state
|
||||
assert_eq!(agent.performance.total_requests, 0);
|
||||
|
||||
// Add first observation
|
||||
agent.update_metrics(100.0, true, Some(0.9));
|
||||
assert_eq!(agent.performance.total_requests, 1);
|
||||
assert_eq!(agent.performance.avg_latency_ms, 100.0);
|
||||
assert_eq!(agent.performance.success_rate, 1.0);
|
||||
assert_eq!(agent.performance.quality_score, 0.9);
|
||||
|
||||
// Add second observation
|
||||
agent.update_metrics(200.0, true, Some(0.8));
|
||||
assert_eq!(agent.performance.total_requests, 2);
|
||||
assert_eq!(agent.performance.avg_latency_ms, 150.0);
|
||||
assert_eq!(agent.performance.success_rate, 1.0);
|
||||
assert!((agent.performance.quality_score - 0.85).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_register() {
|
||||
let registry = AgentRegistry::new();
|
||||
let agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
|
||||
|
||||
assert!(registry.register(agent.clone()).is_ok());
|
||||
assert_eq!(registry.count(), 1);
|
||||
|
||||
// Duplicate registration should fail
|
||||
assert!(registry.register(agent).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_get() {
|
||||
let registry = AgentRegistry::new();
|
||||
let agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
|
||||
|
||||
registry.register(agent.clone()).unwrap();
|
||||
|
||||
let retrieved = registry.get("test").unwrap();
|
||||
assert_eq!(retrieved.name, "test");
|
||||
|
||||
assert!(registry.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_remove() {
|
||||
let registry = AgentRegistry::new();
|
||||
let agent = Agent::new("test".to_string(), AgentType::LLM, vec![]);
|
||||
|
||||
registry.register(agent).unwrap();
|
||||
assert_eq!(registry.count(), 1);
|
||||
|
||||
let removed = registry.remove("test").unwrap();
|
||||
assert_eq!(removed.name, "test");
|
||||
assert_eq!(registry.count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_list_active() {
|
||||
let registry = AgentRegistry::new();
|
||||
|
||||
let mut agent1 = Agent::new("active".to_string(), AgentType::LLM, vec![]);
|
||||
agent1.is_active = true;
|
||||
|
||||
let mut agent2 = Agent::new("inactive".to_string(), AgentType::LLM, vec![]);
|
||||
agent2.is_active = false;
|
||||
|
||||
registry.register(agent1).unwrap();
|
||||
registry.register(agent2).unwrap();
|
||||
|
||||
let active = registry.list_active();
|
||||
assert_eq!(active.len(), 1);
|
||||
assert_eq!(active[0].name, "active");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_find_by_capability() {
|
||||
let registry = AgentRegistry::new();
|
||||
|
||||
let agent1 = Agent::new(
|
||||
"agent1".to_string(),
|
||||
AgentType::LLM,
|
||||
vec!["coding".to_string()],
|
||||
);
|
||||
let agent2 = Agent::new(
|
||||
"agent2".to_string(),
|
||||
AgentType::LLM,
|
||||
vec!["translation".to_string()],
|
||||
);
|
||||
let agent3 = Agent::new(
|
||||
"agent3".to_string(),
|
||||
AgentType::LLM,
|
||||
vec!["coding".to_string(), "translation".to_string()],
|
||||
);
|
||||
|
||||
registry.register(agent1).unwrap();
|
||||
registry.register(agent2).unwrap();
|
||||
registry.register(agent3).unwrap();
|
||||
|
||||
let coders = registry.find_by_capability("coding", 10);
|
||||
assert_eq!(coders.len(), 2);
|
||||
|
||||
let translators = registry.find_by_capability("translation", 10);
|
||||
assert_eq!(translators.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_find_by_type() {
|
||||
let registry = AgentRegistry::new();
|
||||
|
||||
registry
|
||||
.register(Agent::new("llm1".to_string(), AgentType::LLM, vec![]))
|
||||
.unwrap();
|
||||
registry
|
||||
.register(Agent::new("llm2".to_string(), AgentType::LLM, vec![]))
|
||||
.unwrap();
|
||||
registry
|
||||
.register(Agent::new(
|
||||
"embed1".to_string(),
|
||||
AgentType::Embedding,
|
||||
vec![],
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let llms = registry.find_by_type(&AgentType::LLM);
|
||||
assert_eq!(llms.len(), 2);
|
||||
|
||||
let embeddings = registry.find_by_type(&AgentType::Embedding);
|
||||
assert_eq!(embeddings.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_clear() {
|
||||
let registry = AgentRegistry::new();
|
||||
registry
|
||||
.register(Agent::new("test".to_string(), AgentType::LLM, vec![]))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(registry.count(), 1);
|
||||
registry.clear();
|
||||
assert_eq!(registry.count(), 0);
|
||||
}
|
||||
}
|
||||
253
vendor/ruvector/crates/ruvector-postgres/src/routing/fastgrnn.rs
vendored
Normal file
253
vendor/ruvector/crates/ruvector-postgres/src/routing/fastgrnn.rs
vendored
Normal file
@@ -0,0 +1,253 @@
|
||||
// FastGRNN - Fast Gated Recurrent Neural Network
|
||||
//
|
||||
// Lightweight RNN for real-time routing decisions with minimal compute overhead.
|
||||
// Based on "FastGRNN: A Fast, Accurate, Stable and Tiny Kilobyte Sized Gated Recurrent Neural Network"
|
||||
|
||||
use std::f32;
|
||||
|
||||
/// FastGRNN cell for sequence processing with gating mechanisms
|
||||
#[derive(Clone)]
|
||||
pub struct FastGRNN {
|
||||
/// Input dimension
|
||||
input_dim: usize,
|
||||
/// Hidden state dimension
|
||||
hidden_dim: usize,
|
||||
/// Gate weights for input
|
||||
w_gate: Vec<f32>,
|
||||
/// Gate weights for hidden state
|
||||
u_gate: Vec<f32>,
|
||||
/// Update weights for input
|
||||
w_update: Vec<f32>,
|
||||
/// Update weights for hidden state
|
||||
u_update: Vec<f32>,
|
||||
/// Biases for gate and update
|
||||
bias_gate: Vec<f32>,
|
||||
bias_update: Vec<f32>,
|
||||
/// Zeta parameter for gate scaling
|
||||
zeta: f32,
|
||||
/// Nu parameter for update scaling
|
||||
nu: f32,
|
||||
}
|
||||
|
||||
impl FastGRNN {
|
||||
/// Create a new FastGRNN cell with specified dimensions
|
||||
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
|
||||
// Initialize with small random weights (Xavier initialization)
|
||||
let scale = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
|
||||
|
||||
Self {
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
w_gate: vec![0.1 * scale; input_dim * hidden_dim],
|
||||
u_gate: vec![0.1 * scale; hidden_dim * hidden_dim],
|
||||
w_update: vec![0.1 * scale; input_dim * hidden_dim],
|
||||
u_update: vec![0.1 * scale; hidden_dim * hidden_dim],
|
||||
bias_gate: vec![0.0; hidden_dim],
|
||||
bias_update: vec![0.0; hidden_dim],
|
||||
zeta: 1.0,
|
||||
nu: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create FastGRNN from pre-trained weights
|
||||
pub fn from_weights(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
w_gate: Vec<f32>,
|
||||
u_gate: Vec<f32>,
|
||||
w_update: Vec<f32>,
|
||||
u_update: Vec<f32>,
|
||||
bias_gate: Vec<f32>,
|
||||
bias_update: Vec<f32>,
|
||||
zeta: f32,
|
||||
nu: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
w_gate,
|
||||
u_gate,
|
||||
w_update,
|
||||
u_update,
|
||||
bias_gate,
|
||||
bias_update,
|
||||
zeta,
|
||||
nu,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform one step of FastGRNN computation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input` - Input vector of size input_dim
|
||||
/// * `hidden` - Previous hidden state of size hidden_dim
|
||||
///
|
||||
/// # Returns
|
||||
/// New hidden state of size hidden_dim
|
||||
pub fn step(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(input.len(), self.input_dim, "Input dimension mismatch");
|
||||
assert_eq!(hidden.len(), self.hidden_dim, "Hidden dimension mismatch");
|
||||
|
||||
let mut new_hidden = vec![0.0; self.hidden_dim];
|
||||
|
||||
// Compute gate: g = sigmoid(W_g * x + U_g * h + b_g)
|
||||
let mut gate = vec![0.0; self.hidden_dim];
|
||||
self.matmul_add(&self.w_gate, input, &mut gate);
|
||||
self.matmul_add(&self.u_gate, hidden, &mut gate);
|
||||
for i in 0..self.hidden_dim {
|
||||
gate[i] = self.sigmoid(gate[i] + self.bias_gate[i]);
|
||||
}
|
||||
|
||||
// Compute update: c = tanh(W_u * x + U_u * h + b_u)
|
||||
let mut update = vec![0.0; self.hidden_dim];
|
||||
self.matmul_add(&self.w_update, input, &mut update);
|
||||
self.matmul_add(&self.u_update, hidden, &mut update);
|
||||
for i in 0..self.hidden_dim {
|
||||
update[i] = self.tanh(update[i] + self.bias_update[i]);
|
||||
}
|
||||
|
||||
// Compute new hidden: h' = (zeta * g + nu) ⊙ h + (1 - zeta * g - nu) ⊙ c
|
||||
for i in 0..self.hidden_dim {
|
||||
let gate_factor = self.zeta * gate[i] + self.nu;
|
||||
let gate_factor = gate_factor.min(1.0).max(0.0); // Clip to [0, 1]
|
||||
new_hidden[i] = gate_factor * hidden[i] + (1.0 - gate_factor) * update[i];
|
||||
}
|
||||
|
||||
new_hidden
|
||||
}
|
||||
|
||||
/// Process a single input and return hidden state (for single-step inference)
|
||||
pub fn forward_single(&self, input: &[f32]) -> Vec<f32> {
|
||||
let hidden = vec![0.0; self.hidden_dim];
|
||||
self.step(input, &hidden)
|
||||
}
|
||||
|
||||
/// Process a sequence of inputs
|
||||
pub fn forward_sequence(&self, inputs: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
let mut hidden = vec![0.0; self.hidden_dim];
|
||||
let mut outputs = Vec::with_capacity(inputs.len());
|
||||
|
||||
for input in inputs {
|
||||
hidden = self.step(input, &hidden);
|
||||
outputs.push(hidden.clone());
|
||||
}
|
||||
|
||||
outputs
|
||||
}
|
||||
|
||||
/// Matrix-vector multiplication with accumulation: result += W * input
|
||||
fn matmul_add(&self, weights: &[f32], input: &[f32], result: &mut [f32]) {
|
||||
let rows = result.len();
|
||||
let cols = input.len();
|
||||
|
||||
for i in 0..rows {
|
||||
for j in 0..cols {
|
||||
result[i] += weights[i * cols + j] * input[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sigmoid activation function
|
||||
fn sigmoid(&self, x: f32) -> f32 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
/// Hyperbolic tangent activation function
|
||||
fn tanh(&self, x: f32) -> f32 {
|
||||
x.tanh()
|
||||
}
|
||||
|
||||
/// Get input dimension
|
||||
pub fn input_dim(&self) -> usize {
|
||||
self.input_dim
|
||||
}
|
||||
|
||||
/// Get hidden dimension
|
||||
pub fn hidden_dim(&self) -> usize {
|
||||
self.hidden_dim
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fastgrnn_creation() {
|
||||
let grnn = FastGRNN::new(10, 5);
|
||||
assert_eq!(grnn.input_dim(), 10);
|
||||
assert_eq!(grnn.hidden_dim(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fastgrnn_step() {
|
||||
let grnn = FastGRNN::new(4, 3);
|
||||
let input = vec![1.0, 0.5, -0.5, 0.0];
|
||||
let hidden = vec![0.1, 0.2, 0.3];
|
||||
|
||||
let new_hidden = grnn.step(&input, &hidden);
|
||||
assert_eq!(new_hidden.len(), 3);
|
||||
|
||||
// Check that output is bounded (due to tanh and sigmoid)
|
||||
for &h in &new_hidden {
|
||||
assert!(h.abs() <= 2.0, "Hidden state should be bounded");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fastgrnn_forward_single() {
|
||||
let grnn = FastGRNN::new(4, 3);
|
||||
let input = vec![1.0, 0.5, -0.5, 0.0];
|
||||
|
||||
let output = grnn.forward_single(&input);
|
||||
assert_eq!(output.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fastgrnn_sequence() {
|
||||
let grnn = FastGRNN::new(4, 3);
|
||||
let inputs = vec![
|
||||
vec![1.0, 0.5, -0.5, 0.0],
|
||||
vec![0.5, 1.0, 0.0, -0.5],
|
||||
vec![-0.5, 0.0, 1.0, 0.5],
|
||||
];
|
||||
|
||||
let outputs = grnn.forward_sequence(&inputs);
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(outputs[0].len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid() {
|
||||
let grnn = FastGRNN::new(1, 1);
|
||||
assert!((grnn.sigmoid(0.0) - 0.5).abs() < 1e-6);
|
||||
assert!(grnn.sigmoid(10.0) > 0.99);
|
||||
assert!(grnn.sigmoid(-10.0) < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tanh() {
|
||||
let grnn = FastGRNN::new(1, 1);
|
||||
assert!(grnn.tanh(0.0).abs() < 1e-6);
|
||||
assert!(grnn.tanh(10.0) > 0.99);
|
||||
assert!(grnn.tanh(-10.0) < -0.99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Input dimension mismatch")]
|
||||
fn test_wrong_input_dimension() {
|
||||
let grnn = FastGRNN::new(4, 3);
|
||||
let input = vec![1.0, 0.5]; // Wrong size
|
||||
let hidden = vec![0.1, 0.2, 0.3];
|
||||
grnn.step(&input, &hidden);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Hidden dimension mismatch")]
|
||||
fn test_wrong_hidden_dimension() {
|
||||
let grnn = FastGRNN::new(4, 3);
|
||||
let input = vec![1.0, 0.5, -0.5, 0.0];
|
||||
let hidden = vec![0.1, 0.2]; // Wrong size
|
||||
grnn.step(&input, &hidden);
|
||||
}
|
||||
}
|
||||
24
vendor/ruvector/crates/ruvector-postgres/src/routing/mod.rs
vendored
Normal file
24
vendor/ruvector/crates/ruvector-postgres/src/routing/mod.rs
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
// Tiny Dancer Routing Module
|
||||
//
|
||||
// Neural-powered dynamic agent routing with FastGRNN for adaptive decision-making.
|
||||
|
||||
pub mod agents;
|
||||
pub mod fastgrnn;
|
||||
pub mod operators;
|
||||
pub mod router;
|
||||
|
||||
pub use agents::{Agent, AgentRegistry, AgentType};
|
||||
pub use fastgrnn::FastGRNN;
|
||||
pub use router::{OptimizationTarget, Router, RoutingConstraints, RoutingDecision};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_module_exports() {
|
||||
// Verify all types are exported
|
||||
let _registry = AgentRegistry::new();
|
||||
let _router = Router::new();
|
||||
}
|
||||
}
|
||||
610
vendor/ruvector/crates/ruvector-postgres/src/routing/operators.rs
vendored
Normal file
610
vendor/ruvector/crates/ruvector-postgres/src/routing/operators.rs
vendored
Normal file
@@ -0,0 +1,610 @@
|
||||
// PostgreSQL Operators for Tiny Dancer Routing
|
||||
//
|
||||
// SQL functions for agent registration, routing, and management.
|
||||
|
||||
use pgrx::prelude::*;
|
||||
use pgrx::JsonB;
|
||||
use serde_json::json;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use super::agents::{Agent, AgentRegistry, AgentType};
|
||||
use super::router::{OptimizationTarget, Router, RoutingConstraints};
|
||||
|
||||
// Global agent registry and router
|
||||
static AGENT_REGISTRY: OnceLock<AgentRegistry> = OnceLock::new();
|
||||
static ROUTER: OnceLock<Router> = OnceLock::new();
|
||||
|
||||
/// Initialize the global registry and router
|
||||
fn init_router() -> &'static Router {
|
||||
ROUTER.get_or_init(|| {
|
||||
let _registry = AGENT_REGISTRY.get_or_init(AgentRegistry::new);
|
||||
Router::with_registry(std::sync::Arc::new(AgentRegistry::new()))
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the global agent registry
|
||||
fn get_registry() -> &'static AgentRegistry {
|
||||
AGENT_REGISTRY.get_or_init(AgentRegistry::new)
|
||||
}
|
||||
|
||||
/// Register a new AI agent
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - Unique agent identifier
|
||||
/// * `agent_type` - Type of agent (llm, embedding, specialized, etc.)
|
||||
/// * `capabilities` - Array of capability strings
|
||||
/// * `cost_per_request` - Cost per request in dollars
|
||||
/// * `avg_latency_ms` - Average latency in milliseconds
|
||||
/// * `quality_score` - Quality score (0-1)
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_register_agent(
|
||||
/// 'gpt-4',
|
||||
/// 'llm',
|
||||
/// ARRAY['code_generation', 'translation'],
|
||||
/// 0.03,
|
||||
/// 500.0,
|
||||
/// 0.95
|
||||
/// );
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_register_agent(
|
||||
name: String,
|
||||
agent_type: String,
|
||||
capabilities: Vec<String>,
|
||||
cost_per_request: f32,
|
||||
avg_latency_ms: f32,
|
||||
quality_score: f32,
|
||||
) -> Result<bool, String> {
|
||||
let registry = get_registry();
|
||||
|
||||
let mut agent = Agent::new(name.clone(), AgentType::from_str(&agent_type), capabilities);
|
||||
|
||||
agent.cost_model.per_request = cost_per_request;
|
||||
agent.performance.avg_latency_ms = avg_latency_ms;
|
||||
agent.performance.quality_score = quality_score;
|
||||
|
||||
registry.register(agent)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Register an agent with full configuration
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `config` - JSONB configuration with all agent properties
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_register_agent_full('{
|
||||
/// "name": "gpt-4",
|
||||
/// "agent_type": "llm",
|
||||
/// "capabilities": ["code_generation", "translation"],
|
||||
/// "cost_model": {
|
||||
/// "per_request": 0.03,
|
||||
/// "per_token": 0.00006
|
||||
/// },
|
||||
/// "performance": {
|
||||
/// "avg_latency_ms": 500.0,
|
||||
/// "quality_score": 0.95,
|
||||
/// "success_rate": 0.99
|
||||
/// }
|
||||
/// }'::jsonb);
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_register_agent_full(config: JsonB) -> Result<bool, String> {
|
||||
let registry = get_registry();
|
||||
|
||||
let agent: Agent = serde_json::from_value(config.0)
|
||||
.map_err(|e| format!("Invalid agent configuration: {}", e))?;
|
||||
|
||||
registry.register(agent)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Update an existing agent's performance metrics
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - Agent name
|
||||
/// * `latency_ms` - Observed latency
|
||||
/// * `success` - Whether the request succeeded
|
||||
/// * `quality` - Optional quality score for this request
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_update_agent_metrics('gpt-4', 450.0, true, 0.92);
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_update_agent_metrics(
|
||||
name: String,
|
||||
latency_ms: f32,
|
||||
success: bool,
|
||||
quality: Option<f32>,
|
||||
) -> Result<bool, String> {
|
||||
let registry = get_registry();
|
||||
|
||||
let mut agent = registry
|
||||
.get(&name)
|
||||
.ok_or_else(|| format!("Agent '{}' not found", name))?;
|
||||
|
||||
agent.update_metrics(latency_ms, success, quality);
|
||||
registry.update(agent)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Remove an agent from the registry
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_remove_agent('gpt-4');
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_remove_agent(name: String) -> Result<bool, String> {
|
||||
let registry = get_registry();
|
||||
registry
|
||||
.remove(&name)
|
||||
.ok_or_else(|| format!("Agent '{}' not found", name))?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Set an agent's active status
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_set_agent_active('gpt-4', false);
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_set_agent_active(name: String, is_active: bool) -> Result<bool, String> {
|
||||
let registry = get_registry();
|
||||
|
||||
let mut agent = registry
|
||||
.get(&name)
|
||||
.ok_or_else(|| format!("Agent '{}' not found", name))?;
|
||||
|
||||
agent.is_active = is_active;
|
||||
registry.update(agent)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Route a request to the best agent
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `request_embedding` - Request embedding vector (384-dim)
|
||||
/// * `optimize_for` - Optimization target: 'cost', 'latency', 'quality', 'balanced'
|
||||
/// * `constraints` - Optional JSONB constraints object
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_route(
|
||||
/// embedding,
|
||||
/// 'balanced',
|
||||
/// '{"max_cost": 0.1, "min_quality": 0.8}'::jsonb
|
||||
/// )
|
||||
/// FROM request_embeddings
|
||||
/// WHERE id = 123;
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_route(
|
||||
request_embedding: Vec<f32>,
|
||||
optimize_for: default!(String, "'balanced'"),
|
||||
constraints: default!(Option<JsonB>, "NULL"),
|
||||
) -> Result<JsonB, String> {
|
||||
init_router(); // Ensure router is initialized
|
||||
|
||||
let target = OptimizationTarget::from_str(&optimize_for);
|
||||
|
||||
let routing_constraints = if let Some(JsonB(json_val)) = constraints {
|
||||
serde_json::from_value(json_val).map_err(|e| format!("Invalid constraints: {}", e))?
|
||||
} else {
|
||||
RoutingConstraints::default()
|
||||
};
|
||||
|
||||
// Get router with proper registry
|
||||
let registry = get_registry();
|
||||
let router = Router::with_registry(std::sync::Arc::new(AgentRegistry::new()));
|
||||
|
||||
// Copy agents from global registry to router's registry
|
||||
for agent in registry.list_all() {
|
||||
router.registry().register(agent).ok();
|
||||
}
|
||||
|
||||
let decision = router.route(&request_embedding, &routing_constraints, target)?;
|
||||
|
||||
let result = json!({
|
||||
"agent_name": decision.agent_name,
|
||||
"confidence": decision.confidence,
|
||||
"estimated_cost": decision.estimated_cost,
|
||||
"estimated_latency_ms": decision.estimated_latency_ms,
|
||||
"expected_quality": decision.expected_quality,
|
||||
"similarity_score": decision.similarity_score,
|
||||
"reasoning": decision.reasoning,
|
||||
"alternatives": decision.alternatives,
|
||||
});
|
||||
|
||||
Ok(JsonB(result))
|
||||
}
|
||||
|
||||
/// List all registered agents
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT * FROM ruvector_list_agents();
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_list_agents() -> TableIterator<
|
||||
'static,
|
||||
(
|
||||
name!(name, String),
|
||||
name!(agent_type, String),
|
||||
name!(capabilities, Vec<String>),
|
||||
name!(cost_per_request, f32),
|
||||
name!(avg_latency_ms, f32),
|
||||
name!(quality_score, f32),
|
||||
name!(success_rate, f32),
|
||||
name!(total_requests, i64),
|
||||
name!(is_active, bool),
|
||||
),
|
||||
> {
|
||||
let registry = get_registry();
|
||||
let agents = registry.list_all();
|
||||
|
||||
TableIterator::new(
|
||||
agents
|
||||
.into_iter()
|
||||
.map(|agent| {
|
||||
(
|
||||
agent.name,
|
||||
agent.agent_type.as_str().to_string(),
|
||||
agent.capabilities,
|
||||
agent.cost_model.per_request,
|
||||
agent.performance.avg_latency_ms,
|
||||
agent.performance.quality_score,
|
||||
agent.performance.success_rate,
|
||||
agent.performance.total_requests as i64,
|
||||
agent.is_active,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Get detailed information about a specific agent
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_get_agent('gpt-4');
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_get_agent(name: String) -> Result<JsonB, String> {
|
||||
let registry = get_registry();
|
||||
|
||||
let agent = registry
|
||||
.get(&name)
|
||||
.ok_or_else(|| format!("Agent '{}' not found", name))?;
|
||||
|
||||
let result = serde_json::to_value(&agent).map_err(|e| format!("Serialization error: {}", e))?;
|
||||
|
||||
Ok(JsonB(result))
|
||||
}
|
||||
|
||||
/// Find agents by capability
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT * FROM ruvector_find_agents_by_capability('code_generation', 5);
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_find_agents_by_capability(
|
||||
capability: String,
|
||||
limit: default!(i32, 10),
|
||||
) -> TableIterator<
|
||||
'static,
|
||||
(
|
||||
name!(name, String),
|
||||
name!(quality_score, f32),
|
||||
name!(avg_latency_ms, f32),
|
||||
name!(cost_per_request, f32),
|
||||
),
|
||||
> {
|
||||
let registry = get_registry();
|
||||
let agents = registry.find_by_capability(&capability, limit as usize);
|
||||
|
||||
TableIterator::new(
|
||||
agents
|
||||
.into_iter()
|
||||
.map(|agent| {
|
||||
(
|
||||
agent.name,
|
||||
agent.performance.quality_score,
|
||||
agent.performance.avg_latency_ms,
|
||||
agent.cost_model.per_request,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Get routing statistics
|
||||
///
|
||||
/// # Example
|
||||
/// ```sql
|
||||
/// SELECT ruvector_routing_stats();
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_routing_stats() -> JsonB {
|
||||
let registry = get_registry();
|
||||
|
||||
let total_agents = registry.count();
|
||||
let active_agents = registry.count_active();
|
||||
|
||||
let agents = registry.list_all();
|
||||
|
||||
let total_requests: u64 = agents.iter().map(|a| a.performance.total_requests).sum();
|
||||
let avg_quality: f32 = if !agents.is_empty() {
|
||||
agents
|
||||
.iter()
|
||||
.map(|a| a.performance.quality_score)
|
||||
.sum::<f32>()
|
||||
/ agents.len() as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let result = json!({
|
||||
"total_agents": total_agents,
|
||||
"active_agents": active_agents,
|
||||
"total_requests": total_requests,
|
||||
"average_quality": avg_quality,
|
||||
});
|
||||
|
||||
JsonB(result)
|
||||
}
|
||||
|
||||
/// Clear all agents (for testing)
|
||||
#[pg_extern]
|
||||
fn ruvector_clear_agents() -> bool {
|
||||
let registry = get_registry();
|
||||
registry.clear();
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(feature = "pg_test")]
|
||||
#[pg_schema]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[pg_test]
|
||||
fn test_register_agent() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
let result = ruvector_register_agent(
|
||||
"test-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), true);
|
||||
|
||||
// Verify agent was registered
|
||||
let agent = ruvector_get_agent("test-agent".to_string());
|
||||
assert!(agent.is_ok());
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_register_duplicate_agent() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"test-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Try to register again
|
||||
let result = ruvector_register_agent(
|
||||
"test-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_update_agent_metrics() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"test-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result =
|
||||
ruvector_update_agent_metrics("test-agent".to_string(), 150.0, true, Some(0.9));
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_remove_agent() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"test-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = ruvector_remove_agent("test-agent".to_string());
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify agent was removed
|
||||
let agent = ruvector_get_agent("test-agent".to_string());
|
||||
assert!(agent.is_err());
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_set_agent_active() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"test-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = ruvector_set_agent_active("test-agent".to_string(), false);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let agent_json = ruvector_get_agent("test-agent".to_string()).unwrap();
|
||||
let agent: Agent = serde_json::from_value(agent_json.0).unwrap();
|
||||
assert_eq!(agent.is_active, false);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_list_agents() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"agent1".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ruvector_register_agent(
|
||||
"agent2".to_string(),
|
||||
"embedding".to_string(),
|
||||
vec!["similarity".to_string()],
|
||||
0.01,
|
||||
50.0,
|
||||
0.90,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let agents: Vec<_> = ruvector_list_agents().collect();
|
||||
assert_eq!(agents.len(), 2);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_find_agents_by_capability() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"coder1".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ruvector_register_agent(
|
||||
"coder2".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string(), "translation".to_string()],
|
||||
0.08,
|
||||
250.0,
|
||||
0.90,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ruvector_register_agent(
|
||||
"translator".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["translation".to_string()],
|
||||
0.03,
|
||||
150.0,
|
||||
0.80,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let coders: Vec<_> = ruvector_find_agents_by_capability("coding".to_string(), 10).collect();
|
||||
assert_eq!(coders.len(), 2);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_routing_stats() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"agent1".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.05,
|
||||
200.0,
|
||||
0.85,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let stats = ruvector_routing_stats();
|
||||
let stats_obj: serde_json::Value = stats.0;
|
||||
|
||||
assert_eq!(stats_obj["total_agents"], 1);
|
||||
assert_eq!(stats_obj["active_agents"], 1);
|
||||
}
|
||||
|
||||
#[pg_test]
|
||||
fn test_route_basic() {
|
||||
ruvector_clear_agents();
|
||||
|
||||
ruvector_register_agent(
|
||||
"cheap-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.01,
|
||||
200.0,
|
||||
0.70,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ruvector_register_agent(
|
||||
"expensive-agent".to_string(),
|
||||
"llm".to_string(),
|
||||
vec!["coding".to_string()],
|
||||
0.10,
|
||||
200.0,
|
||||
0.95,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let embedding = vec![0.1; 384];
|
||||
|
||||
// Route optimizing for cost
|
||||
let result = ruvector_route(embedding.clone(), "cost".to_string(), None);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let decision = result.unwrap().0;
|
||||
assert_eq!(decision["agent_name"], "cheap-agent");
|
||||
}
|
||||
}
|
||||
628
vendor/ruvector/crates/ruvector-postgres/src/routing/router.rs
vendored
Normal file
628
vendor/ruvector/crates/ruvector-postgres/src/routing/router.rs
vendored
Normal file
@@ -0,0 +1,628 @@
|
||||
// Neural-Powered Agent Router
|
||||
//
|
||||
// Dynamic routing with FastGRNN and multi-objective optimization.
|
||||
|
||||
use super::agents::{Agent, AgentRegistry};
|
||||
use super::fastgrnn::FastGRNN;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Optimization target for routing decisions
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum OptimizationTarget {
|
||||
/// Minimize cost
|
||||
Cost,
|
||||
/// Minimize latency
|
||||
Latency,
|
||||
/// Maximize quality
|
||||
Quality,
|
||||
/// Balanced optimization
|
||||
Balanced,
|
||||
}
|
||||
|
||||
impl OptimizationTarget {
|
||||
/// Parse from string
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"cost" => OptimizationTarget::Cost,
|
||||
"latency" => OptimizationTarget::Latency,
|
||||
"quality" => OptimizationTarget::Quality,
|
||||
"balanced" => OptimizationTarget::Balanced,
|
||||
_ => OptimizationTarget::Balanced,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to string
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
OptimizationTarget::Cost => "cost",
|
||||
OptimizationTarget::Latency => "latency",
|
||||
OptimizationTarget::Quality => "quality",
|
||||
OptimizationTarget::Balanced => "balanced",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Constraints for routing decisions
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoutingConstraints {
|
||||
/// Maximum acceptable cost
|
||||
pub max_cost: Option<f32>,
|
||||
/// Maximum acceptable latency in ms
|
||||
pub max_latency_ms: Option<f32>,
|
||||
/// Minimum required quality score (0-1)
|
||||
pub min_quality: Option<f32>,
|
||||
/// Required capabilities
|
||||
pub required_capabilities: Vec<String>,
|
||||
/// Excluded agent names
|
||||
pub excluded_agents: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for RoutingConstraints {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_cost: None,
|
||||
max_latency_ms: None,
|
||||
min_quality: None,
|
||||
required_capabilities: Vec::new(),
|
||||
excluded_agents: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoutingConstraints {
|
||||
/// Create new constraints
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set maximum cost
|
||||
pub fn with_max_cost(mut self, cost: f32) -> Self {
|
||||
self.max_cost = Some(cost);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set maximum latency
|
||||
pub fn with_max_latency(mut self, latency_ms: f32) -> Self {
|
||||
self.max_latency_ms = Some(latency_ms);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set minimum quality
|
||||
pub fn with_min_quality(mut self, quality: f32) -> Self {
|
||||
self.min_quality = Some(quality);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add required capability
|
||||
pub fn with_capability(mut self, capability: String) -> Self {
|
||||
self.required_capabilities.push(capability);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add excluded agent
|
||||
pub fn with_excluded_agent(mut self, agent_name: String) -> Self {
|
||||
self.excluded_agents.push(agent_name);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Routing decision result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoutingDecision {
|
||||
/// Selected agent name
|
||||
pub agent_name: String,
|
||||
/// Confidence score (0-1)
|
||||
pub confidence: f32,
|
||||
/// Estimated cost
|
||||
pub estimated_cost: f32,
|
||||
/// Estimated latency in ms
|
||||
pub estimated_latency_ms: f32,
|
||||
/// Expected quality
|
||||
pub expected_quality: f32,
|
||||
/// Similarity score to request
|
||||
pub similarity_score: f32,
|
||||
/// Reasoning for the decision
|
||||
pub reasoning: String,
|
||||
/// Alternative agents considered
|
||||
pub alternatives: Vec<AlternativeAgent>,
|
||||
}
|
||||
|
||||
/// Alternative agent option
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AlternativeAgent {
|
||||
/// Agent name
|
||||
pub name: String,
|
||||
/// Score
|
||||
pub score: f32,
|
||||
/// Why it wasn't selected
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
/// Neural-powered agent router
|
||||
pub struct Router {
|
||||
/// Agent registry
|
||||
registry: Arc<AgentRegistry>,
|
||||
/// FastGRNN model for neural routing
|
||||
grnn: Option<FastGRNN>,
|
||||
/// Embedding dimension
|
||||
embedding_dim: usize,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
/// Create a new router
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registry: Arc::new(AgentRegistry::new()),
|
||||
grnn: None,
|
||||
embedding_dim: 384, // Default embedding size
|
||||
}
|
||||
}
|
||||
|
||||
/// Create router with custom registry
|
||||
pub fn with_registry(registry: Arc<AgentRegistry>) -> Self {
|
||||
Self {
|
||||
registry,
|
||||
grnn: None,
|
||||
embedding_dim: 384,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize FastGRNN model
|
||||
pub fn init_grnn(&mut self, hidden_dim: usize) {
|
||||
self.grnn = Some(FastGRNN::new(self.embedding_dim, hidden_dim));
|
||||
}
|
||||
|
||||
/// Set FastGRNN model from weights
|
||||
pub fn set_grnn(&mut self, grnn: FastGRNN) {
|
||||
self.grnn = Some(grnn);
|
||||
}
|
||||
|
||||
/// Route a request to the best agent
|
||||
pub fn route(
|
||||
&self,
|
||||
request_embedding: &[f32],
|
||||
constraints: &RoutingConstraints,
|
||||
target: OptimizationTarget,
|
||||
) -> Result<RoutingDecision, String> {
|
||||
// Get candidate agents
|
||||
let candidates = self.get_candidates(constraints)?;
|
||||
|
||||
if candidates.is_empty() {
|
||||
return Err("No agents match the constraints".to_string());
|
||||
}
|
||||
|
||||
// Score all candidates
|
||||
let mut scored_candidates: Vec<(Agent, f32, f32)> = candidates
|
||||
.iter()
|
||||
.filter_map(|agent| {
|
||||
// Calculate similarity
|
||||
let similarity = if let Some(agent_emb) = &agent.embedding {
|
||||
cosine_similarity(request_embedding, agent_emb)
|
||||
} else {
|
||||
0.5 // Default similarity if no embedding
|
||||
};
|
||||
|
||||
// Calculate score based on target
|
||||
let score = self.score_agent(agent, request_embedding, target, similarity);
|
||||
|
||||
// Apply constraints
|
||||
if self.meets_constraints(agent, constraints) {
|
||||
Some((agent.clone(), score, similarity))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if scored_candidates.is_empty() {
|
||||
return Err("No agents meet the specified constraints".to_string());
|
||||
}
|
||||
|
||||
// Sort by score (descending)
|
||||
scored_candidates
|
||||
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Select best agent
|
||||
let (best_agent, best_score, similarity) = &scored_candidates[0];
|
||||
|
||||
// Calculate confidence using FastGRNN if available
|
||||
let confidence = if let Some(ref grnn) = self.grnn {
|
||||
let hidden = grnn.forward_single(request_embedding);
|
||||
// Use hidden state magnitude as confidence
|
||||
let magnitude: f32 = hidden.iter().map(|&h| h * h).sum::<f32>().sqrt();
|
||||
(magnitude / hidden.len() as f32).min(1.0).max(0.0)
|
||||
} else {
|
||||
*best_score
|
||||
};
|
||||
|
||||
// Build alternatives list
|
||||
let alternatives: Vec<AlternativeAgent> = scored_candidates
|
||||
.iter()
|
||||
.skip(1)
|
||||
.take(3)
|
||||
.map(|(agent, score, _)| AlternativeAgent {
|
||||
name: agent.name.clone(),
|
||||
score: *score,
|
||||
reason: self.compare_to_best(agent, best_agent, target),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Generate reasoning
|
||||
let reasoning = self.generate_reasoning(best_agent, target, *similarity);
|
||||
|
||||
Ok(RoutingDecision {
|
||||
agent_name: best_agent.name.clone(),
|
||||
confidence,
|
||||
estimated_cost: best_agent.cost_model.per_request,
|
||||
estimated_latency_ms: best_agent.performance.avg_latency_ms,
|
||||
expected_quality: best_agent.performance.quality_score,
|
||||
similarity_score: *similarity,
|
||||
reasoning,
|
||||
alternatives,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get candidate agents based on constraints
|
||||
fn get_candidates(&self, constraints: &RoutingConstraints) -> Result<Vec<Agent>, String> {
|
||||
let mut agents = self.registry.list_active();
|
||||
|
||||
// Filter by required capabilities
|
||||
if !constraints.required_capabilities.is_empty() {
|
||||
agents.retain(|agent| {
|
||||
constraints
|
||||
.required_capabilities
|
||||
.iter()
|
||||
.all(|cap| agent.has_capability(cap))
|
||||
});
|
||||
}
|
||||
|
||||
// Filter excluded agents
|
||||
if !constraints.excluded_agents.is_empty() {
|
||||
agents.retain(|agent| !constraints.excluded_agents.contains(&agent.name));
|
||||
}
|
||||
|
||||
Ok(agents)
|
||||
}
|
||||
|
||||
/// Check if agent meets constraints
|
||||
fn meets_constraints(&self, agent: &Agent, constraints: &RoutingConstraints) -> bool {
|
||||
// Check cost constraint
|
||||
if let Some(max_cost) = constraints.max_cost {
|
||||
if agent.cost_model.per_request > max_cost {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check latency constraint
|
||||
if let Some(max_latency) = constraints.max_latency_ms {
|
||||
if agent.performance.avg_latency_ms > max_latency {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check quality constraint
|
||||
if let Some(min_quality) = constraints.min_quality {
|
||||
if agent.performance.quality_score < min_quality {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Score an agent for a given target
|
||||
fn score_agent(
|
||||
&self,
|
||||
agent: &Agent,
|
||||
_request_embedding: &[f32],
|
||||
target: OptimizationTarget,
|
||||
similarity: f32,
|
||||
) -> f32 {
|
||||
match target {
|
||||
OptimizationTarget::Cost => {
|
||||
// Lower cost = higher score
|
||||
let cost_score = 1.0 / (1.0 + agent.cost_model.per_request);
|
||||
cost_score * 0.7 + similarity * 0.3
|
||||
}
|
||||
OptimizationTarget::Latency => {
|
||||
// Lower latency = higher score
|
||||
let latency_score = 1.0 / (1.0 + agent.performance.avg_latency_ms / 1000.0);
|
||||
latency_score * 0.7 + similarity * 0.3
|
||||
}
|
||||
OptimizationTarget::Quality => {
|
||||
// Higher quality = higher score
|
||||
agent.performance.quality_score * 0.7 + similarity * 0.3
|
||||
}
|
||||
OptimizationTarget::Balanced => {
|
||||
// Balanced scoring
|
||||
let cost_score = 1.0 / (1.0 + agent.cost_model.per_request);
|
||||
let latency_score = 1.0 / (1.0 + agent.performance.avg_latency_ms / 1000.0);
|
||||
let quality_score = agent.performance.quality_score;
|
||||
|
||||
cost_score * 0.25 + latency_score * 0.25 + quality_score * 0.25 + similarity * 0.25
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare agent to best agent
|
||||
fn compare_to_best(&self, agent: &Agent, best: &Agent, target: OptimizationTarget) -> String {
|
||||
match target {
|
||||
OptimizationTarget::Cost => {
|
||||
let diff = agent.cost_model.per_request - best.cost_model.per_request;
|
||||
format!("${:.4} more expensive", diff)
|
||||
}
|
||||
OptimizationTarget::Latency => {
|
||||
let diff = agent.performance.avg_latency_ms - best.performance.avg_latency_ms;
|
||||
format!("{:.1}ms slower", diff)
|
||||
}
|
||||
OptimizationTarget::Quality => {
|
||||
let diff = best.performance.quality_score - agent.performance.quality_score;
|
||||
format!("{:.2} lower quality", diff)
|
||||
}
|
||||
OptimizationTarget::Balanced => "Lower overall score".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate reasoning for decision
|
||||
fn generate_reasoning(
|
||||
&self,
|
||||
agent: &Agent,
|
||||
target: OptimizationTarget,
|
||||
similarity: f32,
|
||||
) -> String {
|
||||
let target_reason = match target {
|
||||
OptimizationTarget::Cost => {
|
||||
format!("lowest cost (${:.4}/request)", agent.cost_model.per_request)
|
||||
}
|
||||
OptimizationTarget::Latency => format!(
|
||||
"fastest response ({:.1}ms avg)",
|
||||
agent.performance.avg_latency_ms
|
||||
),
|
||||
OptimizationTarget::Quality => format!(
|
||||
"highest quality (score: {:.2})",
|
||||
agent.performance.quality_score
|
||||
),
|
||||
OptimizationTarget::Balanced => "best overall balance".to_string(),
|
||||
};
|
||||
|
||||
format!(
|
||||
"Selected {} for {} with {:.1}% similarity to request",
|
||||
agent.name,
|
||||
target_reason,
|
||||
similarity * 100.0
|
||||
)
|
||||
}
|
||||
|
||||
/// Get registry reference
|
||||
pub fn registry(&self) -> &Arc<AgentRegistry> {
|
||||
&self.registry
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Router {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate cosine similarity between two vectors
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(dot_product / (norm_a * norm_b)).max(-1.0).min(1.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::routing::agents::AgentType;
|
||||
|
||||
fn create_test_agent(name: &str, cost: f32, latency: f32, quality: f32) -> Agent {
|
||||
let mut agent = Agent::new(name.to_string(), AgentType::LLM, vec!["test".to_string()]);
|
||||
agent.cost_model.per_request = cost;
|
||||
agent.performance.avg_latency_ms = latency;
|
||||
agent.performance.quality_score = quality;
|
||||
agent.embedding = Some(vec![0.1; 384]);
|
||||
agent
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimization_target_parsing() {
|
||||
assert_eq!(
|
||||
OptimizationTarget::from_str("cost"),
|
||||
OptimizationTarget::Cost
|
||||
);
|
||||
assert_eq!(
|
||||
OptimizationTarget::from_str("LATENCY"),
|
||||
OptimizationTarget::Latency
|
||||
);
|
||||
assert_eq!(
|
||||
OptimizationTarget::from_str("quality"),
|
||||
OptimizationTarget::Quality
|
||||
);
|
||||
assert_eq!(
|
||||
OptimizationTarget::from_str("balanced"),
|
||||
OptimizationTarget::Balanced
|
||||
);
|
||||
assert_eq!(
|
||||
OptimizationTarget::from_str("unknown"),
|
||||
OptimizationTarget::Balanced
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routing_constraints_builder() {
|
||||
let constraints = RoutingConstraints::new()
|
||||
.with_max_cost(0.1)
|
||||
.with_max_latency(500.0)
|
||||
.with_min_quality(0.8)
|
||||
.with_capability("test".to_string());
|
||||
|
||||
assert_eq!(constraints.max_cost, Some(0.1));
|
||||
assert_eq!(constraints.max_latency_ms, Some(500.0));
|
||||
assert_eq!(constraints.min_quality, Some(0.8));
|
||||
assert_eq!(constraints.required_capabilities.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![1.0, 0.0, 0.0];
|
||||
let d = vec![0.0, 1.0, 0.0];
|
||||
assert!(cosine_similarity(&c, &d).abs() < 1e-6);
|
||||
|
||||
let e = vec![1.0, 1.0, 0.0];
|
||||
let f = vec![1.0, 1.0, 0.0];
|
||||
assert!((cosine_similarity(&e, &f) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_router_creation() {
|
||||
let router = Router::new();
|
||||
assert!(router.grnn.is_none());
|
||||
assert_eq!(router.registry().count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_router_init_grnn() {
|
||||
let mut router = Router::new();
|
||||
router.init_grnn(64);
|
||||
assert!(router.grnn.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_cost_optimization() {
|
||||
let router = Router::new();
|
||||
|
||||
// Register agents with different costs
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("cheap", 0.01, 100.0, 0.7))
|
||||
.unwrap();
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("expensive", 0.10, 100.0, 0.9))
|
||||
.unwrap();
|
||||
|
||||
let request_emb = vec![0.1; 384];
|
||||
let constraints = RoutingConstraints::new();
|
||||
|
||||
let decision = router
|
||||
.route(&request_emb, &constraints, OptimizationTarget::Cost)
|
||||
.unwrap();
|
||||
assert_eq!(decision.agent_name, "cheap");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_latency_optimization() {
|
||||
let router = Router::new();
|
||||
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("fast", 0.05, 50.0, 0.7))
|
||||
.unwrap();
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("slow", 0.05, 500.0, 0.9))
|
||||
.unwrap();
|
||||
|
||||
let request_emb = vec![0.1; 384];
|
||||
let constraints = RoutingConstraints::new();
|
||||
|
||||
let decision = router
|
||||
.route(&request_emb, &constraints, OptimizationTarget::Latency)
|
||||
.unwrap();
|
||||
assert_eq!(decision.agent_name, "fast");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_quality_optimization() {
|
||||
let router = Router::new();
|
||||
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("low_quality", 0.05, 100.0, 0.5))
|
||||
.unwrap();
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("high_quality", 0.05, 100.0, 0.95))
|
||||
.unwrap();
|
||||
|
||||
let request_emb = vec![0.1; 384];
|
||||
let constraints = RoutingConstraints::new();
|
||||
|
||||
let decision = router
|
||||
.route(&request_emb, &constraints, OptimizationTarget::Quality)
|
||||
.unwrap();
|
||||
assert_eq!(decision.agent_name, "high_quality");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_with_constraints() {
|
||||
let router = Router::new();
|
||||
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("expensive", 1.0, 100.0, 0.9))
|
||||
.unwrap();
|
||||
router
|
||||
.registry()
|
||||
.register(create_test_agent("cheap", 0.01, 100.0, 0.7))
|
||||
.unwrap();
|
||||
|
||||
let request_emb = vec![0.1; 384];
|
||||
let constraints = RoutingConstraints::new().with_max_cost(0.5);
|
||||
|
||||
let decision = router
|
||||
.route(&request_emb, &constraints, OptimizationTarget::Quality)
|
||||
.unwrap();
|
||||
// Should select cheap even though expensive has higher quality
|
||||
assert_eq!(decision.agent_name, "cheap");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_no_candidates() {
|
||||
let router = Router::new();
|
||||
let request_emb = vec![0.1; 384];
|
||||
let constraints = RoutingConstraints::new();
|
||||
|
||||
let result = router.route(&request_emb, &constraints, OptimizationTarget::Balanced);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_capability_filter() {
|
||||
let router = Router::new();
|
||||
|
||||
let mut agent1 = create_test_agent("coder", 0.05, 100.0, 0.8);
|
||||
agent1.capabilities = vec!["coding".to_string()];
|
||||
|
||||
let mut agent2 = create_test_agent("translator", 0.05, 100.0, 0.8);
|
||||
agent2.capabilities = vec!["translation".to_string()];
|
||||
|
||||
router.registry().register(agent1).unwrap();
|
||||
router.registry().register(agent2).unwrap();
|
||||
|
||||
let request_emb = vec![0.1; 384];
|
||||
let constraints = RoutingConstraints::new().with_capability("coding".to_string());
|
||||
|
||||
let decision = router
|
||||
.route(&request_emb, &constraints, OptimizationTarget::Balanced)
|
||||
.unwrap();
|
||||
assert_eq!(decision.agent_name, "coder");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user