Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,125 @@
//! Bridge configuration: distance metrics and tuning knobs.
use serde::{Deserialize, Serialize};
/// Distance metric used for spatial search operations.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceMetric {
#[default]
Euclidean,
Cosine,
Manhattan,
}
/// Top-level configuration for the robotics bridge.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BridgeConfig {
/// Dimensionality of the vector space (typically 3 for XYZ).
pub dimensions: usize,
/// Metric used when computing distances.
pub distance_metric: DistanceMetric,
/// Maximum number of points the spatial index will accept.
pub max_points: usize,
/// Default *k* value for nearest-neighbour queries.
pub search_k: usize,
}
impl Default for BridgeConfig {
fn default() -> Self {
Self {
dimensions: 3,
distance_metric: DistanceMetric::Euclidean,
max_points: 1_000_000,
search_k: 10,
}
}
}
impl BridgeConfig {
/// Create a new configuration with explicit values.
pub fn new(
dimensions: usize,
distance_metric: DistanceMetric,
max_points: usize,
search_k: usize,
) -> Self {
Self {
dimensions,
distance_metric,
max_points,
search_k,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_distance_metric() {
assert_eq!(DistanceMetric::default(), DistanceMetric::Euclidean);
}
#[test]
fn test_default_bridge_config() {
let cfg = BridgeConfig::default();
assert_eq!(cfg.dimensions, 3);
assert_eq!(cfg.distance_metric, DistanceMetric::Euclidean);
assert_eq!(cfg.max_points, 1_000_000);
assert_eq!(cfg.search_k, 10);
}
#[test]
fn test_config_serde_roundtrip_json() {
let cfg = BridgeConfig::new(128, DistanceMetric::Cosine, 500_000, 20);
let json = serde_json::to_string(&cfg).unwrap();
let restored: BridgeConfig = serde_json::from_str(&json).unwrap();
assert_eq!(cfg, restored);
}
#[test]
fn test_config_serde_roundtrip_default() {
let cfg = BridgeConfig::default();
let json = serde_json::to_string_pretty(&cfg).unwrap();
let restored: BridgeConfig = serde_json::from_str(&json).unwrap();
assert_eq!(cfg, restored);
}
#[test]
fn test_distance_metric_serde_variants() {
for metric in [
DistanceMetric::Euclidean,
DistanceMetric::Cosine,
DistanceMetric::Manhattan,
] {
let json = serde_json::to_string(&metric).unwrap();
let restored: DistanceMetric = serde_json::from_str(&json).unwrap();
assert_eq!(metric, restored);
}
}
#[test]
fn test_config_new() {
let cfg = BridgeConfig::new(64, DistanceMetric::Manhattan, 250_000, 5);
assert_eq!(cfg.dimensions, 64);
assert_eq!(cfg.distance_metric, DistanceMetric::Manhattan);
assert_eq!(cfg.max_points, 250_000);
assert_eq!(cfg.search_k, 5);
}
#[test]
fn test_config_clone_eq() {
let a = BridgeConfig::default();
let b = a.clone();
assert_eq!(a, b);
}
#[test]
fn test_config_debug_format() {
let cfg = BridgeConfig::default();
let dbg = format!("{:?}", cfg);
assert!(dbg.contains("BridgeConfig"));
assert!(dbg.contains("Euclidean"));
}
}

View File

@@ -0,0 +1,390 @@
//! Type conversion functions between robotics domain types and flat vector
//! representations suitable for indexing, serialization, or ML inference.
use crate::bridge::{OccupancyGrid, Point3D, PointCloud, Pose, RobotState, SceneGraph};
// Quaternion is used in tests for constructing Pose values.
#[cfg(test)]
use crate::bridge::Quaternion;
use std::fmt;
/// Errors that can occur during type conversions.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConversionError {
/// The input vector length does not match the expected dimensionality.
LengthMismatch { expected: usize, got: usize },
/// The input collection was empty when a non-empty one was required.
EmptyInput,
}
impl fmt::Display for ConversionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::LengthMismatch { expected, got } => {
write!(f, "length mismatch: expected {expected}, got {got}")
}
Self::EmptyInput => write!(f, "empty input"),
}
}
}
impl std::error::Error for ConversionError {}
/// Convert a [`PointCloud`] to a `Vec` of `[x, y, z]` vectors.
pub fn point_cloud_to_vectors(cloud: &PointCloud) -> Vec<Vec<f32>> {
cloud.points.iter().map(|p| vec![p.x, p.y, p.z]).collect()
}
/// Convert a [`PointCloud`] to `[x, y, z, intensity]` vectors.
///
/// Returns [`ConversionError::LengthMismatch`] when the intensity array length
/// does not match the number of points.
pub fn point_cloud_to_vectors_with_intensity(
cloud: &PointCloud,
) -> Result<Vec<Vec<f32>>, ConversionError> {
if cloud.points.len() != cloud.intensities.len() {
return Err(ConversionError::LengthMismatch {
expected: cloud.points.len(),
got: cloud.intensities.len(),
});
}
Ok(cloud
.points
.iter()
.zip(cloud.intensities.iter())
.map(|(p, &i)| vec![p.x, p.y, p.z, i])
.collect())
}
/// Reconstruct a [`PointCloud`] from `[x, y, z]` vectors.
///
/// Each inner vector **must** have exactly 3 elements.
pub fn vectors_to_point_cloud(
vectors: &[Vec<f32>],
timestamp: i64,
) -> Result<PointCloud, ConversionError> {
if vectors.is_empty() {
return Err(ConversionError::EmptyInput);
}
let mut points = Vec::with_capacity(vectors.len());
for v in vectors {
if v.len() != 3 {
return Err(ConversionError::LengthMismatch {
expected: 3,
got: v.len(),
});
}
points.push(Point3D::new(v[0], v[1], v[2]));
}
Ok(PointCloud::new(points, timestamp))
}
/// Flatten a [`RobotState`] into `[px, py, pz, vx, vy, vz, ax, ay, az]`.
pub fn robot_state_to_vector(state: &RobotState) -> Vec<f64> {
let mut v = Vec::with_capacity(9);
v.extend_from_slice(&state.position);
v.extend_from_slice(&state.velocity);
v.extend_from_slice(&state.acceleration);
v
}
/// Reconstruct a [`RobotState`] from a 9-element vector and a timestamp.
pub fn vector_to_robot_state(
v: &[f64],
timestamp: i64,
) -> Result<RobotState, ConversionError> {
if v.len() != 9 {
return Err(ConversionError::LengthMismatch {
expected: 9,
got: v.len(),
});
}
Ok(RobotState {
position: [v[0], v[1], v[2]],
velocity: [v[3], v[4], v[5]],
acceleration: [v[6], v[7], v[8]],
timestamp_us: timestamp,
})
}
/// Flatten a [`Pose`] into `[px, py, pz, qx, qy, qz, qw]`.
pub fn pose_to_vector(pose: &Pose) -> Vec<f64> {
vec![
pose.position[0],
pose.position[1],
pose.position[2],
pose.orientation.x,
pose.orientation.y,
pose.orientation.z,
pose.orientation.w,
]
}
/// Extract occupied cells (value > 0.5) as `[world_x, world_y, value]` vectors.
pub fn occupancy_grid_to_vectors(grid: &OccupancyGrid) -> Vec<Vec<f32>> {
let mut result = Vec::new();
for y in 0..grid.height {
for x in 0..grid.width {
let val = grid.get(x, y).unwrap_or(0.0);
if val > 0.5 {
let wx = grid.origin[0] as f32 + x as f32 * grid.resolution as f32;
let wy = grid.origin[1] as f32 + y as f32 * grid.resolution as f32;
result.push(vec![wx, wy, val]);
}
}
}
result
}
/// Convert a [`SceneGraph`] into node feature vectors and an edge list.
///
/// Each node vector is `[cx, cy, cz, ex, ey, ez, confidence]`.
/// Each edge tuple is `(from_index, to_index, distance)`.
type NodeFeatures = Vec<Vec<f64>>;
type EdgeList = Vec<(usize, usize, f64)>;
pub fn scene_graph_to_adjacency(
scene: &SceneGraph,
) -> (NodeFeatures, EdgeList) {
let nodes: Vec<Vec<f64>> = scene
.objects
.iter()
.map(|o| {
vec![
o.center[0],
o.center[1],
o.center[2],
o.extent[0],
o.extent[1],
o.extent[2],
o.confidence as f64,
]
})
.collect();
let edges: Vec<(usize, usize, f64)> = scene
.edges
.iter()
.map(|e| (e.from, e.to, e.distance))
.collect();
(nodes, edges)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::{OccupancyGrid, SceneEdge, SceneObject};
#[test]
fn test_point_cloud_to_vectors_basic() {
let cloud = PointCloud::new(
vec![Point3D::new(1.0, 2.0, 3.0), Point3D::new(4.0, 5.0, 6.0)],
100,
);
let vecs = point_cloud_to_vectors(&cloud);
assert_eq!(vecs.len(), 2);
assert_eq!(vecs[0], vec![1.0, 2.0, 3.0]);
assert_eq!(vecs[1], vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_point_cloud_to_vectors_empty() {
let cloud = PointCloud::default();
let vecs = point_cloud_to_vectors(&cloud);
assert!(vecs.is_empty());
}
#[test]
fn test_point_cloud_with_intensity_ok() {
let cloud = PointCloud::new(vec![Point3D::new(1.0, 2.0, 3.0)], 0);
let vecs = point_cloud_to_vectors_with_intensity(&cloud).unwrap();
assert_eq!(vecs[0], vec![1.0, 2.0, 3.0, 1.0]);
}
#[test]
fn test_point_cloud_with_intensity_mismatch() {
let mut cloud = PointCloud::new(vec![Point3D::new(1.0, 2.0, 3.0)], 0);
cloud.intensities = vec![];
let err = point_cloud_to_vectors_with_intensity(&cloud).unwrap_err();
assert_eq!(err, ConversionError::LengthMismatch { expected: 1, got: 0 });
}
#[test]
fn test_vectors_to_point_cloud_ok() {
let vecs = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let cloud = vectors_to_point_cloud(&vecs, 42).unwrap();
assert_eq!(cloud.len(), 2);
assert_eq!(cloud.timestamp(), 42);
assert_eq!(cloud.points[0].x, 1.0);
}
#[test]
fn test_vectors_to_point_cloud_empty() {
let vecs: Vec<Vec<f32>> = vec![];
let err = vectors_to_point_cloud(&vecs, 0).unwrap_err();
assert_eq!(err, ConversionError::EmptyInput);
}
#[test]
fn test_vectors_to_point_cloud_wrong_dim() {
let vecs = vec![vec![1.0, 2.0]];
let err = vectors_to_point_cloud(&vecs, 0).unwrap_err();
assert_eq!(err, ConversionError::LengthMismatch { expected: 3, got: 2 });
}
#[test]
fn test_point_cloud_roundtrip() {
let pts = vec![Point3D::new(1.0, 2.0, 3.0), Point3D::new(-1.0, 0.0, 5.5)];
let original = PointCloud::new(pts, 999);
let vecs = point_cloud_to_vectors(&original);
let restored = vectors_to_point_cloud(&vecs, 999).unwrap();
assert_eq!(restored.len(), original.len());
}
#[test]
fn test_robot_state_to_vector() {
let state = RobotState {
position: [1.0, 2.0, 3.0],
velocity: [4.0, 5.0, 6.0],
acceleration: [7.0, 8.0, 9.0],
timestamp_us: 0,
};
let v = robot_state_to_vector(&state);
assert_eq!(v, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
}
#[test]
fn test_vector_to_robot_state_ok() {
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let state = vector_to_robot_state(&v, 123).unwrap();
assert_eq!(state.position, [1.0, 2.0, 3.0]);
assert_eq!(state.velocity, [4.0, 5.0, 6.0]);
assert_eq!(state.acceleration, [7.0, 8.0, 9.0]);
assert_eq!(state.timestamp_us, 123);
}
#[test]
fn test_vector_to_robot_state_wrong_len() {
let v = vec![1.0, 2.0, 3.0];
let err = vector_to_robot_state(&v, 0).unwrap_err();
assert_eq!(err, ConversionError::LengthMismatch { expected: 9, got: 3 });
}
#[test]
fn test_robot_state_roundtrip() {
let original = RobotState {
position: [10.0, 20.0, 30.0],
velocity: [-1.0, -2.0, -3.0],
acceleration: [0.1, 0.2, 0.3],
timestamp_us: 555,
};
let v = robot_state_to_vector(&original);
let restored = vector_to_robot_state(&v, 555).unwrap();
assert_eq!(original, restored);
}
#[test]
fn test_pose_to_vector() {
let pose = Pose {
position: [1.0, 2.0, 3.0],
orientation: Quaternion::new(0.1, 0.2, 0.3, 0.9),
frame_id: "map".into(),
};
let v = pose_to_vector(&pose);
assert_eq!(v.len(), 7);
assert!((v[0] - 1.0).abs() < f64::EPSILON);
assert!((v[3] - 0.1).abs() < f64::EPSILON);
assert!((v[6] - 0.9).abs() < f64::EPSILON);
}
#[test]
fn test_pose_to_vector_identity() {
let pose = Pose::default();
let v = pose_to_vector(&pose);
assert_eq!(v, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]);
}
#[test]
fn test_occupancy_grid_to_vectors_no_occupied() {
let grid = OccupancyGrid::new(5, 5, 0.1);
let vecs = occupancy_grid_to_vectors(&grid);
assert!(vecs.is_empty());
}
#[test]
fn test_occupancy_grid_to_vectors_some_occupied() {
let mut grid = OccupancyGrid::new(3, 3, 0.5);
grid.set(1, 2, 0.9);
grid.set(0, 0, 0.3); // below threshold
let vecs = occupancy_grid_to_vectors(&grid);
assert_eq!(vecs.len(), 1);
let v = &vecs[0];
assert!((v[2] - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_occupancy_grid_with_origin() {
let mut grid = OccupancyGrid::new(2, 2, 1.0);
grid.origin = [10.0, 20.0, 0.0];
grid.set(1, 0, 0.8);
let vecs = occupancy_grid_to_vectors(&grid);
assert_eq!(vecs.len(), 1);
assert!((vecs[0][0] - 11.0).abs() < f32::EPSILON); // wx = 10 + 1*1
assert!((vecs[0][1] - 20.0).abs() < f32::EPSILON); // wy = 20 + 0*1
}
#[test]
fn test_scene_graph_to_adjacency_empty() {
let scene = SceneGraph::default();
let (nodes, edges) = scene_graph_to_adjacency(&scene);
assert!(nodes.is_empty());
assert!(edges.is_empty());
}
#[test]
fn test_scene_graph_to_adjacency() {
let o1 = SceneObject::new(0, [1.0, 2.0, 3.0], [0.5, 0.5, 0.5]);
let o2 = SceneObject::new(1, [4.0, 5.0, 6.0], [1.0, 1.0, 1.0]);
let edge = SceneEdge {
from: 0,
to: 1,
distance: 5.196,
relation: "near".into(),
};
let scene = SceneGraph::new(vec![o1, o2], vec![edge], 0);
let (nodes, edges) = scene_graph_to_adjacency(&scene);
assert_eq!(nodes.len(), 2);
assert_eq!(nodes[0].len(), 7);
assert!((nodes[0][0] - 1.0).abs() < f64::EPSILON);
assert!((nodes[0][6] - 1.0).abs() < f64::EPSILON); // confidence
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].0, 0);
assert_eq!(edges[0].1, 1);
}
#[test]
fn test_point_cloud_10k_roundtrip() {
let points: Vec<Point3D> = (0..10_000)
.map(|i| {
let f = i as f32;
Point3D::new(f * 0.1, f * 0.2, f * 0.3)
})
.collect();
let cloud = PointCloud::new(points, 1_000_000);
let vecs = point_cloud_to_vectors(&cloud);
assert_eq!(vecs.len(), 10_000);
let restored = vectors_to_point_cloud(&vecs, 1_000_000).unwrap();
assert_eq!(restored.len(), 10_000);
}
#[test]
fn test_conversion_error_display() {
let e1 = ConversionError::LengthMismatch { expected: 3, got: 5 };
assert!(format!("{e1}").contains("3") && format!("{e1}").contains("5"));
assert!(format!("{}", ConversionError::EmptyInput).contains("empty"));
}
}

View File

@@ -0,0 +1,272 @@
//! Gaussian splatting types and point-cloud-to-Gaussian conversion.
//!
//! Provides a [`GaussianSplat`] representation that maps each point cloud
//! cluster to a 3D Gaussian with position, colour, opacity, scale, and
//! optional temporal trajectory. The serialised format is compatible with
//! the `vwm-viewer` Canvas2D renderer.
use crate::bridge::{Point3D, PointCloud};
use crate::perception::clustering;
use serde::{Deserialize, Serialize};
/// A single 3-D Gaussian suitable for splatting-based rendering.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GaussianSplat {
/// Centre of the Gaussian in world coordinates.
pub center: [f64; 3],
/// RGB colour in \[0, 1\].
pub color: [f32; 3],
/// Opacity in \[0, 1\].
pub opacity: f32,
/// Anisotropic scale along each axis.
pub scale: [f32; 3],
/// Number of raw points that contributed to this Gaussian.
pub point_count: usize,
/// Semantic label (e.g. `"obstacle"`, `"ground"`).
pub label: String,
/// Temporal trajectory: each entry is a position at a successive timestep.
/// Empty for static Gaussians.
pub trajectory: Vec<[f64; 3]>,
}
/// A collection of Gaussians derived from one or more point cloud frames.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GaussianSplatCloud {
pub gaussians: Vec<GaussianSplat>,
pub timestamp_us: i64,
pub frame_id: String,
}
impl GaussianSplatCloud {
/// Number of Gaussians.
pub fn len(&self) -> usize {
self.gaussians.len()
}
pub fn is_empty(&self) -> bool {
self.gaussians.is_empty()
}
}
/// Configuration for point-cloud → Gaussian conversion.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GaussianConfig {
/// Clustering cell size in metres. Smaller = more Gaussians.
pub cell_size: f64,
/// Minimum number of points to form a Gaussian.
pub min_cluster_size: usize,
/// Default colour for unlabelled Gaussians `[R, G, B]`.
pub default_color: [f32; 3],
/// Base opacity for generated Gaussians.
pub base_opacity: f32,
}
impl Default for GaussianConfig {
fn default() -> Self {
Self {
cell_size: 0.5,
min_cluster_size: 2,
default_color: [0.3, 0.5, 0.8],
base_opacity: 0.7,
}
}
}
/// Convert a [`PointCloud`] into a [`GaussianSplatCloud`] by clustering nearby
/// points and computing per-cluster statistics.
pub fn gaussians_from_cloud(
cloud: &PointCloud,
config: &GaussianConfig,
) -> GaussianSplatCloud {
if cloud.is_empty() || config.cell_size <= 0.0 {
return GaussianSplatCloud {
gaussians: Vec::new(),
timestamp_us: cloud.timestamp_us,
frame_id: cloud.frame_id.clone(),
};
}
let clusters = clustering::cluster_point_cloud(cloud, config.cell_size);
let gaussians: Vec<GaussianSplat> = clusters
.into_iter()
.filter(|c| c.len() >= config.min_cluster_size)
.map(|pts| cluster_to_gaussian(&pts, config))
.collect();
GaussianSplatCloud {
gaussians,
timestamp_us: cloud.timestamp_us,
frame_id: cloud.frame_id.clone(),
}
}
fn cluster_to_gaussian(points: &[Point3D], config: &GaussianConfig) -> GaussianSplat {
let n = points.len() as f64;
let (mut sx, mut sy, mut sz) = (0.0_f64, 0.0_f64, 0.0_f64);
for p in points {
sx += p.x as f64;
sy += p.y as f64;
sz += p.z as f64;
}
let center = [sx / n, sy / n, sz / n];
// Compute per-axis standard deviation as the scale.
let (mut vx, mut vy, mut vz) = (0.0_f64, 0.0_f64, 0.0_f64);
for p in points {
let dx = p.x as f64 - center[0];
let dy = p.y as f64 - center[1];
let dz = p.z as f64 - center[2];
vx += dx * dx;
vy += dy * dy;
vz += dz * dz;
}
let scale = [
(vx / n).sqrt().max(0.01) as f32,
(vy / n).sqrt().max(0.01) as f32,
(vz / n).sqrt().max(0.01) as f32,
];
// Opacity proportional to cluster density.
let opacity = (config.base_opacity * (points.len() as f32 / 50.0).min(1.0)).max(0.1);
GaussianSplat {
center,
color: config.default_color,
opacity,
scale,
point_count: points.len(),
label: String::new(),
trajectory: Vec::new(),
}
}
/// Serialise a [`GaussianSplatCloud`] to the JSON format expected by the
/// `vwm-viewer` Canvas2D renderer.
pub fn to_viewer_json(cloud: &GaussianSplatCloud) -> serde_json::Value {
let gs: Vec<serde_json::Value> = cloud
.gaussians
.iter()
.map(|g| {
let positions: Vec<Vec<f64>> = if g.trajectory.is_empty() {
vec![g.center.to_vec()]
} else {
g.trajectory.iter().map(|p| p.to_vec()).collect()
};
serde_json::json!({
"positions": positions,
"color": g.color,
"opacity": g.opacity,
"scale": g.scale,
"label": g.label,
"point_count": g.point_count,
})
})
.collect();
serde_json::json!({
"gaussians": gs,
"timestamp_us": cloud.timestamp_us,
"frame_id": cloud.frame_id,
"count": cloud.len(),
})
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cloud(pts: &[[f32; 3]], ts: i64) -> PointCloud {
let points: Vec<Point3D> = pts.iter().map(|a| Point3D::new(a[0], a[1], a[2])).collect();
PointCloud::new(points, ts)
}
#[test]
fn test_empty_cloud() {
let cloud = PointCloud::default();
let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
assert!(gs.is_empty());
}
#[test]
fn test_single_cluster() {
let cloud = make_cloud(
&[[1.0, 0.0, 0.0], [1.1, 0.0, 0.0], [1.0, 0.1, 0.0]],
1000,
);
let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
assert_eq!(gs.len(), 1);
let g = &gs.gaussians[0];
assert_eq!(g.point_count, 3);
assert!(g.center[0] > 0.9 && g.center[0] < 1.2);
}
#[test]
fn test_two_clusters() {
let cloud = make_cloud(
&[
[0.0, 0.0, 0.0], [0.1, 0.0, 0.0],
[10.0, 10.0, 0.0], [10.1, 10.0, 0.0],
],
2000,
);
let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
assert_eq!(gs.len(), 2);
}
#[test]
fn test_min_cluster_size_filtering() {
let cloud = make_cloud(
&[[0.0, 0.0, 0.0], [10.0, 10.0, 0.0]],
0,
);
let config = GaussianConfig { min_cluster_size: 3, ..Default::default() };
let gs = gaussians_from_cloud(&cloud, &config);
assert!(gs.is_empty());
}
#[test]
fn test_scale_reflects_spread() {
// Use a larger cell size so all three points end up in one cluster.
let cloud = make_cloud(
&[[0.0, 0.0, 0.0], [0.3, 0.0, 0.0], [0.15, 0.0, 0.0]],
0,
);
let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
assert_eq!(gs.len(), 1);
let g = &gs.gaussians[0];
// X-axis spread > Y/Z spread (Y/Z should be clamped minimum 0.01).
assert!(g.scale[0] > g.scale[1]);
}
#[test]
fn test_viewer_json_format() {
let cloud = make_cloud(&[[1.0, 2.0, 3.0], [1.1, 2.0, 3.0]], 5000);
let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
let json = to_viewer_json(&gs);
assert_eq!(json["count"], 1);
assert_eq!(json["timestamp_us"], 5000);
let arr = json["gaussians"].as_array().unwrap();
assert_eq!(arr.len(), 1);
assert!(arr[0]["positions"].is_array());
assert!(arr[0]["color"].is_array());
}
#[test]
fn test_serde_roundtrip() {
let cloud = make_cloud(&[[0.0, 0.0, 0.0], [0.1, 0.1, 0.0]], 0);
let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
let json = serde_json::to_string(&gs).unwrap();
let restored: GaussianSplatCloud = serde_json::from_str(&json).unwrap();
assert_eq!(restored.len(), gs.len());
}
#[test]
fn test_zero_cell_size() {
let cloud = make_cloud(&[[1.0, 0.0, 0.0]], 0);
let config = GaussianConfig { cell_size: 0.0, ..Default::default() };
let gs = gaussians_from_cloud(&cloud, &config);
assert!(gs.is_empty());
}
}

View File

@@ -0,0 +1,506 @@
//! Flat spatial index with brute-force nearest-neighbour and radius search.
//!
//! Supports Euclidean, Manhattan, and Cosine distance metrics. Designed as a
//! lightweight, dependency-free baseline that can be swapped for an HNSW
//! implementation when the dataset outgrows brute-force search.
//!
//! ## Optimizations
//!
//! - Points stored in a flat `Vec<f32>` buffer (stride = dimensions) for cache
//! locality and zero per-point heap allocation.
//! - Euclidean kNN uses **squared** distances for comparison, deferring the
//! `sqrt` to only the final `k` results.
//! - Cosine distance computed in a single fused loop (dot, norm_a, norm_b).
use crate::bridge::config::DistanceMetric;
use crate::bridge::PointCloud;
#[cfg(test)]
use crate::bridge::Point3D;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::fmt;
/// Entry for the max-heap used by kNN search.
#[derive(PartialEq)]
struct MaxDistEntry {
index: usize,
distance: f32,
}
impl Eq for MaxDistEntry {}
impl PartialOrd for MaxDistEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxDistEntry {
fn cmp(&self, other: &Self) -> Ordering {
// Larger distance = higher priority in the max-heap.
// NaN is treated as maximally distant so it gets evicted first.
self.distance
.partial_cmp(&other.distance)
.unwrap_or_else(|| {
if self.distance.is_nan() && other.distance.is_nan() {
Ordering::Equal
} else if self.distance.is_nan() {
Ordering::Greater
} else {
Ordering::Less
}
})
}
}
/// Errors returned by [`SpatialIndex`] operations.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IndexError {
/// Query or insertion vector has a different dimensionality than the index.
DimensionMismatch { expected: usize, got: usize },
/// The index contains no points.
EmptyIndex,
}
impl fmt::Display for IndexError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DimensionMismatch { expected, got } => {
write!(f, "dimension mismatch: expected {expected}, got {got}")
}
Self::EmptyIndex => write!(f, "index is empty"),
}
}
}
impl std::error::Error for IndexError {}
/// Squared Euclidean distance (avoids sqrt for comparison-only use).
#[inline]
fn euclidean_distance_sq(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum()
}
#[inline]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
euclidean_distance_sq(a, b).sqrt()
}
#[inline]
fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.sum()
}
/// Cosine distance in a single fused loop (dot + both norms together).
#[inline]
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let (mut dot, mut norm_a, mut norm_b) = (0.0_f32, 0.0_f32, 0.0_f32);
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < f32::EPSILON {
return 1.0; // zero-vectors are maximally dissimilar
}
(1.0 - (dot / denom).clamp(-1.0, 1.0)).max(0.0)
}
/// A flat spatial index that stores points in a contiguous `f32` buffer.
///
/// Points are stored in a flat `Vec<f32>` with stride equal to `dimensions`,
/// eliminating per-point heap allocations and improving cache locality.
#[derive(Debug, Clone)]
pub struct SpatialIndex {
dimensions: usize,
metric: DistanceMetric,
/// Flat buffer: point `i` occupies `data[i*dimensions .. (i+1)*dimensions]`.
data: Vec<f32>,
}
impl SpatialIndex {
/// Create a new index with the given dimensionality and Euclidean metric.
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
metric: DistanceMetric::Euclidean,
data: Vec::new(),
}
}
/// Create a new index with an explicit distance metric.
pub fn with_metric(dimensions: usize, metric: DistanceMetric) -> Self {
Self {
dimensions,
metric,
data: Vec::new(),
}
}
/// Insert all points from a [`PointCloud`] into the index.
pub fn insert_point_cloud(&mut self, cloud: &PointCloud) {
self.data.reserve(cloud.points.len() * 3);
for p in &cloud.points {
self.data.extend_from_slice(&[p.x, p.y, p.z]);
}
}
/// Insert pre-built vectors. Vectors whose length does not match the
/// index dimensionality are silently skipped.
pub fn insert_vectors(&mut self, vectors: &[Vec<f32>]) {
for v in vectors {
if v.len() == self.dimensions {
self.data.extend_from_slice(v);
}
}
}
/// Find the `k` nearest neighbours to `query`.
///
/// Returns `(index, distance)` pairs sorted by ascending distance.
/// Uses a max-heap of size `k` for O(n log k) instead of O(n log n).
///
/// For the Euclidean metric, squared distances are used internally and
/// only the final `k` results are square-rooted.
pub fn search_nearest(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<(usize, f32)>, IndexError> {
let n = self.len();
if n == 0 {
return Err(IndexError::EmptyIndex);
}
if query.len() != self.dimensions {
return Err(IndexError::DimensionMismatch {
expected: self.dimensions,
got: query.len(),
});
}
let use_sq = matches!(self.metric, DistanceMetric::Euclidean);
let mut heap: BinaryHeap<MaxDistEntry> = BinaryHeap::with_capacity(k + 1);
let dim = self.dimensions;
for i in 0..n {
let p = &self.data[i * dim..(i + 1) * dim];
let d = if use_sq {
euclidean_distance_sq(query, p)
} else {
self.compute_distance(query, p)
};
if heap.len() < k {
heap.push(MaxDistEntry {
index: i,
distance: d,
});
} else if let Some(top) = heap.peek() {
if d < top.distance {
heap.pop();
heap.push(MaxDistEntry {
index: i,
distance: d,
});
}
}
}
let mut result: Vec<(usize, f32)> = heap
.into_iter()
.map(|e| {
let dist = if use_sq { e.distance.sqrt() } else { e.distance };
(e.index, dist)
})
.collect();
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(result)
}
/// Find all points within `radius` of `center`.
///
/// Returns `(index, distance)` pairs in arbitrary order.
/// For the Euclidean metric, squared distances are compared against
/// `radius²` internally.
pub fn search_radius(
&self,
center: &[f32],
radius: f32,
) -> Result<Vec<(usize, f32)>, IndexError> {
let n = self.len();
if n == 0 {
return Ok(Vec::new());
}
if center.len() != self.dimensions {
return Err(IndexError::DimensionMismatch {
expected: self.dimensions,
got: center.len(),
});
}
let use_sq = matches!(self.metric, DistanceMetric::Euclidean);
let threshold = if use_sq { radius * radius } else { radius };
let dim = self.dimensions;
let mut results = Vec::new();
for i in 0..n {
let p = &self.data[i * dim..(i + 1) * dim];
let d = if use_sq {
euclidean_distance_sq(center, p)
} else {
self.compute_distance(center, p)
};
if d <= threshold {
let dist = if use_sq { d.sqrt() } else { d };
results.push((i, dist));
}
}
Ok(results)
}
/// Number of points stored in the index.
pub fn len(&self) -> usize {
if self.dimensions == 0 {
0
} else {
self.data.len() / self.dimensions
}
}
/// Returns `true` if the index contains no points.
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Remove all points from the index.
pub fn clear(&mut self) {
self.data.clear();
}
/// The dimensionality of this index.
pub fn dimensions(&self) -> usize {
self.dimensions
}
/// The distance metric in use.
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[inline]
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.metric {
DistanceMetric::Euclidean => euclidean_distance(a, b),
DistanceMetric::Manhattan => manhattan_distance(a, b),
DistanceMetric::Cosine => cosine_distance(a, b),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cloud(pts: &[[f32; 3]]) -> PointCloud {
let points: Vec<Point3D> = pts.iter().map(|p| Point3D::new(p[0], p[1], p[2])).collect();
PointCloud::new(points, 0)
}
#[test]
fn test_new_index_is_empty() {
let idx = SpatialIndex::new(3);
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
assert_eq!(idx.dimensions(), 3);
}
#[test]
fn test_with_metric() {
let idx = SpatialIndex::with_metric(4, DistanceMetric::Cosine);
assert_eq!(idx.metric(), DistanceMetric::Cosine);
assert_eq!(idx.dimensions(), 4);
}
#[test]
fn test_insert_point_cloud() {
let mut idx = SpatialIndex::new(3);
let cloud = make_cloud(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
idx.insert_point_cloud(&cloud);
assert_eq!(idx.len(), 2);
}
#[test]
fn test_insert_vectors() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
assert_eq!(idx.len(), 2);
}
#[test]
fn test_insert_vectors_skips_wrong_dim() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[vec![1.0, 2.0], vec![4.0, 5.0, 6.0]]);
assert_eq!(idx.len(), 1); // only the 3-d vector was inserted
}
#[test]
fn test_search_nearest_basic() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[
vec![0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0],
vec![10.0, 0.0, 0.0],
]);
let results = idx.search_nearest(&[0.5, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
// closest should be index 0 or 1
assert!(results[0].0 == 0 || results[0].0 == 1);
}
#[test]
fn test_search_nearest_returns_sorted() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[
vec![10.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0],
vec![5.0, 0.0, 0.0],
]);
let results = idx.search_nearest(&[0.0, 0.0, 0.0], 3).unwrap();
assert!(results[0].1 <= results[1].1);
assert!(results[1].1 <= results[2].1);
}
#[test]
fn test_search_nearest_k_larger_than_len() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[vec![1.0, 2.0, 3.0]]);
let results = idx.search_nearest(&[0.0, 0.0, 0.0], 10).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_search_nearest_empty_index() {
let idx = SpatialIndex::new(3);
let err = idx.search_nearest(&[0.0, 0.0, 0.0], 1).unwrap_err();
assert_eq!(err, IndexError::EmptyIndex);
}
#[test]
fn test_search_nearest_dim_mismatch() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[vec![1.0, 2.0, 3.0]]);
let err = idx.search_nearest(&[0.0, 0.0], 1).unwrap_err();
assert_eq!(err, IndexError::DimensionMismatch { expected: 3, got: 2 });
}
#[test]
fn test_search_radius_basic() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[
vec![0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0],
vec![10.0, 0.0, 0.0],
]);
let results = idx.search_radius(&[0.0, 0.0, 0.0], 1.5).unwrap();
assert_eq!(results.len(), 2); // indices 0 and 1
}
#[test]
fn test_search_radius_empty_index() {
let idx = SpatialIndex::new(3);
let results = idx.search_radius(&[0.0, 0.0, 0.0], 1.0).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_radius_no_results() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[vec![100.0, 100.0, 100.0]]);
let results = idx.search_radius(&[0.0, 0.0, 0.0], 1.0).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_clear() {
let mut idx = SpatialIndex::new(3);
idx.insert_vectors(&[vec![1.0, 2.0, 3.0]]);
assert!(!idx.is_empty());
idx.clear();
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
}
#[test]
fn test_euclidean_distance() {
let d = euclidean_distance(&[0.0, 0.0, 0.0], &[3.0, 4.0, 0.0]);
assert!((d - 5.0).abs() < 1e-5);
}
#[test]
fn test_manhattan_distance() {
let d = manhattan_distance(&[0.0, 0.0, 0.0], &[3.0, 4.0, 1.0]);
assert!((d - 8.0).abs() < 1e-5);
}
#[test]
fn test_cosine_distance_identical() {
let d = cosine_distance(&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0]);
assert!(d.abs() < 1e-5);
}
#[test]
fn test_cosine_distance_orthogonal() {
let d = cosine_distance(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]);
assert!((d - 1.0).abs() < 1e-5);
}
#[test]
fn test_cosine_distance_zero_vector() {
let d = cosine_distance(&[0.0, 0.0, 0.0], &[1.0, 2.0, 3.0]);
assert!((d - 1.0).abs() < 1e-5);
}
#[test]
fn test_manhattan_metric_search() {
let mut idx = SpatialIndex::with_metric(3, DistanceMetric::Manhattan);
idx.insert_vectors(&[
vec![0.0, 0.0, 0.0],
vec![1.0, 1.0, 1.0],
vec![10.0, 10.0, 10.0],
]);
let results = idx.search_nearest(&[0.0, 0.0, 0.0], 1).unwrap();
assert_eq!(results[0].0, 0);
assert!(results[0].1.abs() < 1e-5);
}
#[test]
fn test_stress_10k_points() {
let mut idx = SpatialIndex::new(3);
let vecs: Vec<Vec<f32>> = (0..10_000)
.map(|i| vec![i as f32 * 0.01, i as f32 * 0.02, i as f32 * 0.03])
.collect();
idx.insert_vectors(&vecs);
assert_eq!(idx.len(), 10_000);
let results = idx.search_nearest(&[0.0, 0.0, 0.0], 5).unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, 0);
}
#[test]
fn test_index_error_display() {
let e = IndexError::DimensionMismatch { expected: 3, got: 5 };
assert!(format!("{e}").contains("3"));
assert!(format!("{}", IndexError::EmptyIndex).contains("empty"));
}
}

View File

@@ -0,0 +1,383 @@
//! Core robotics types, converters, spatial indexing, and perception pipeline.
//!
//! This module provides the foundational types that all other robotics modules
//! build upon: point clouds, robot state, scene graphs, poses, and trajectories.
pub mod config;
pub mod converters;
pub mod gaussian;
pub mod indexing;
pub mod pipeline;
pub mod search;
use serde::{Deserialize, Serialize};
// Re-exports
pub use config::{BridgeConfig, DistanceMetric};
pub use converters::ConversionError;
pub use gaussian::{GaussianConfig, GaussianSplat, GaussianSplatCloud};
pub use indexing::{IndexError, SpatialIndex};
pub use pipeline::{PerceptionResult, PipelineConfig, PipelineStats};
pub use search::{AlertSeverity, Neighbor, ObstacleAlert, SearchResult};
// ---------------------------------------------------------------------------
// Core types
// ---------------------------------------------------------------------------
/// 3D point used in point clouds and spatial operations.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Point3D {
pub x: f32,
pub y: f32,
pub z: f32,
}
impl Point3D {
pub fn new(x: f32, y: f32, z: f32) -> Self {
Self { x, y, z }
}
pub fn distance_to(&self, other: &Point3D) -> f32 {
let dx = self.x - other.x;
let dy = self.y - other.y;
let dz = self.z - other.z;
(dx * dx + dy * dy + dz * dz).sqrt()
}
pub fn as_f64_array(&self) -> [f64; 3] {
[self.x as f64, self.y as f64, self.z as f64]
}
pub fn from_f64_array(arr: &[f64; 3]) -> Self {
Self {
x: arr[0] as f32,
y: arr[1] as f32,
z: arr[2] as f32,
}
}
pub fn to_vec(&self) -> Vec<f32> {
vec![self.x, self.y, self.z]
}
}
/// A unit quaternion representing a 3-D rotation.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Quaternion {
pub x: f64,
pub y: f64,
pub z: f64,
pub w: f64,
}
impl Default for Quaternion {
fn default() -> Self {
Self { x: 0.0, y: 0.0, z: 0.0, w: 1.0 }
}
}
impl Quaternion {
pub fn identity() -> Self {
Self::default()
}
pub fn new(x: f64, y: f64, z: f64, w: f64) -> Self {
Self { x, y, z, w }
}
}
/// A 6-DOF pose: position + orientation.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Pose {
pub position: [f64; 3],
pub orientation: Quaternion,
pub frame_id: String,
}
impl Default for Pose {
fn default() -> Self {
Self {
position: [0.0; 3],
orientation: Quaternion::identity(),
frame_id: String::new(),
}
}
}
/// A collection of 3D points from a sensor (LiDAR, depth camera, etc.).
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PointCloud {
pub points: Vec<Point3D>,
pub intensities: Vec<f32>,
pub normals: Option<Vec<Point3D>>,
pub timestamp_us: i64,
pub frame_id: String,
}
impl PointCloud {
pub fn new(points: Vec<Point3D>, timestamp: i64) -> Self {
let len = points.len();
Self {
points,
intensities: vec![1.0; len],
normals: None,
timestamp_us: timestamp,
frame_id: String::new(),
}
}
pub fn len(&self) -> usize {
self.points.len()
}
pub fn is_empty(&self) -> bool {
self.points.is_empty()
}
pub fn timestamp(&self) -> i64 {
self.timestamp_us
}
}
/// Robot state: position, velocity, acceleration, timestamp.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RobotState {
pub position: [f64; 3],
pub velocity: [f64; 3],
pub acceleration: [f64; 3],
pub timestamp_us: i64,
}
impl Default for RobotState {
fn default() -> Self {
Self {
position: [0.0; 3],
velocity: [0.0; 3],
acceleration: [0.0; 3],
timestamp_us: 0,
}
}
}
/// A synchronised bundle of sensor observations captured at one instant.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SensorFrame {
pub cloud: Option<PointCloud>,
pub state: Option<RobotState>,
pub pose: Option<Pose>,
pub timestamp_us: i64,
}
/// A 2-D occupancy grid map.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OccupancyGrid {
pub width: usize,
pub height: usize,
pub resolution: f64,
pub data: Vec<f32>,
pub origin: [f64; 3],
}
impl OccupancyGrid {
pub fn new(width: usize, height: usize, resolution: f64) -> Self {
let size = width.checked_mul(height).expect("grid size overflow");
Self {
width,
height,
resolution,
data: vec![0.0; size],
origin: [0.0; 3],
}
}
/// Get the occupancy value at `(x, y)`, or `None` if out of bounds.
pub fn get(&self, x: usize, y: usize) -> Option<f32> {
if x < self.width && y < self.height {
Some(self.data[y * self.width + x])
} else {
None
}
}
/// Set the occupancy value at `(x, y)`. Out-of-bounds writes are ignored.
pub fn set(&mut self, x: usize, y: usize, value: f32) {
if x < self.width && y < self.height {
self.data[y * self.width + x] = value;
}
}
}
/// An object detected in a scene with bounding information.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SceneObject {
pub id: usize,
pub center: [f64; 3],
pub extent: [f64; 3],
pub confidence: f32,
pub label: String,
pub velocity: Option<[f64; 3]>,
}
impl SceneObject {
pub fn new(id: usize, center: [f64; 3], extent: [f64; 3]) -> Self {
Self {
id,
center,
extent,
confidence: 1.0,
label: String::new(),
velocity: None,
}
}
}
/// An edge in a scene graph connecting two objects.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SceneEdge {
pub from: usize,
pub to: usize,
pub distance: f64,
pub relation: String,
}
/// A scene graph representing spatial relationships between objects.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SceneGraph {
pub objects: Vec<SceneObject>,
pub edges: Vec<SceneEdge>,
pub timestamp: i64,
}
impl SceneGraph {
pub fn new(objects: Vec<SceneObject>, edges: Vec<SceneEdge>, timestamp: i64) -> Self {
Self { objects, edges, timestamp }
}
}
/// A predicted trajectory consisting of waypoints.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trajectory {
pub waypoints: Vec<[f64; 3]>,
pub timestamps: Vec<i64>,
pub confidence: f64,
}
impl Trajectory {
pub fn new(waypoints: Vec<[f64; 3]>, timestamps: Vec<i64>, confidence: f64) -> Self {
Self { waypoints, timestamps, confidence }
}
pub fn len(&self) -> usize {
self.waypoints.len()
}
pub fn is_empty(&self) -> bool {
self.waypoints.is_empty()
}
}
/// An obstacle detected by the perception pipeline.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Obstacle {
pub id: u64,
pub position: [f64; 3],
pub distance: f64,
pub radius: f64,
pub label: String,
pub confidence: f32,
}
/// Bridge error type.
#[derive(Debug, thiserror::Error)]
pub enum BridgeError {
#[error("Invalid data: {0}")]
InvalidData(String),
#[error("Conversion error: {0}")]
ConversionError(String),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, BridgeError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_point3d_distance() {
let a = Point3D::new(0.0, 0.0, 0.0);
let b = Point3D::new(3.0, 4.0, 0.0);
assert!((a.distance_to(&b) - 5.0).abs() < 1e-6);
}
#[test]
fn test_point_cloud() {
let cloud = PointCloud::new(vec![Point3D::new(1.0, 2.0, 3.0)], 100);
assert_eq!(cloud.len(), 1);
assert_eq!(cloud.timestamp(), 100);
}
#[test]
fn test_robot_state_default() {
let state = RobotState::default();
assert_eq!(state.position, [0.0; 3]);
assert_eq!(state.velocity, [0.0; 3]);
}
#[test]
fn test_scene_graph() {
let obj = SceneObject::new(0, [1.0, 2.0, 3.0], [0.5, 0.5, 0.5]);
let graph = SceneGraph::new(vec![obj], vec![], 0);
assert_eq!(graph.objects.len(), 1);
}
#[test]
fn test_quaternion_identity() {
let q = Quaternion::identity();
assert_eq!(q.w, 1.0);
assert_eq!(q.x, 0.0);
}
#[test]
fn test_pose_default() {
let p = Pose::default();
assert_eq!(p.position, [0.0; 3]);
assert_eq!(p.orientation.w, 1.0);
}
#[test]
fn test_sensor_frame_default() {
let f = SensorFrame::default();
assert!(f.cloud.is_none());
assert!(f.state.is_none());
assert!(f.pose.is_none());
}
#[test]
fn test_occupancy_grid() {
let mut grid = OccupancyGrid::new(10, 10, 0.05);
grid.set(3, 4, 0.8);
assert!((grid.get(3, 4).unwrap() - 0.8).abs() < f32::EPSILON);
assert!(grid.get(10, 10).is_none()); // out of bounds
}
#[test]
fn test_trajectory() {
let t = Trajectory::new(
vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
vec![100, 200],
0.95,
);
assert_eq!(t.len(), 2);
assert!(!t.is_empty());
}
#[test]
fn test_serde_roundtrip() {
let obj = SceneObject::new(0, [1.0, 2.0, 3.0], [0.5, 0.5, 0.5]);
let json = serde_json::to_string(&obj).unwrap();
let obj2: SceneObject = serde_json::from_str(&json).unwrap();
assert_eq!(obj.id, obj2.id);
}
}

View File

@@ -0,0 +1,392 @@
//! Lightweight perception pipeline that ingests [`SensorFrame`]s, maintains a
//! [`SpatialIndex`], detects obstacles, and predicts linear trajectories.
use crate::bridge::indexing::SpatialIndex;
use crate::bridge::{Obstacle, PointCloud, SensorFrame, Trajectory};
use serde::{Deserialize, Serialize};
use std::time::Instant;
/// Configuration for the perception pipeline.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineConfig {
/// Points within this radius (metres) of the robot are classified as obstacles.
pub obstacle_radius: f64,
/// Whether to predict linear trajectories from consecutive frames.
pub track_trajectories: bool,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
obstacle_radius: 2.0,
track_trajectories: true,
}
}
}
/// Output of a single pipeline frame.
#[derive(Debug, Clone)]
pub struct PerceptionResult {
/// Detected obstacles in the current frame.
pub obstacles: Vec<Obstacle>,
/// Linear trajectory prediction (if enabled and enough history exists).
pub trajectory_prediction: Option<Trajectory>,
/// Wall-clock latency of this frame in microseconds.
pub frame_latency_us: u64,
}
/// Cumulative pipeline statistics.
#[derive(Debug, Clone, Default)]
pub struct PipelineStats {
/// Number of frames processed so far.
pub frames_processed: u64,
/// Running average latency in microseconds.
pub avg_latency_us: f64,
/// Total number of distinct objects tracked.
pub objects_tracked: u64,
}
/// A stateful perception pipeline.
///
/// Call [`process_frame`](Self::process_frame) once per sensor tick.
pub struct PerceptionPipeline {
config: PipelineConfig,
index: SpatialIndex,
stats: PipelineStats,
/// Positions from previous frames for trajectory prediction (capped at 1000).
position_history: Vec<[f64; 3]>,
last_timestamp: i64,
obstacle_counter: u64,
}
impl PerceptionPipeline {
/// Create a pipeline with the given configuration.
pub fn new(config: PipelineConfig) -> Self {
Self {
config,
index: SpatialIndex::new(3),
stats: PipelineStats::default(),
position_history: Vec::new(),
last_timestamp: 0,
obstacle_counter: 0,
}
}
/// Process a single [`SensorFrame`] and return perception results.
pub fn process_frame(&mut self, frame: &SensorFrame) -> PerceptionResult {
let start = Instant::now();
// -- 1. Rebuild spatial index from the point cloud -------------------
self.index.clear();
if let Some(ref cloud) = frame.cloud {
self.index.insert_point_cloud(cloud);
}
// -- 2. Determine robot position (prefer state, fallback to pose) ----
let robot_pos: Option<[f64; 3]> = frame
.state
.as_ref()
.map(|s| s.position)
.or_else(|| frame.pose.as_ref().map(|p| p.position));
// -- 3. Detect obstacles ---------------------------------------------
let obstacles = self.detect_obstacles(robot_pos, frame.cloud.as_ref());
// -- 4. Trajectory prediction ----------------------------------------
let trajectory_prediction = if self.config.track_trajectories {
if let Some(pos) = robot_pos {
// Cap history to prevent unbounded memory growth.
if self.position_history.len() >= 1000 {
self.position_history.drain(..500);
}
self.position_history.push(pos);
self.predict_trajectory(frame.timestamp_us)
} else {
None
}
} else {
None
};
// -- 5. Update stats -------------------------------------------------
let elapsed_us = start.elapsed().as_micros() as u64;
self.stats.frames_processed += 1;
self.stats.objects_tracked += obstacles.len() as u64;
let n = self.stats.frames_processed as f64;
self.stats.avg_latency_us =
self.stats.avg_latency_us * ((n - 1.0) / n) + elapsed_us as f64 / n;
self.last_timestamp = frame.timestamp_us;
PerceptionResult {
obstacles,
trajectory_prediction,
frame_latency_us: elapsed_us,
}
}
/// Return cumulative statistics.
pub fn stats(&self) -> &PipelineStats {
&self.stats
}
/// Return a reference to the current spatial index.
pub fn index(&self) -> &SpatialIndex {
&self.index
}
fn detect_obstacles(
&mut self,
robot_pos: Option<[f64; 3]>,
cloud: Option<&PointCloud>,
) -> Vec<Obstacle> {
let robot_pos = match robot_pos {
Some(p) => p,
None => return Vec::new(),
};
let cloud = match cloud {
Some(c) if !c.is_empty() => c,
_ => return Vec::new(),
};
let query = [
robot_pos[0] as f32,
robot_pos[1] as f32,
robot_pos[2] as f32,
];
let radius = self.config.obstacle_radius as f32;
let neighbours = match self.index.search_radius(&query, radius) {
Ok(n) => n,
Err(_) => return Vec::new(),
};
neighbours
.into_iter()
.map(|(idx, dist)| {
let pt = &cloud.points[idx];
self.obstacle_counter += 1;
Obstacle {
id: self.obstacle_counter,
position: [pt.x as f64, pt.y as f64, pt.z as f64],
distance: dist as f64,
radius: 0.1, // point-level radius
label: String::new(),
confidence: 1.0,
}
})
.collect()
}
fn predict_trajectory(&self, current_ts: i64) -> Option<Trajectory> {
if self.position_history.len() < 2 {
return None;
}
let n = self.position_history.len();
let prev = &self.position_history[n - 2];
let curr = &self.position_history[n - 1];
let vel = [
curr[0] - prev[0],
curr[1] - prev[1],
curr[2] - prev[2],
];
// Predict 5 steps into the future with constant velocity.
let steps = 5;
let dt_us: i64 = 100_000; // 100 ms per step
let mut waypoints = Vec::with_capacity(steps);
let mut timestamps = Vec::with_capacity(steps);
for i in 1..=steps {
let t = i as f64;
waypoints.push([
curr[0] + vel[0] * t,
curr[1] + vel[1] * t,
curr[2] + vel[2] * t,
]);
timestamps.push(current_ts + dt_us * i as i64);
}
Some(Trajectory::new(waypoints, timestamps, 0.8))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::{Point3D, PointCloud, RobotState, SensorFrame};
fn make_frame(
points: Vec<Point3D>,
position: [f64; 3],
ts: i64,
) -> SensorFrame {
SensorFrame {
cloud: Some(PointCloud::new(points, ts)),
state: Some(RobotState {
position,
velocity: [0.0; 3],
acceleration: [0.0; 3],
timestamp_us: ts,
}),
pose: None,
timestamp_us: ts,
}
}
#[test]
fn test_empty_frame() {
let mut pipeline = PerceptionPipeline::new(PipelineConfig::default());
let frame = SensorFrame::default();
let result = pipeline.process_frame(&frame);
assert!(result.obstacles.is_empty());
assert!(result.trajectory_prediction.is_none());
assert_eq!(pipeline.stats().frames_processed, 1);
}
#[test]
fn test_obstacle_detection() {
let config = PipelineConfig {
obstacle_radius: 5.0,
..Default::default()
};
let mut pipeline = PerceptionPipeline::new(config);
// Robot at origin, obstacle at (1, 0, 0).
let frame = make_frame(
vec![
Point3D::new(1.0, 0.0, 0.0),
Point3D::new(100.0, 0.0, 0.0), // too far
],
[0.0, 0.0, 0.0],
1000,
);
let result = pipeline.process_frame(&frame);
assert_eq!(result.obstacles.len(), 1);
assert!((result.obstacles[0].distance - 1.0).abs() < 0.01);
}
#[test]
fn test_no_obstacles_when_far() {
let config = PipelineConfig {
obstacle_radius: 0.5,
..Default::default()
};
let mut pipeline = PerceptionPipeline::new(config);
let frame = make_frame(
vec![Point3D::new(10.0, 10.0, 10.0)],
[0.0, 0.0, 0.0],
1000,
);
let result = pipeline.process_frame(&frame);
assert!(result.obstacles.is_empty());
}
#[test]
fn test_trajectory_prediction() {
let config = PipelineConfig {
track_trajectories: true,
obstacle_radius: 0.1,
..Default::default()
};
let mut pipeline = PerceptionPipeline::new(config);
// Frame 1: robot at (0, 0, 0).
let f1 = make_frame(vec![], [0.0, 0.0, 0.0], 1000);
let r1 = pipeline.process_frame(&f1);
assert!(r1.trajectory_prediction.is_none()); // need 2+ frames
// Frame 2: robot at (1, 0, 0) => velocity = (1, 0, 0).
let f2 = make_frame(vec![], [1.0, 0.0, 0.0], 2000);
let r2 = pipeline.process_frame(&f2);
let traj = r2.trajectory_prediction.unwrap();
assert_eq!(traj.len(), 5);
// First predicted waypoint should be ~(2, 0, 0).
assert!((traj.waypoints[0][0] - 2.0).abs() < 1e-6);
}
#[test]
fn test_trajectory_disabled() {
let config = PipelineConfig {
track_trajectories: false,
..Default::default()
};
let mut pipeline = PerceptionPipeline::new(config);
let f1 = make_frame(vec![], [0.0, 0.0, 0.0], 0);
let f2 = make_frame(vec![], [1.0, 0.0, 0.0], 1000);
pipeline.process_frame(&f1);
let r2 = pipeline.process_frame(&f2);
assert!(r2.trajectory_prediction.is_none());
}
#[test]
fn test_stats_accumulate() {
let mut pipeline = PerceptionPipeline::new(PipelineConfig::default());
for _ in 0..5 {
pipeline.process_frame(&SensorFrame::default());
}
assert_eq!(pipeline.stats().frames_processed, 5);
}
#[test]
fn test_obstacle_ids_increment() {
let config = PipelineConfig {
obstacle_radius: 100.0,
..Default::default()
};
let mut pipeline = PerceptionPipeline::new(config);
let frame = make_frame(
vec![
Point3D::new(1.0, 0.0, 0.0),
Point3D::new(2.0, 0.0, 0.0),
],
[0.0, 0.0, 0.0],
0,
);
let r = pipeline.process_frame(&frame);
assert_eq!(r.obstacles.len(), 2);
// IDs should be monotonically increasing.
assert!(r.obstacles[0].id < r.obstacles[1].id);
}
#[test]
fn test_pipeline_config_serde() {
let cfg = PipelineConfig {
obstacle_radius: 3.5,
track_trajectories: false,
};
let json = serde_json::to_string(&cfg).unwrap();
let restored: PipelineConfig = serde_json::from_str(&json).unwrap();
assert!((restored.obstacle_radius - 3.5).abs() < f64::EPSILON);
assert!(!restored.track_trajectories);
}
#[test]
fn test_index_is_rebuilt_per_frame() {
let config = PipelineConfig {
obstacle_radius: 100.0,
..Default::default()
};
let mut pipeline = PerceptionPipeline::new(config);
let f1 = make_frame(
vec![Point3D::new(1.0, 0.0, 0.0)],
[0.0, 0.0, 0.0],
0,
);
pipeline.process_frame(&f1);
assert_eq!(pipeline.index().len(), 1);
let f2 = make_frame(
vec![
Point3D::new(1.0, 0.0, 0.0),
Point3D::new(2.0, 0.0, 0.0),
Point3D::new(3.0, 0.0, 0.0),
],
[0.0, 0.0, 0.0],
1000,
);
pipeline.process_frame(&f2);
// Index should reflect only the latest frame's cloud.
assert_eq!(pipeline.index().len(), 3);
}
}

View File

@@ -0,0 +1,174 @@
//! Search result types used by the bridge and perception layers.
use serde::{Deserialize, Serialize};
/// Severity level for obstacle proximity alerts.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AlertSeverity {
/// Immediate collision risk.
Critical,
/// Object is approaching but not imminent.
Warning,
/// Informational -- object detected at moderate range.
Info,
}
/// A single nearest-neighbour result.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Neighbor {
/// Unique identifier of the indexed point.
pub id: u64,
/// Distance from the query point to this neighbour.
pub distance: f32,
/// 3-D position of the neighbour.
pub position: [f32; 3],
}
/// Result of a spatial search query.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SearchResult {
/// Identifier of the query that produced this result.
pub query_id: u64,
/// Nearest neighbours, sorted by ascending distance.
pub neighbors: Vec<Neighbor>,
/// Wall-clock latency of the search in microseconds.
pub latency_us: u64,
}
/// An alert generated when an obstacle is within a safety threshold.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ObstacleAlert {
/// Identifier of the obstacle that triggered the alert.
pub obstacle_id: u64,
/// Distance to the obstacle in metres.
pub distance: f32,
/// Unit direction vector pointing from the robot towards the obstacle.
pub direction: [f32; 3],
/// Severity of this alert.
pub severity: AlertSeverity,
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neighbor_serde_roundtrip() {
let n = Neighbor {
id: 42,
distance: 1.5,
position: [1.0, 2.0, 3.0],
};
let json = serde_json::to_string(&n).unwrap();
let restored: Neighbor = serde_json::from_str(&json).unwrap();
assert_eq!(n, restored);
}
#[test]
fn test_search_result_serde_roundtrip() {
let sr = SearchResult {
query_id: 7,
neighbors: vec![
Neighbor {
id: 1,
distance: 0.5,
position: [0.0, 0.0, 0.0],
},
Neighbor {
id: 2,
distance: 1.2,
position: [1.0, 1.0, 1.0],
},
],
latency_us: 150,
};
let json = serde_json::to_string(&sr).unwrap();
let restored: SearchResult = serde_json::from_str(&json).unwrap();
assert_eq!(sr, restored);
}
#[test]
fn test_obstacle_alert_serde_roundtrip() {
let alert = ObstacleAlert {
obstacle_id: 99,
distance: 0.3,
direction: [1.0, 0.0, 0.0],
severity: AlertSeverity::Critical,
};
let json = serde_json::to_string(&alert).unwrap();
let restored: ObstacleAlert = serde_json::from_str(&json).unwrap();
assert_eq!(alert, restored);
}
#[test]
fn test_alert_severity_all_variants() {
for severity in [
AlertSeverity::Critical,
AlertSeverity::Warning,
AlertSeverity::Info,
] {
let json = serde_json::to_string(&severity).unwrap();
let restored: AlertSeverity = serde_json::from_str(&json).unwrap();
assert_eq!(severity, restored);
}
}
#[test]
fn test_search_result_empty_neighbors() {
let sr = SearchResult {
query_id: 0,
neighbors: vec![],
latency_us: 0,
};
let json = serde_json::to_string(&sr).unwrap();
let restored: SearchResult = serde_json::from_str(&json).unwrap();
assert!(restored.neighbors.is_empty());
}
#[test]
fn test_neighbor_debug_format() {
let n = Neighbor {
id: 1,
distance: 0.0,
position: [0.0, 0.0, 0.0],
};
let dbg = format!("{:?}", n);
assert!(dbg.contains("Neighbor"));
}
#[test]
fn test_obstacle_alert_direction_preserved() {
let alert = ObstacleAlert {
obstacle_id: 1,
distance: 5.0,
direction: [0.577, 0.577, 0.577],
severity: AlertSeverity::Warning,
};
let json = serde_json::to_string(&alert).unwrap();
let restored: ObstacleAlert = serde_json::from_str(&json).unwrap();
for i in 0..3 {
assert!((alert.direction[i] - restored.direction[i]).abs() < 1e-6);
}
}
#[test]
fn test_search_result_json_structure() {
let sr = SearchResult {
query_id: 10,
neighbors: vec![Neighbor {
id: 5,
distance: 2.5,
position: [3.0, 4.0, 5.0],
}],
latency_us: 42,
};
let json = serde_json::to_string_pretty(&sr).unwrap();
assert!(json.contains("query_id"));
assert!(json.contains("neighbors"));
assert!(json.contains("latency_us"));
}
}

View File

@@ -0,0 +1,433 @@
//! Composable behavior trees for robot task execution.
//!
//! Provides a declarative way to build complex robot behaviors from simple
//! building blocks: actions, conditions, sequences, selectors, decorators,
//! and parallel nodes.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ---------------------------------------------------------------------------
// Status & decorator types
// ---------------------------------------------------------------------------
/// Result of ticking a behavior tree node.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BehaviorStatus {
Success,
Failure,
Running,
}
/// Decorator modifiers that wrap a single child node.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DecoratorType {
/// Inverts Success <-> Failure; Running stays Running.
Inverter,
/// Repeats the child a fixed number of times.
Repeat(usize),
/// Keeps ticking the child until it returns Failure.
UntilFail,
/// Fails the child if it does not finish within `ms` milliseconds (tick count proxy).
Timeout(u64),
}
// ---------------------------------------------------------------------------
// Node enum
// ---------------------------------------------------------------------------
/// A single node in the behavior tree.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BehaviorNode {
/// Leaf that executes a named action. Result is looked up in the context.
Action(String),
/// Leaf that checks a named boolean condition in the context.
Condition(String),
/// Runs children left-to-right; stops on first non-Success.
Sequence(Vec<BehaviorNode>),
/// Runs children left-to-right; stops on first Success.
Selector(Vec<BehaviorNode>),
/// Applies a [`DecoratorType`] to a single child.
Decorator(DecoratorType, Box<BehaviorNode>),
/// Runs all children concurrently; succeeds when `threshold` children succeed.
Parallel(usize, Vec<BehaviorNode>),
}
// ---------------------------------------------------------------------------
// Context (blackboard)
// ---------------------------------------------------------------------------
/// Shared context passed through the tree during evaluation.
#[derive(Debug, Clone, Default)]
pub struct BehaviorContext {
/// General-purpose string key-value store.
pub blackboard: HashMap<String, String>,
/// Monotonically increasing tick counter.
pub tick_count: u64,
/// Boolean conditions used by `Condition` nodes.
pub conditions: HashMap<String, bool>,
/// Pre-set results for `Action` nodes.
pub action_results: HashMap<String, BehaviorStatus>,
}
// ---------------------------------------------------------------------------
// Tree
// ---------------------------------------------------------------------------
/// A behavior tree with a root node and shared context.
#[derive(Debug, Clone)]
pub struct BehaviorTree {
root: BehaviorNode,
context: BehaviorContext,
}
impl BehaviorTree {
/// Create a new tree with the given root node.
pub fn new(root: BehaviorNode) -> Self {
Self {
root,
context: BehaviorContext::default(),
}
}
/// Tick the tree once, returning the root status.
pub fn tick(&mut self) -> BehaviorStatus {
self.context.tick_count += 1;
// Split borrow: `root` is borrowed immutably, `context` mutably.
// Rust allows this because they are disjoint struct fields.
let root = &self.root;
let ctx = &mut self.context;
eval(root, ctx)
}
/// Reset the context (tick count, blackboard, etc.).
pub fn reset(&mut self) {
self.context = BehaviorContext::default();
}
/// Set a named boolean condition.
pub fn set_condition(&mut self, name: &str, value: bool) {
self.context.conditions.insert(name.to_string(), value);
}
/// Set the result that a named action should return.
pub fn set_action_result(&mut self, name: &str, status: BehaviorStatus) {
self.context
.action_results
.insert(name.to_string(), status);
}
/// Read-only access to the context.
pub fn context(&self) -> &BehaviorContext {
&self.context
}
/// Read-only reference to the root node.
pub fn root(&self) -> &BehaviorNode {
&self.root
}
}
/// Maximum iterations for `UntilFail` before returning `Running` as a safety
/// guard against infinite loops when the child always succeeds.
const UNTIL_FAIL_MAX_ITERATIONS: usize = 10_000;
// Free functions that borrow the tree and context independently, avoiding the
// need to clone the tree on every tick.
fn eval(node: &BehaviorNode, ctx: &mut BehaviorContext) -> BehaviorStatus {
match node {
BehaviorNode::Action(name) => ctx
.action_results
.get(name)
.copied()
.unwrap_or(BehaviorStatus::Failure),
BehaviorNode::Condition(name) => {
if ctx.conditions.get(name).copied().unwrap_or(false) {
BehaviorStatus::Success
} else {
BehaviorStatus::Failure
}
}
BehaviorNode::Sequence(children) => {
for child in children {
match eval(child, ctx) {
BehaviorStatus::Success => continue,
other => return other,
}
}
BehaviorStatus::Success
}
BehaviorNode::Selector(children) => {
for child in children {
match eval(child, ctx) {
BehaviorStatus::Failure => continue,
other => return other,
}
}
BehaviorStatus::Failure
}
BehaviorNode::Decorator(dtype, child) => eval_decorator(dtype, child, ctx),
BehaviorNode::Parallel(threshold, children) => {
let mut success_count = 0usize;
let mut any_running = false;
for child in children {
match eval(child, ctx) {
BehaviorStatus::Success => success_count += 1,
BehaviorStatus::Running => any_running = true,
BehaviorStatus::Failure => {}
}
}
if success_count >= *threshold {
BehaviorStatus::Success
} else if any_running {
BehaviorStatus::Running
} else {
BehaviorStatus::Failure
}
}
}
}
fn eval_decorator(
dtype: &DecoratorType,
child: &BehaviorNode,
ctx: &mut BehaviorContext,
) -> BehaviorStatus {
match dtype {
DecoratorType::Inverter => match eval(child, ctx) {
BehaviorStatus::Success => BehaviorStatus::Failure,
BehaviorStatus::Failure => BehaviorStatus::Success,
BehaviorStatus::Running => BehaviorStatus::Running,
},
DecoratorType::Repeat(n) => {
for _ in 0..*n {
match eval(child, ctx) {
BehaviorStatus::Failure => return BehaviorStatus::Failure,
BehaviorStatus::Running => return BehaviorStatus::Running,
BehaviorStatus::Success => {}
}
}
BehaviorStatus::Success
}
DecoratorType::UntilFail => {
for _ in 0..UNTIL_FAIL_MAX_ITERATIONS {
match eval(child, ctx) {
BehaviorStatus::Failure => return BehaviorStatus::Success,
BehaviorStatus::Running => return BehaviorStatus::Running,
BehaviorStatus::Success => continue,
}
}
// Safety: child never failed within the iteration budget.
BehaviorStatus::Running
}
DecoratorType::Timeout(max_ticks) => {
if ctx.tick_count > *max_ticks {
return BehaviorStatus::Failure;
}
eval(child, ctx)
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_action_success() {
let mut tree = BehaviorTree::new(BehaviorNode::Action("move".into()));
tree.set_action_result("move", BehaviorStatus::Success);
assert_eq!(tree.tick(), BehaviorStatus::Success);
}
#[test]
fn test_action_default_failure() {
let mut tree = BehaviorTree::new(BehaviorNode::Action("unknown".into()));
assert_eq!(tree.tick(), BehaviorStatus::Failure);
}
#[test]
fn test_condition_true() {
let mut tree = BehaviorTree::new(BehaviorNode::Condition("has_target".into()));
tree.set_condition("has_target", true);
assert_eq!(tree.tick(), BehaviorStatus::Success);
}
#[test]
fn test_condition_false() {
let mut tree = BehaviorTree::new(BehaviorNode::Condition("has_target".into()));
assert_eq!(tree.tick(), BehaviorStatus::Failure);
}
#[test]
fn test_sequence_all_success() {
let seq = BehaviorNode::Sequence(vec![
BehaviorNode::Action("a".into()),
BehaviorNode::Action("b".into()),
]);
let mut tree = BehaviorTree::new(seq);
tree.set_action_result("a", BehaviorStatus::Success);
tree.set_action_result("b", BehaviorStatus::Success);
assert_eq!(tree.tick(), BehaviorStatus::Success);
}
#[test]
fn test_sequence_early_failure() {
let seq = BehaviorNode::Sequence(vec![
BehaviorNode::Action("a".into()),
BehaviorNode::Action("b".into()),
]);
let mut tree = BehaviorTree::new(seq);
tree.set_action_result("a", BehaviorStatus::Failure);
tree.set_action_result("b", BehaviorStatus::Success);
assert_eq!(tree.tick(), BehaviorStatus::Failure);
}
#[test]
fn test_selector_first_success() {
let sel = BehaviorNode::Selector(vec![
BehaviorNode::Action("a".into()),
BehaviorNode::Action("b".into()),
]);
let mut tree = BehaviorTree::new(sel);
tree.set_action_result("a", BehaviorStatus::Success);
assert_eq!(tree.tick(), BehaviorStatus::Success);
}
#[test]
fn test_selector_fallback() {
let sel = BehaviorNode::Selector(vec![
BehaviorNode::Action("a".into()),
BehaviorNode::Action("b".into()),
]);
let mut tree = BehaviorTree::new(sel);
tree.set_action_result("a", BehaviorStatus::Failure);
tree.set_action_result("b", BehaviorStatus::Success);
assert_eq!(tree.tick(), BehaviorStatus::Success);
}
#[test]
fn test_selector_all_fail() {
let sel = BehaviorNode::Selector(vec![
BehaviorNode::Action("a".into()),
BehaviorNode::Action("b".into()),
]);
let mut tree = BehaviorTree::new(sel);
tree.set_action_result("a", BehaviorStatus::Failure);
tree.set_action_result("b", BehaviorStatus::Failure);
assert_eq!(tree.tick(), BehaviorStatus::Failure);
}
#[test]
fn test_inverter_decorator() {
let node = BehaviorNode::Decorator(
DecoratorType::Inverter,
Box::new(BehaviorNode::Action("a".into())),
);
let mut tree = BehaviorTree::new(node);
tree.set_action_result("a", BehaviorStatus::Success);
assert_eq!(tree.tick(), BehaviorStatus::Failure);
}
#[test]
fn test_repeat_decorator() {
let node = BehaviorNode::Decorator(
DecoratorType::Repeat(3),
Box::new(BehaviorNode::Action("a".into())),
);
let mut tree = BehaviorTree::new(node);
tree.set_action_result("a", BehaviorStatus::Success);
assert_eq!(tree.tick(), BehaviorStatus::Success);
}
#[test]
fn test_repeat_decorator_failure() {
let node = BehaviorNode::Decorator(
DecoratorType::Repeat(3),
Box::new(BehaviorNode::Action("a".into())),
);
let mut tree = BehaviorTree::new(node);
tree.set_action_result("a", BehaviorStatus::Failure);
assert_eq!(tree.tick(), BehaviorStatus::Failure);
}
#[test]
fn test_parallel_threshold() {
let par = BehaviorNode::Parallel(
2,
vec![
BehaviorNode::Action("a".into()),
BehaviorNode::Action("b".into()),
BehaviorNode::Action("c".into()),
],
);
let mut tree = BehaviorTree::new(par);
tree.set_action_result("a", BehaviorStatus::Success);
tree.set_action_result("b", BehaviorStatus::Success);
tree.set_action_result("c", BehaviorStatus::Failure);
assert_eq!(tree.tick(), BehaviorStatus::Success);
}
#[test]
fn test_parallel_running() {
let par = BehaviorNode::Parallel(
2,
vec![
BehaviorNode::Action("a".into()),
BehaviorNode::Action("b".into()),
],
);
let mut tree = BehaviorTree::new(par);
tree.set_action_result("a", BehaviorStatus::Success);
tree.set_action_result("b", BehaviorStatus::Running);
assert_eq!(tree.tick(), BehaviorStatus::Running);
}
#[test]
fn test_timeout_decorator() {
let node = BehaviorNode::Decorator(
DecoratorType::Timeout(2),
Box::new(BehaviorNode::Action("a".into())),
);
let mut tree = BehaviorTree::new(node);
tree.set_action_result("a", BehaviorStatus::Running);
// tick 1 => within timeout
assert_eq!(tree.tick(), BehaviorStatus::Running);
// tick 2 => within timeout
assert_eq!(tree.tick(), BehaviorStatus::Running);
// tick 3 => exceeds timeout
assert_eq!(tree.tick(), BehaviorStatus::Failure);
}
#[test]
fn test_reset() {
let mut tree = BehaviorTree::new(BehaviorNode::Action("a".into()));
tree.set_action_result("a", BehaviorStatus::Success);
tree.set_condition("flag", true);
tree.tick();
assert_eq!(tree.context().tick_count, 1);
tree.reset();
assert_eq!(tree.context().tick_count, 0);
assert!(tree.context().conditions.is_empty());
}
#[test]
fn test_blackboard() {
let mut tree = BehaviorTree::new(BehaviorNode::Action("a".into()));
tree.set_action_result("a", BehaviorStatus::Success);
tree.context
.blackboard
.insert("target".into(), "object_1".into());
assert_eq!(tree.context().blackboard.get("target").unwrap(), "object_1");
}
}

View File

@@ -0,0 +1,356 @@
//! Central cognitive loop: perceive -> think -> act -> learn.
//!
//! The [`CognitiveCore`] drives the robot's high-level autonomy by filtering
//! percepts, selecting actions through utility-based reasoning, and
//! incorporating feedback to improve future decisions.
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
// ---------------------------------------------------------------------------
// Enums
// ---------------------------------------------------------------------------
/// Operating mode of the cognitive system.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CognitiveMode {
/// Fast stimulus-response behaviour.
Reactive,
/// Goal-directed planning and reasoning.
Deliberative,
/// Override mode for safety-critical situations.
Emergency,
}
/// Current phase of the cognitive loop.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CognitiveState {
Idle,
Perceiving,
Thinking,
Acting,
Learning,
}
/// The kind of action the robot can execute.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ActionType {
Move([f64; 3]),
Rotate(f64),
Grasp(String),
Release,
Speak(String),
Wait(u64),
}
// ---------------------------------------------------------------------------
// Data structs
// ---------------------------------------------------------------------------
/// A command to execute an action with priority and confidence metadata.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ActionCommand {
pub action: ActionType,
pub priority: u8,
pub confidence: f64,
}
/// A single percept received from a sensor or subsystem.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Percept {
pub source: String,
pub data: Vec<f64>,
pub confidence: f64,
pub timestamp: i64,
}
/// A decision produced by the think phase.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Decision {
pub action: ActionCommand,
pub reasoning: String,
pub utility: f64,
}
/// Feedback from the environment after executing an action.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Outcome {
pub success: bool,
pub reward: f64,
pub description: String,
}
/// Configuration for the cognitive core.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CognitiveConfig {
pub mode: CognitiveMode,
pub attention_threshold: f64,
pub learning_rate: f64,
pub max_percepts: usize,
}
impl Default for CognitiveConfig {
fn default() -> Self {
Self {
mode: CognitiveMode::Reactive,
attention_threshold: 0.5,
learning_rate: 0.01,
max_percepts: 100,
}
}
}
// ---------------------------------------------------------------------------
// Core
// ---------------------------------------------------------------------------
/// Central cognitive controller implementing perceive-think-act-learn.
#[derive(Debug, Clone)]
pub struct CognitiveCore {
state: CognitiveState,
config: CognitiveConfig,
percept_buffer: VecDeque<Percept>,
decision_history: Vec<Decision>,
cumulative_reward: f64,
}
impl CognitiveCore {
/// Create a new cognitive core with the given configuration.
pub fn new(config: CognitiveConfig) -> Self {
Self {
state: CognitiveState::Idle,
config,
percept_buffer: VecDeque::new(),
decision_history: Vec::new(),
cumulative_reward: 0.0,
}
}
/// Ingest a percept and transition to the Perceiving state.
///
/// Percepts below the attention threshold are silently dropped.
/// When the buffer exceeds `max_percepts`, the oldest entry is removed.
pub fn perceive(&mut self, percept: Percept) -> CognitiveState {
self.state = CognitiveState::Perceiving;
if percept.confidence < self.config.attention_threshold {
return self.state;
}
if self.percept_buffer.len() >= self.config.max_percepts {
self.percept_buffer.pop_front(); // O(1) with VecDeque
}
self.percept_buffer.push_back(percept);
self.state
}
/// Deliberate over buffered percepts and produce a decision.
///
/// Returns `None` when no percepts are available.
pub fn think(&mut self) -> Option<Decision> {
self.state = CognitiveState::Thinking;
if self.percept_buffer.is_empty() {
return None;
}
// Simple heuristic: pick the most confident percept and derive an action.
let best = self
.percept_buffer
.iter()
.max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap_or(std::cmp::Ordering::Equal))?;
let action_type = if best.data.len() >= 3 {
ActionType::Move([best.data[0], best.data[1], best.data[2]])
} else {
ActionType::Wait(100)
};
let decision = Decision {
action: ActionCommand {
action: action_type,
priority: match self.config.mode {
CognitiveMode::Emergency => 255,
CognitiveMode::Deliberative => 128,
CognitiveMode::Reactive => 64,
},
confidence: best.confidence,
},
reasoning: format!("Best percept from '{}' (conf={:.2})", best.source, best.confidence),
utility: best.confidence,
};
self.decision_history.push(decision.clone());
Some(decision)
}
/// Convert a decision into an executable action command.
pub fn act(&mut self, decision: Decision) -> ActionCommand {
self.state = CognitiveState::Acting;
decision.action
}
/// Incorporate feedback from the environment to improve future behaviour.
pub fn learn(&mut self, outcome: Outcome) {
self.state = CognitiveState::Learning;
self.cumulative_reward += outcome.reward * self.config.learning_rate;
// Adjust attention threshold based on success/failure.
if outcome.success {
self.config.attention_threshold =
(self.config.attention_threshold - 0.01).max(0.1);
} else {
self.config.attention_threshold =
(self.config.attention_threshold + 0.01).min(0.9);
}
// Clear processed percepts so the next cycle starts fresh.
self.percept_buffer.clear();
self.state = CognitiveState::Idle;
}
/// Current cognitive state.
pub fn state(&self) -> CognitiveState {
self.state
}
/// Current operating mode.
pub fn mode(&self) -> CognitiveMode {
self.config.mode
}
/// Number of percepts currently buffered.
pub fn percept_count(&self) -> usize {
self.percept_buffer.len()
}
/// Number of decisions made so far.
pub fn decision_count(&self) -> usize {
self.decision_history.len()
}
/// Accumulated reward scaled by learning rate.
pub fn cumulative_reward(&self) -> f64 {
self.cumulative_reward
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn default_core() -> CognitiveCore {
CognitiveCore::new(CognitiveConfig::default())
}
fn make_percept(source: &str, data: Vec<f64>, confidence: f64) -> Percept {
Percept {
source: source.into(),
data,
confidence,
timestamp: 1000,
}
}
#[test]
fn test_initial_state() {
let core = default_core();
assert_eq!(core.state(), CognitiveState::Idle);
assert_eq!(core.mode(), CognitiveMode::Reactive);
assert_eq!(core.percept_count(), 0);
}
#[test]
fn test_perceive_above_threshold() {
let mut core = default_core();
let state = core.perceive(make_percept("lidar", vec![1.0, 2.0, 3.0], 0.8));
assert_eq!(state, CognitiveState::Perceiving);
assert_eq!(core.percept_count(), 1);
}
#[test]
fn test_perceive_below_threshold() {
let mut core = default_core();
core.perceive(make_percept("lidar", vec![1.0], 0.1));
assert_eq!(core.percept_count(), 0);
}
#[test]
fn test_think_produces_decision() {
let mut core = default_core();
core.perceive(make_percept("cam", vec![1.0, 2.0, 3.0], 0.9));
let decision = core.think();
assert!(decision.is_some());
let d = decision.unwrap();
assert_eq!(d.action.priority, 64); // Reactive mode
assert_eq!(core.decision_count(), 1);
}
#[test]
fn test_think_empty_buffer() {
let mut core = default_core();
assert!(core.think().is_none());
}
#[test]
fn test_act_returns_command() {
let mut core = default_core();
core.perceive(make_percept("cam", vec![1.0, 2.0, 3.0], 0.9));
let decision = core.think().unwrap();
let cmd = core.act(decision);
assert_eq!(cmd.action, ActionType::Move([1.0, 2.0, 3.0]));
assert_eq!(core.state(), CognitiveState::Acting);
}
#[test]
fn test_learn_adjusts_threshold() {
let mut core = default_core();
let initial = core.config.attention_threshold;
core.learn(Outcome {
success: true,
reward: 1.0,
description: "ok".into(),
});
assert!(core.config.attention_threshold < initial);
assert_eq!(core.state(), CognitiveState::Idle);
}
#[test]
fn test_learn_failure_raises_threshold() {
let mut core = default_core();
let initial = core.config.attention_threshold;
core.learn(Outcome {
success: false,
reward: -1.0,
description: "fail".into(),
});
assert!(core.config.attention_threshold > initial);
}
#[test]
fn test_emergency_priority() {
let mut core = CognitiveCore::new(CognitiveConfig {
mode: CognitiveMode::Emergency,
..CognitiveConfig::default()
});
core.perceive(make_percept("ir", vec![5.0, 6.0, 7.0], 0.99));
let d = core.think().unwrap();
assert_eq!(d.action.priority, 255);
}
#[test]
fn test_percept_buffer_overflow() {
let mut core = CognitiveCore::new(CognitiveConfig {
max_percepts: 2,
..CognitiveConfig::default()
});
core.perceive(make_percept("a", vec![1.0], 0.8));
core.perceive(make_percept("b", vec![2.0], 0.8));
core.perceive(make_percept("c", vec![3.0], 0.8));
assert_eq!(core.percept_count(), 2);
}
}

View File

@@ -0,0 +1,198 @@
//! Multi-criteria utility-based action selection.
//!
//! The [`DecisionEngine`] scores candidate actions using a weighted
//! combination of reward, risk, energy cost, and novelty to select the
//! best option for the current context.
use serde::{Deserialize, Serialize};
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
/// A candidate action with associated attributes.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ActionOption {
pub name: String,
pub reward: f64,
pub risk: f64,
pub energy_cost: f64,
pub novelty: f64,
}
/// Weights and parameters controlling the decision engine.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DecisionConfig {
/// How strongly the engine penalises risky actions (>= 0).
pub risk_aversion: f64,
/// How strongly the engine penalises energy expenditure (>= 0).
pub energy_weight: f64,
/// How strongly the engine rewards novel/exploratory actions (>= 0).
pub curiosity_weight: f64,
}
impl Default for DecisionConfig {
fn default() -> Self {
Self {
risk_aversion: 1.0,
energy_weight: 0.5,
curiosity_weight: 0.2,
}
}
}
// ---------------------------------------------------------------------------
// Engine
// ---------------------------------------------------------------------------
/// Evaluates candidate actions and selects the one with the highest utility.
#[derive(Debug, Clone)]
pub struct DecisionEngine {
config: DecisionConfig,
}
impl DecisionEngine {
/// Create a new engine with the given configuration.
pub fn new(config: DecisionConfig) -> Self {
Self { config }
}
/// Compute the utility of a single action option.
pub fn utility(&self, option: &ActionOption) -> f64 {
option.reward
- self.config.risk_aversion * option.risk
- self.config.energy_weight * option.energy_cost
+ self.config.curiosity_weight * option.novelty
}
/// Evaluate all options and return the index and utility of the best one.
///
/// Returns `None` when the slice is empty.
pub fn evaluate(&self, options: &[ActionOption]) -> Option<(usize, f64)> {
if options.is_empty() {
return None;
}
let mut best_idx = 0;
let mut best_util = f64::NEG_INFINITY;
for (i, opt) in options.iter().enumerate() {
let u = self.utility(opt);
if u > best_util {
best_util = u;
best_idx = i;
}
}
Some((best_idx, best_util))
}
/// Read-only access to the configuration.
pub fn config(&self) -> &DecisionConfig {
&self.config
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn make_option(name: &str, reward: f64, risk: f64, energy: f64, novelty: f64) -> ActionOption {
ActionOption {
name: name.into(),
reward,
risk,
energy_cost: energy,
novelty,
}
}
#[test]
fn test_single_option() {
let engine = DecisionEngine::new(DecisionConfig::default());
let options = vec![make_option("go", 1.0, 0.0, 0.0, 0.0)];
let (idx, util) = engine.evaluate(&options).unwrap();
assert_eq!(idx, 0);
assert!((util - 1.0).abs() < 1e-9);
}
#[test]
fn test_empty_options() {
let engine = DecisionEngine::new(DecisionConfig::default());
assert!(engine.evaluate(&[]).is_none());
}
#[test]
fn test_risk_penalty() {
let engine = DecisionEngine::new(DecisionConfig {
risk_aversion: 2.0,
energy_weight: 0.0,
curiosity_weight: 0.0,
});
let options = vec![
make_option("safe", 1.0, 0.0, 0.0, 0.0),
make_option("risky", 2.0, 1.0, 0.0, 0.0),
];
let (idx, _) = engine.evaluate(&options).unwrap();
assert_eq!(idx, 0); // safe: 1.0, risky: 2.0 - 2.0*1.0 = 0.0
}
#[test]
fn test_curiosity_bonus() {
let engine = DecisionEngine::new(DecisionConfig {
risk_aversion: 0.0,
energy_weight: 0.0,
curiosity_weight: 5.0,
});
let options = vec![
make_option("boring", 1.0, 0.0, 0.0, 0.0),
make_option("novel", 1.0, 0.0, 0.0, 2.0),
];
let (idx, _) = engine.evaluate(&options).unwrap();
assert_eq!(idx, 1); // novel: 1.0 + 5.0*2.0 = 11.0
}
#[test]
fn test_energy_penalty() {
let engine = DecisionEngine::new(DecisionConfig {
risk_aversion: 0.0,
energy_weight: 3.0,
curiosity_weight: 0.0,
});
let options = vec![
make_option("cheap", 1.0, 0.0, 0.1, 0.0),
make_option("expensive", 1.0, 0.0, 1.0, 0.0),
];
let (idx, _) = engine.evaluate(&options).unwrap();
assert_eq!(idx, 0);
}
#[test]
fn test_utility_formula() {
let engine = DecisionEngine::new(DecisionConfig {
risk_aversion: 1.0,
energy_weight: 0.5,
curiosity_weight: 0.2,
});
let opt = make_option("test", 10.0, 2.0, 4.0, 5.0);
// utility = 10 - 1*2 - 0.5*4 + 0.2*5 = 10 - 2 - 2 + 1 = 7.0
let u = engine.utility(&opt);
assert!((u - 7.0).abs() < 1e-9);
}
#[test]
fn test_multiple_options_best() {
let engine = DecisionEngine::new(DecisionConfig::default());
let options = vec![
make_option("a", 0.5, 0.1, 0.1, 0.1),
make_option("b", 5.0, 0.1, 0.1, 0.1),
make_option("c", 2.0, 0.1, 0.1, 0.1),
];
let (idx, _) = engine.evaluate(&options).unwrap();
assert_eq!(idx, 1);
}
}

View File

@@ -0,0 +1,359 @@
//! Three-tier memory system: working, episodic, and semantic.
//!
//! - **Working memory**: bounded short-term buffer for active items.
//! - **Episodic memory**: stores temporally ordered episodes for experience replay.
//! - **Semantic memory**: long-term concept storage with similarity search.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ---------------------------------------------------------------------------
// Working memory
// ---------------------------------------------------------------------------
/// A single item held in working memory.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MemoryItem {
pub key: String,
pub data: Vec<f64>,
pub importance: f64,
pub timestamp: i64,
pub access_count: u64,
}
/// Bounded short-term buffer. Evicts the least-important item when full.
#[derive(Debug, Clone)]
pub struct WorkingMemory {
items: Vec<MemoryItem>,
max_size: usize,
}
impl WorkingMemory {
/// Create a working memory with the given capacity.
pub fn new(max_size: usize) -> Self {
Self {
items: Vec::new(),
max_size,
}
}
/// Insert an item. If the buffer is full the item with the lowest
/// importance is evicted first.
pub fn add(&mut self, item: MemoryItem) {
if self.items.len() >= self.max_size {
// Evict least important.
if let Some((idx, _)) = self
.items
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.importance
.partial_cmp(&b.importance)
.unwrap_or(std::cmp::Ordering::Equal)
})
{
self.items.remove(idx);
}
}
self.items.push(item);
}
/// Retrieve an item by key, incrementing its access count.
pub fn get(&mut self, key: &str) -> Option<&MemoryItem> {
if let Some(item) = self.items.iter_mut().find(|i| i.key == key) {
item.access_count += 1;
Some(item)
} else {
None
}
}
/// Remove all items.
pub fn clear(&mut self) {
self.items.clear();
}
/// Current number of items.
pub fn len(&self) -> usize {
self.items.len()
}
/// Whether the buffer is empty.
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
}
// ---------------------------------------------------------------------------
// Episodic memory
// ---------------------------------------------------------------------------
/// A single episode consisting of percepts, actions, and a scalar reward.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Episode {
pub percepts: Vec<Vec<f64>>,
pub actions: Vec<String>,
pub reward: f64,
pub timestamp: i64,
}
/// Stores temporally ordered episodes and supports similarity recall.
#[derive(Debug, Clone, Default)]
pub struct EpisodicMemory {
episodes: Vec<Episode>,
}
impl EpisodicMemory {
pub fn new() -> Self {
Self::default()
}
/// Record a new episode.
pub fn store(&mut self, episode: Episode) {
self.episodes.push(episode);
}
/// Recall the `k` most similar episodes to `query` using dot-product
/// similarity on the flattened percept vectors.
pub fn recall_similar(&self, query: &[f64], k: usize) -> Vec<&Episode> {
let mut scored: Vec<(f64, &Episode)> = self
.episodes
.iter()
.map(|ep| {
let flat: Vec<f64> = ep.percepts.iter().flatten().copied().collect();
let sim = dot_product(query, &flat);
(sim, ep)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(k).map(|(_, ep)| ep).collect()
}
/// Number of stored episodes.
pub fn len(&self) -> usize {
self.episodes.len()
}
/// Whether the store is empty.
pub fn is_empty(&self) -> bool {
self.episodes.is_empty()
}
}
// ---------------------------------------------------------------------------
// Semantic memory
// ---------------------------------------------------------------------------
/// Long-term concept storage mapping names to embedding vectors.
#[derive(Debug, Clone, Default)]
pub struct SemanticMemory {
concepts: HashMap<String, Vec<f64>>,
}
impl SemanticMemory {
pub fn new() -> Self {
Self::default()
}
/// Store a concept with the given name and embedding.
pub fn store(&mut self, name: &str, embedding: Vec<f64>) {
self.concepts.insert(name.to_string(), embedding);
}
/// Retrieve the embedding for a concept.
pub fn retrieve(&self, name: &str) -> Option<&Vec<f64>> {
self.concepts.get(name)
}
/// Find the `k` concepts most similar to `query` (dot-product).
pub fn find_similar(&self, query: &[f64], k: usize) -> Vec<(&str, f64)> {
let mut scored: Vec<(&str, f64)> = self
.concepts
.iter()
.map(|(name, emb)| (name.as_str(), dot_product(query, emb)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(k).collect()
}
/// Number of stored concepts.
pub fn len(&self) -> usize {
self.concepts.len()
}
/// Whether the store is empty.
pub fn is_empty(&self) -> bool {
self.concepts.is_empty()
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Dot product of two slices (truncated to the shorter length).
fn dot_product(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
// -- Working memory ---------------------------------------------------
#[test]
fn test_working_memory_add_get() {
let mut wm = WorkingMemory::new(5);
wm.add(MemoryItem {
key: "obj1".into(),
data: vec![1.0, 2.0],
importance: 0.8,
timestamp: 100,
access_count: 0,
});
let item = wm.get("obj1").unwrap();
assert_eq!(item.access_count, 1);
}
#[test]
fn test_working_memory_eviction() {
let mut wm = WorkingMemory::new(2);
wm.add(MemoryItem {
key: "a".into(),
data: vec![],
importance: 0.1,
timestamp: 1,
access_count: 0,
});
wm.add(MemoryItem {
key: "b".into(),
data: vec![],
importance: 0.9,
timestamp: 2,
access_count: 0,
});
wm.add(MemoryItem {
key: "c".into(),
data: vec![],
importance: 0.5,
timestamp: 3,
access_count: 0,
});
assert_eq!(wm.len(), 2);
// "a" (importance 0.1) should have been evicted.
assert!(wm.get("a").is_none());
assert!(wm.get("b").is_some());
}
#[test]
fn test_working_memory_clear() {
let mut wm = WorkingMemory::new(5);
wm.add(MemoryItem {
key: "x".into(),
data: vec![],
importance: 1.0,
timestamp: 0,
access_count: 0,
});
assert!(!wm.is_empty());
wm.clear();
assert!(wm.is_empty());
}
#[test]
fn test_working_memory_get_missing() {
let mut wm = WorkingMemory::new(5);
assert!(wm.get("nonexistent").is_none());
}
// -- Episodic memory --------------------------------------------------
#[test]
fn test_episodic_store_recall() {
let mut em = EpisodicMemory::new();
em.store(Episode {
percepts: vec![vec![1.0, 0.0, 0.0]],
actions: vec!["move".into()],
reward: 1.0,
timestamp: 100,
});
em.store(Episode {
percepts: vec![vec![0.0, 1.0, 0.0]],
actions: vec!["turn".into()],
reward: 0.5,
timestamp: 200,
});
let results = em.recall_similar(&[1.0, 0.0, 0.0], 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].actions[0], "move");
}
#[test]
fn test_episodic_empty_recall() {
let em = EpisodicMemory::new();
let results = em.recall_similar(&[1.0], 5);
assert!(results.is_empty());
}
#[test]
fn test_episodic_len() {
let mut em = EpisodicMemory::new();
assert!(em.is_empty());
em.store(Episode {
percepts: vec![],
actions: vec![],
reward: 0.0,
timestamp: 0,
});
assert_eq!(em.len(), 1);
}
// -- Semantic memory --------------------------------------------------
#[test]
fn test_semantic_store_retrieve() {
let mut sm = SemanticMemory::new();
sm.store("cup", vec![1.0, 0.0, 0.0]);
let emb = sm.retrieve("cup").unwrap();
assert_eq!(emb, &vec![1.0, 0.0, 0.0]);
}
#[test]
fn test_semantic_find_similar() {
let mut sm = SemanticMemory::new();
sm.store("cup", vec![1.0, 0.0, 0.0]);
sm.store("plate", vec![0.9, 0.1, 0.0]);
sm.store("ball", vec![0.0, 0.0, 1.0]);
let results = sm.find_similar(&[1.0, 0.0, 0.0], 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "cup");
}
#[test]
fn test_semantic_retrieve_missing() {
let sm = SemanticMemory::new();
assert!(sm.retrieve("nothing").is_none());
}
#[test]
fn test_semantic_len() {
let mut sm = SemanticMemory::new();
assert!(sm.is_empty());
sm.store("a", vec![]);
assert_eq!(sm.len(), 1);
}
// -- Helpers ----------------------------------------------------------
#[test]
fn test_dot_product() {
assert!((dot_product(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]) - 32.0).abs() < 1e-9);
}
}

View File

@@ -0,0 +1,34 @@
//! Cognitive architecture for autonomous robot behavior.
//!
//! This module provides a layered cognitive system comprising:
//! - **Behavior trees** for composable task execution
//! - **Cognitive core** implementing a perceive-think-act-learn loop
//! - **Decision engine** for multi-criteria utility-based action selection
//! - **Memory system** with working, episodic, and semantic tiers
//! - **Skill learning** for acquiring and refining motor skills
//! - **Swarm intelligence** for multi-robot coordination
//! - **World model** for internal environment representation
pub mod behavior_tree;
pub mod cognitive_core;
pub mod decision_engine;
pub mod memory_system;
pub mod skill_learning;
pub mod swarm_intelligence;
pub mod world_model;
pub use behavior_tree::{
BehaviorContext, BehaviorNode, BehaviorStatus, BehaviorTree, DecoratorType,
};
pub use cognitive_core::{
ActionCommand, ActionType, CognitiveConfig, CognitiveCore, CognitiveMode, CognitiveState,
Decision, Outcome, Percept,
};
pub use decision_engine::{ActionOption, DecisionConfig, DecisionEngine};
pub use memory_system::{Episode, EpisodicMemory, MemoryItem, SemanticMemory, WorkingMemory};
pub use skill_learning::{Demonstration, Skill, SkillLibrary};
pub use swarm_intelligence::{
ConsensusResult, Formation, FormationType, RobotCapabilities, SwarmConfig, SwarmCoordinator,
SwarmTask, TaskAssignment,
};
pub use world_model::{PredictedState, TrackedObject, WorldModel};

View File

@@ -0,0 +1,225 @@
//! Skill acquisition via learning from demonstration.
//!
//! Robots can observe demonstrations (trajectories with timestamps),
//! generalise a skill by averaging, and progressively improve confidence
//! through execution feedback.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
/// A single demonstration of a skill (e.g., a recorded trajectory).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Demonstration {
pub trajectory: Vec<[f64; 3]>,
pub timestamps: Vec<i64>,
pub metadata: String,
}
/// A learned skill derived from one or more demonstrations.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Skill {
pub name: String,
pub trajectory: Vec<[f64; 3]>,
pub confidence: f64,
pub execution_count: u64,
}
// ---------------------------------------------------------------------------
// Library
// ---------------------------------------------------------------------------
/// A library of learned skills keyed by name.
#[derive(Debug, Clone, Default)]
pub struct SkillLibrary {
skills: HashMap<String, Skill>,
}
impl SkillLibrary {
pub fn new() -> Self {
Self::default()
}
/// Learn a skill from one or more demonstrations by averaging their
/// trajectories point-by-point. The resulting trajectory length equals
/// the shortest demonstration.
/// # Panics
///
/// Returns early with a zero-confidence skill if `demos` is empty.
pub fn learn_from_demonstration(&mut self, name: &str, demos: &[Demonstration]) -> Skill {
if demos.is_empty() {
let skill = Skill {
name: name.to_string(),
trajectory: Vec::new(),
confidence: 0.0,
execution_count: 0,
};
self.skills.insert(name.to_string(), skill.clone());
return skill;
}
let min_len = demos.iter().map(|d| d.trajectory.len()).min().unwrap_or(0);
let mut avg_traj: Vec<[f64; 3]> = Vec::with_capacity(min_len);
let n = demos.len() as f64;
for i in 0..min_len {
let mut sum = [0.0_f64; 3];
for demo in demos {
sum[0] += demo.trajectory[i][0];
sum[1] += demo.trajectory[i][1];
sum[2] += demo.trajectory[i][2];
}
avg_traj.push([sum[0] / n, sum[1] / n, sum[2] / n]);
}
let confidence = 1.0 - (1.0 / (demos.len() as f64 + 1.0));
let skill = Skill {
name: name.to_string(),
trajectory: avg_traj,
confidence,
execution_count: 0,
};
self.skills.insert(name.to_string(), skill.clone());
skill
}
/// Execute a named skill, returning its trajectory and incrementing
/// the execution count. Returns `None` if the skill is not found.
pub fn execute_skill(&mut self, name: &str) -> Option<Vec<[f64; 3]>> {
if let Some(skill) = self.skills.get_mut(name) {
skill.execution_count += 1;
Some(skill.trajectory.clone())
} else {
None
}
}
/// Adjust a skill's confidence based on external feedback.
/// Positive feedback increases confidence; negative decreases it.
/// Confidence is clamped to [0.0, 1.0].
pub fn improve_skill(&mut self, name: &str, feedback: f64) {
if let Some(skill) = self.skills.get_mut(name) {
skill.confidence = (skill.confidence + feedback).clamp(0.0, 1.0);
}
}
/// Look up a skill by name.
pub fn get(&self, name: &str) -> Option<&Skill> {
self.skills.get(name)
}
/// Number of skills in the library.
pub fn len(&self) -> usize {
self.skills.len()
}
/// Whether the library is empty.
pub fn is_empty(&self) -> bool {
self.skills.is_empty()
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn demo(pts: Vec<[f64; 3]>) -> Demonstration {
let n = pts.len();
Demonstration {
trajectory: pts,
timestamps: (0..n as i64).collect(),
metadata: String::new(),
}
}
#[test]
fn test_learn_single_demo() {
let mut lib = SkillLibrary::new();
let skill = lib.learn_from_demonstration("wave", &[demo(vec![[1.0, 2.0, 3.0]])]);
assert_eq!(skill.trajectory, vec![[1.0, 2.0, 3.0]]);
assert!(skill.confidence > 0.0);
assert_eq!(lib.len(), 1);
}
#[test]
fn test_learn_multiple_demos_averages() {
let mut lib = SkillLibrary::new();
let d1 = demo(vec![[0.0, 0.0, 0.0], [2.0, 2.0, 2.0]]);
let d2 = demo(vec![[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]);
let skill = lib.learn_from_demonstration("reach", &[d1, d2]);
assert_eq!(skill.trajectory.len(), 2);
assert!((skill.trajectory[0][0] - 1.0).abs() < 1e-9);
assert!((skill.trajectory[1][0] - 3.0).abs() < 1e-9);
}
#[test]
fn test_execute_increments_count() {
let mut lib = SkillLibrary::new();
lib.learn_from_demonstration("grab", &[demo(vec![[0.0, 0.0, 0.0]])]);
let traj = lib.execute_skill("grab");
assert!(traj.is_some());
assert_eq!(lib.get("grab").unwrap().execution_count, 1);
lib.execute_skill("grab");
assert_eq!(lib.get("grab").unwrap().execution_count, 2);
}
#[test]
fn test_execute_missing_skill() {
let mut lib = SkillLibrary::new();
assert!(lib.execute_skill("nonexistent").is_none());
}
#[test]
fn test_improve_skill() {
let mut lib = SkillLibrary::new();
lib.learn_from_demonstration("push", &[demo(vec![[1.0, 0.0, 0.0]])]);
let initial = lib.get("push").unwrap().confidence;
lib.improve_skill("push", 0.1);
assert!(lib.get("push").unwrap().confidence > initial);
}
#[test]
fn test_improve_skill_clamp() {
let mut lib = SkillLibrary::new();
lib.learn_from_demonstration("pull", &[demo(vec![[0.0, 0.0, 0.0]])]);
lib.improve_skill("pull", 10.0);
assert!((lib.get("pull").unwrap().confidence - 1.0).abs() < 1e-9);
lib.improve_skill("pull", -20.0);
assert!((lib.get("pull").unwrap().confidence).abs() < 1e-9);
}
#[test]
fn test_different_length_demos() {
let mut lib = SkillLibrary::new();
let d1 = demo(vec![[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]);
let d2 = demo(vec![[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]]);
let skill = lib.learn_from_demonstration("mixed", &[d1, d2]);
// Uses min length = 2
assert_eq!(skill.trajectory.len(), 2);
}
#[test]
fn test_confidence_increases_with_more_demos() {
let mut lib = SkillLibrary::new();
let s1 = lib.learn_from_demonstration("s1", &[demo(vec![[0.0, 0.0, 0.0]])]);
let s2 = lib.learn_from_demonstration(
"s2",
&[
demo(vec![[0.0, 0.0, 0.0]]),
demo(vec![[1.0, 1.0, 1.0]]),
demo(vec![[2.0, 2.0, 2.0]]),
],
);
assert!(s2.confidence > s1.confidence);
}
}

View File

@@ -0,0 +1,428 @@
//! Multi-robot swarm coordination.
//!
//! Provides formation computation, capability-based task assignment, and
//! simple majority consensus for distributed decision making.
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
// ---------------------------------------------------------------------------
// Formation
// ---------------------------------------------------------------------------
/// Types of spatial formations a swarm can adopt.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum FormationType {
Line,
Circle,
Grid,
Custom(Vec<[f64; 3]>),
}
/// A formation specification: type, spacing, and center point.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Formation {
pub formation_type: FormationType,
pub spacing: f64,
pub center: [f64; 3],
}
// ---------------------------------------------------------------------------
// Robots & tasks
// ---------------------------------------------------------------------------
/// Capabilities advertised by a single robot in the swarm.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RobotCapabilities {
pub id: u64,
pub max_speed: f64,
pub payload: f64,
pub sensors: Vec<String>,
}
/// A task to be assigned to one or more robots.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SwarmTask {
pub id: u64,
pub description: String,
pub location: [f64; 3],
pub required_capabilities: Vec<String>,
pub priority: u8,
}
/// The result of assigning a task to a robot.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TaskAssignment {
pub robot_id: u64,
pub task_id: u64,
pub estimated_completion: f64,
}
// ---------------------------------------------------------------------------
// Consensus
// ---------------------------------------------------------------------------
/// The result of a consensus vote among swarm members.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ConsensusResult {
pub proposal: String,
pub votes_for: usize,
pub votes_against: usize,
pub accepted: bool,
}
// ---------------------------------------------------------------------------
// Config
// ---------------------------------------------------------------------------
/// Configuration for the swarm coordinator.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SwarmConfig {
pub max_robots: usize,
pub communication_range: f64,
pub consensus_threshold: f64,
}
impl Default for SwarmConfig {
fn default() -> Self {
Self {
max_robots: 10,
communication_range: 50.0,
consensus_threshold: 0.5,
}
}
}
// ---------------------------------------------------------------------------
// Coordinator
// ---------------------------------------------------------------------------
/// Coordinates a swarm of robots for formation, task assignment, and consensus.
#[derive(Debug, Clone)]
pub struct SwarmCoordinator {
config: SwarmConfig,
robots: HashMap<u64, RobotCapabilities>,
}
impl SwarmCoordinator {
/// Create a new coordinator with the given configuration.
pub fn new(config: SwarmConfig) -> Self {
Self {
config,
robots: HashMap::new(),
}
}
/// Register a robot's capabilities. Respects `max_robots`.
pub fn register_robot(&mut self, capabilities: RobotCapabilities) -> bool {
if self.robots.len() >= self.config.max_robots {
return false;
}
self.robots.insert(capabilities.id, capabilities);
true
}
/// Number of registered robots.
pub fn robot_count(&self) -> usize {
self.robots.len()
}
/// Assign tasks to robots using a greedy capability-matching strategy.
///
/// Tasks are processed in priority order (highest first). Each task is
/// assigned to the first unassigned robot that possesses all required
/// capabilities.
pub fn assign_tasks(&self, tasks: &[SwarmTask]) -> Vec<TaskAssignment> {
let mut sorted_tasks: Vec<&SwarmTask> = tasks.iter().collect();
sorted_tasks.sort_by(|a, b| b.priority.cmp(&a.priority));
let mut assigned_robots: HashSet<u64> = HashSet::new();
let mut assignments = Vec::new();
for task in &sorted_tasks {
for (id, caps) in &self.robots {
if assigned_robots.contains(id) {
continue;
}
let has_caps = task
.required_capabilities
.iter()
.all(|req| caps.sensors.contains(req));
if has_caps {
// Estimate completion as distance / speed.
let dx = task.location[0];
let dy = task.location[1];
let dz = task.location[2];
let dist = (dx * dx + dy * dy + dz * dz).sqrt();
let est = if caps.max_speed > 0.0 {
dist / caps.max_speed
} else {
f64::INFINITY
};
assignments.push(TaskAssignment {
robot_id: *id,
task_id: task.id,
estimated_completion: est,
});
assigned_robots.insert(*id);
break;
}
}
}
assignments
}
/// Compute target positions for each robot given a formation spec.
pub fn compute_formation(&self, formation: &Formation) -> Vec<[f64; 3]> {
let n = self.robots.len();
if n == 0 {
return Vec::new();
}
match &formation.formation_type {
FormationType::Line => (0..n)
.map(|i| {
let offset = (i as f64 - (n as f64 - 1.0) / 2.0) * formation.spacing;
[
formation.center[0] + offset,
formation.center[1],
formation.center[2],
]
})
.collect(),
FormationType::Circle => {
let radius = formation.spacing * n as f64 / (2.0 * std::f64::consts::PI);
(0..n)
.map(|i| {
let angle = 2.0 * std::f64::consts::PI * i as f64 / n as f64;
[
formation.center[0] + radius * angle.cos(),
formation.center[1] + radius * angle.sin(),
formation.center[2],
]
})
.collect()
}
FormationType::Grid => {
let cols = (n as f64).sqrt().ceil() as usize;
(0..n)
.map(|i| {
let row = i / cols;
let col = i % cols;
[
formation.center[0] + col as f64 * formation.spacing,
formation.center[1] + row as f64 * formation.spacing,
formation.center[2],
]
})
.collect()
}
FormationType::Custom(positions) => positions.clone(),
}
}
/// Run a simple majority consensus vote among all registered robots.
///
/// Each robot "votes" deterministically based on its ID parity (a
/// placeholder for real voting logic). The proposal is accepted when
/// the fraction of `for` votes meets the threshold.
pub fn propose_consensus(&self, proposal: &str) -> ConsensusResult {
let total = self.robots.len();
if total == 0 {
return ConsensusResult {
proposal: proposal.to_string(),
votes_for: 0,
votes_against: 0,
accepted: false,
};
}
// Deterministic placeholder vote: even IDs vote for, odd against.
let votes_for = self.robots.keys().filter(|id| *id % 2 == 0).count();
let votes_against = total - votes_for;
let ratio = votes_for as f64 / total as f64;
ConsensusResult {
proposal: proposal.to_string(),
votes_for,
votes_against,
accepted: ratio >= self.config.consensus_threshold,
}
}
/// Read-only access to the configuration.
pub fn config(&self) -> &SwarmConfig {
&self.config
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn make_robot(id: u64, sensors: Vec<&str>) -> RobotCapabilities {
RobotCapabilities {
id,
max_speed: 1.0,
payload: 5.0,
sensors: sensors.into_iter().map(String::from).collect(),
}
}
fn make_task(id: u64, caps: Vec<&str>, priority: u8) -> SwarmTask {
SwarmTask {
id,
description: format!("task_{}", id),
location: [3.0, 4.0, 0.0],
required_capabilities: caps.into_iter().map(String::from).collect(),
priority,
}
}
#[test]
fn test_register_robot() {
let mut coord = SwarmCoordinator::new(SwarmConfig::default());
assert!(coord.register_robot(make_robot(1, vec!["lidar"])));
assert_eq!(coord.robot_count(), 1);
}
#[test]
fn test_register_respects_max() {
let mut coord = SwarmCoordinator::new(SwarmConfig {
max_robots: 1,
..SwarmConfig::default()
});
assert!(coord.register_robot(make_robot(1, vec![])));
assert!(!coord.register_robot(make_robot(2, vec![])));
}
#[test]
fn test_assign_tasks_capability_match() {
let mut coord = SwarmCoordinator::new(SwarmConfig::default());
coord.register_robot(make_robot(1, vec!["camera"]));
coord.register_robot(make_robot(2, vec!["lidar"]));
let tasks = vec![make_task(10, vec!["lidar"], 5)];
let assignments = coord.assign_tasks(&tasks);
assert_eq!(assignments.len(), 1);
assert_eq!(assignments[0].robot_id, 2);
assert_eq!(assignments[0].task_id, 10);
}
#[test]
fn test_assign_no_capable_robot() {
let mut coord = SwarmCoordinator::new(SwarmConfig::default());
coord.register_robot(make_robot(1, vec!["camera"]));
let tasks = vec![make_task(10, vec!["sonar"], 5)];
let assignments = coord.assign_tasks(&tasks);
assert!(assignments.is_empty());
}
#[test]
fn test_line_formation() {
let mut coord = SwarmCoordinator::new(SwarmConfig::default());
coord.register_robot(make_robot(1, vec![]));
coord.register_robot(make_robot(2, vec![]));
coord.register_robot(make_robot(3, vec![]));
let formation = Formation {
formation_type: FormationType::Line,
spacing: 2.0,
center: [0.0, 0.0, 0.0],
};
let positions = coord.compute_formation(&formation);
assert_eq!(positions.len(), 3);
}
#[test]
fn test_circle_formation() {
let mut coord = SwarmCoordinator::new(SwarmConfig::default());
for i in 0..4 {
coord.register_robot(make_robot(i, vec![]));
}
let formation = Formation {
formation_type: FormationType::Circle,
spacing: 2.0,
center: [0.0, 0.0, 0.0],
};
let positions = coord.compute_formation(&formation);
assert_eq!(positions.len(), 4);
}
#[test]
fn test_consensus_accepted() {
let mut coord = SwarmCoordinator::new(SwarmConfig {
consensus_threshold: 0.5,
..SwarmConfig::default()
});
// Even IDs vote for.
coord.register_robot(make_robot(2, vec![]));
coord.register_robot(make_robot(4, vec![]));
coord.register_robot(make_robot(5, vec![]));
let result = coord.propose_consensus("explore area B");
assert_eq!(result.votes_for, 2);
assert_eq!(result.votes_against, 1);
assert!(result.accepted);
}
#[test]
fn test_consensus_rejected() {
let mut coord = SwarmCoordinator::new(SwarmConfig {
consensus_threshold: 0.8,
..SwarmConfig::default()
});
coord.register_robot(make_robot(1, vec![])); // odd -> against
coord.register_robot(make_robot(3, vec![])); // odd -> against
coord.register_robot(make_robot(2, vec![])); // even -> for
let result = coord.propose_consensus("attack");
assert!(!result.accepted);
}
#[test]
fn test_consensus_empty_swarm() {
let coord = SwarmCoordinator::new(SwarmConfig::default());
let result = coord.propose_consensus("noop");
assert!(!result.accepted);
assert_eq!(result.votes_for, 0);
}
#[test]
fn test_grid_formation() {
let mut coord = SwarmCoordinator::new(SwarmConfig::default());
for i in 0..4 {
coord.register_robot(make_robot(i, vec![]));
}
let formation = Formation {
formation_type: FormationType::Grid,
spacing: 1.0,
center: [0.0, 0.0, 0.0],
};
let positions = coord.compute_formation(&formation);
assert_eq!(positions.len(), 4);
// 4 robots => 2x2 grid
assert!((positions[0][0] - 0.0).abs() < 1e-9);
assert!((positions[1][0] - 1.0).abs() < 1e-9);
}
#[test]
fn test_custom_formation() {
let mut coord = SwarmCoordinator::new(SwarmConfig::default());
coord.register_robot(make_robot(1, vec![]));
let custom_pos = vec![[10.0, 20.0, 0.0]];
let formation = Formation {
formation_type: FormationType::Custom(custom_pos.clone()),
spacing: 0.0,
center: [0.0, 0.0, 0.0],
};
let positions = coord.compute_formation(&formation);
assert_eq!(positions, custom_pos);
}
}

View File

@@ -0,0 +1,247 @@
//! Internal world representation with object tracking, occupancy grid,
//! and linear state prediction.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
/// An object being tracked in the world model.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TrackedObject {
pub id: u64,
pub position: [f64; 3],
pub velocity: [f64; 3],
pub last_seen: i64,
pub confidence: f64,
pub label: String,
}
/// Predicted future state of a tracked object.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PredictedState {
pub position: [f64; 3],
pub confidence: f64,
pub time_horizon: f64,
}
// ---------------------------------------------------------------------------
// World model
// ---------------------------------------------------------------------------
/// Maintains a spatial model of the environment: tracked objects and a 2-D
/// occupancy grid.
#[derive(Debug, Clone)]
pub struct WorldModel {
tracked_objects: HashMap<u64, TrackedObject>,
occupancy: Vec<Vec<f32>>,
grid_size: usize,
grid_resolution: f64,
}
impl WorldModel {
/// Create a new world model with a square occupancy grid.
///
/// * `grid_size` -- number of cells along each axis.
/// * `resolution` -- real-world size of each cell.
pub fn new(grid_size: usize, resolution: f64) -> Self {
Self {
tracked_objects: HashMap::new(),
occupancy: vec![vec![0.0_f32; grid_size]; grid_size],
grid_size,
grid_resolution: resolution,
}
}
/// Insert or update a tracked object.
pub fn update_object(&mut self, obj: TrackedObject) {
self.tracked_objects.insert(obj.id, obj);
}
/// Remove objects that have not been observed for longer than `max_age`
/// (microseconds). Returns the number of removed objects.
pub fn remove_stale_objects(&mut self, current_time: i64, max_age: i64) -> usize {
let before = self.tracked_objects.len();
self.tracked_objects
.retain(|_, obj| (current_time - obj.last_seen) <= max_age);
before - self.tracked_objects.len()
}
/// Predict the future state of the object with the given ID using
/// constant-velocity extrapolation over `dt` seconds.
///
/// Confidence decays linearly with `dt`.
pub fn predict_state(&self, object_id: u64, dt: f64) -> Option<PredictedState> {
let obj = self.tracked_objects.get(&object_id)?;
let predicted_pos = [
obj.position[0] + obj.velocity[0] * dt,
obj.position[1] + obj.velocity[1] * dt,
obj.position[2] + obj.velocity[2] * dt,
];
// Confidence decays with time horizon (halves every 5 s).
let decay = (1.0 + dt / 5.0).recip();
Some(PredictedState {
position: predicted_pos,
confidence: obj.confidence * decay,
time_horizon: dt,
})
}
/// Set the occupancy value at grid cell `(x, y)`.
///
/// Values are typically in `[0.0, 1.0]` where 0 is free and 1 is
/// occupied. Out-of-bounds writes are silently ignored.
pub fn update_occupancy(&mut self, x: usize, y: usize, value: f32) {
if x < self.grid_size && y < self.grid_size {
self.occupancy[y][x] = value;
}
}
/// Read the occupancy value at grid cell `(x, y)`.
pub fn get_occupancy(&self, x: usize, y: usize) -> Option<f32> {
if x < self.grid_size && y < self.grid_size {
Some(self.occupancy[y][x])
} else {
None
}
}
/// Check whether the straight-line path between two grid cells is free
/// (all intermediate cells have occupancy < 0.5).
///
/// Uses Bresenham-like sampling along the line.
pub fn is_path_clear(&self, from: [usize; 2], to: [usize; 2]) -> bool {
let steps = {
let dx = (to[0] as isize - from[0] as isize).unsigned_abs();
let dy = (to[1] as isize - from[1] as isize).unsigned_abs();
dx.max(dy).max(1)
};
for i in 0..=steps {
let t = i as f64 / steps as f64;
let x = (from[0] as f64 + t * (to[0] as f64 - from[0] as f64)).round() as usize;
let y = (from[1] as f64 + t * (to[1] as f64 - from[1] as f64)).round() as usize;
if x >= self.grid_size || y >= self.grid_size {
return false;
}
if self.occupancy[y][x] >= 0.5 {
return false;
}
}
true
}
/// Number of currently tracked objects.
pub fn object_count(&self) -> usize {
self.tracked_objects.len()
}
/// Retrieve a tracked object by ID.
pub fn get_object(&self, id: u64) -> Option<&TrackedObject> {
self.tracked_objects.get(&id)
}
/// Grid size (cells per axis).
pub fn grid_size(&self) -> usize {
self.grid_size
}
/// Grid cell resolution in world units.
pub fn grid_resolution(&self) -> f64 {
self.grid_resolution
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn sample_object(id: u64, pos: [f64; 3], vel: [f64; 3], last_seen: i64) -> TrackedObject {
TrackedObject {
id,
position: pos,
velocity: vel,
last_seen,
confidence: 0.9,
label: format!("obj_{}", id),
}
}
#[test]
fn test_update_and_get_object() {
let mut wm = WorldModel::new(10, 0.1);
wm.update_object(sample_object(1, [1.0, 2.0, 3.0], [0.0; 3], 100));
assert_eq!(wm.object_count(), 1);
let obj = wm.get_object(1).unwrap();
assert_eq!(obj.position, [1.0, 2.0, 3.0]);
}
#[test]
fn test_remove_stale_objects() {
let mut wm = WorldModel::new(10, 0.1);
wm.update_object(sample_object(1, [0.0; 3], [0.0; 3], 100));
wm.update_object(sample_object(2, [0.0; 3], [0.0; 3], 500));
let removed = wm.remove_stale_objects(600, 200);
assert_eq!(removed, 1);
assert!(wm.get_object(1).is_none());
assert!(wm.get_object(2).is_some());
}
#[test]
fn test_predict_state() {
let mut wm = WorldModel::new(10, 0.1);
wm.update_object(sample_object(1, [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], 0));
let pred = wm.predict_state(1, 2.0).unwrap();
assert!((pred.position[0] - 2.0).abs() < 1e-9);
assert!((pred.time_horizon - 2.0).abs() < 1e-9);
assert!(pred.confidence < 0.9); // Decayed
}
#[test]
fn test_predict_missing_object() {
let wm = WorldModel::new(10, 0.1);
assert!(wm.predict_state(99, 1.0).is_none());
}
#[test]
fn test_occupancy_update_and_read() {
let mut wm = WorldModel::new(5, 0.5);
wm.update_occupancy(2, 3, 0.8);
assert!((wm.get_occupancy(2, 3).unwrap() - 0.8).abs() < f32::EPSILON);
assert!((wm.get_occupancy(0, 0).unwrap()).abs() < f32::EPSILON);
}
#[test]
fn test_occupancy_out_of_bounds() {
let mut wm = WorldModel::new(5, 0.5);
wm.update_occupancy(10, 10, 1.0); // Should be silently ignored.
assert!(wm.get_occupancy(10, 10).is_none());
}
#[test]
fn test_path_clear() {
let mut wm = WorldModel::new(10, 0.1);
assert!(wm.is_path_clear([0, 0], [9, 0]));
wm.update_occupancy(5, 0, 1.0);
assert!(!wm.is_path_clear([0, 0], [9, 0]));
}
#[test]
fn test_path_clear_diagonal() {
let wm = WorldModel::new(10, 0.1);
assert!(wm.is_path_clear([0, 0], [9, 9]));
}
#[test]
fn test_grid_properties() {
let wm = WorldModel::new(20, 0.05);
assert_eq!(wm.grid_size(), 20);
assert!((wm.grid_resolution() - 0.05).abs() < 1e-9);
}
}

View File

@@ -0,0 +1,976 @@
//! Robotics domain for cross-domain transfer learning.
//!
//! Implements [`ruvector_domain_expansion::Domain`] so that robotics tasks
//! (perception, planning, skill learning) participate in the domain-expansion
//! engine's transfer-learning pipeline alongside Rust synthesis, structured
//! planning, and tool orchestration.
//!
//! ## Task categories
//!
//! | Category | Description |
//! |---|---|
//! | `PointCloudClustering` | Cluster a synthetic point cloud into objects |
//! | `ObstacleAvoidance` | Plan a collision-free path through obstacles |
//! | `SceneGraphConstruction` | Build a scene graph from a set of objects |
//! | `SkillSequencing` | Select and sequence learned motor skills |
//! | `SwarmFormation` | Assign robots to formation positions |
//!
//! ## Transfer synergies
//!
//! - **Planning ↔ Robotics**: Both decompose goals into ordered steps with
//! resource constraints. Robotics adds spatial reasoning.
//! - **Tool Orchestration ↔ Robotics**: Swarm coordination is structurally
//! similar to multi-tool pipeline coordination.
//! - **Rust Synthesis ↔ Robotics**: Algorithmic solutions (search, sort,
//! graph traversal) directly appear in perception and planning kernels.
use rand::Rng;
use ruvector_domain_expansion::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
use serde::{Deserialize, Serialize};
const EMBEDDING_DIM: usize = 64;
// ---------------------------------------------------------------------------
// Task specification types
// ---------------------------------------------------------------------------
/// Categories of robotics tasks.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RoboticsCategory {
/// Cluster a point cloud into distinct objects.
PointCloudClustering,
/// Plan a path that avoids all obstacles.
ObstacleAvoidance,
/// Build a scene graph from detected objects.
SceneGraphConstruction,
/// Select and order skills to achieve a goal.
SkillSequencing,
/// Assign N robots to formation positions.
SwarmFormation,
}
/// A synthetic obstacle for avoidance tasks.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskObstacle {
pub center: [f64; 3],
pub radius: f64,
}
/// A skill reference for sequencing tasks.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskSkill {
pub name: String,
pub preconditions: Vec<String>,
pub effects: Vec<String>,
pub cost: f32,
}
/// Specification for a robotics task.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoboticsTaskSpec {
pub category: RoboticsCategory,
pub description: String,
/// Number of points (clustering), obstacles (avoidance), objects (scene
/// graph), skills (sequencing), or robots (formation).
pub size: usize,
/// Spatial extent of the task environment.
pub world_bounds: [f64; 3],
/// Obstacles (used by avoidance and scene graph tasks).
pub obstacles: Vec<TaskObstacle>,
/// Skills (used by sequencing tasks).
pub skills: Vec<TaskSkill>,
/// Start position (avoidance).
pub start: Option<[f64; 3]>,
/// Goal position (avoidance).
pub goal: Option<[f64; 3]>,
/// Desired formation type name (swarm).
pub formation: Option<String>,
}
/// A parsed robotics solution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoboticsSolution {
/// Waypoints for path or formation positions.
pub waypoints: Vec<[f64; 3]>,
/// Cluster assignments (point index → cluster id).
pub cluster_ids: Vec<usize>,
/// Ordered skill names.
pub skill_sequence: Vec<String>,
/// Scene graph edges as (from, to) pairs.
pub edges: Vec<(usize, usize)>,
}
// ---------------------------------------------------------------------------
// Domain implementation
// ---------------------------------------------------------------------------
/// Robotics domain for the domain-expansion engine.
pub struct RoboticsDomain {
id: DomainId,
}
impl RoboticsDomain {
/// Create a new robotics domain.
pub fn new() -> Self {
Self {
id: DomainId("robotics".to_string()),
}
}
// -- task generators ---------------------------------------------------
fn gen_clustering(&self, difficulty: f32, rng: &mut impl Rng) -> RoboticsTaskSpec {
let num_clusters = if difficulty < 0.3 { 2 } else if difficulty < 0.7 { 5 } else { 10 };
let pts_per_cluster = if difficulty < 0.3 { 10 } else { 20 };
let spread = if difficulty < 0.5 { 0.5 } else { 2.0 };
let mut obstacles = Vec::new();
for _ in 0..num_clusters {
obstacles.push(TaskObstacle {
center: [
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
rng.gen_range(0.0..5.0),
],
radius: spread,
});
}
RoboticsTaskSpec {
category: RoboticsCategory::PointCloudClustering,
description: format!(
"Cluster {} points into {} groups (spread={:.1}).",
num_clusters * pts_per_cluster,
num_clusters,
spread,
),
size: num_clusters * pts_per_cluster,
world_bounds: [20.0, 20.0, 5.0],
obstacles,
skills: Vec::new(),
start: None,
goal: None,
formation: None,
}
}
fn gen_avoidance(&self, difficulty: f32, rng: &mut impl Rng) -> RoboticsTaskSpec {
let num_obstacles = if difficulty < 0.3 { 3 } else if difficulty < 0.7 { 8 } else { 15 };
let mut obstacles = Vec::new();
for _ in 0..num_obstacles {
obstacles.push(TaskObstacle {
center: [
rng.gen_range(1.0..9.0),
rng.gen_range(1.0..9.0),
0.0,
],
radius: rng.gen_range(0.3..1.5),
});
}
RoboticsTaskSpec {
category: RoboticsCategory::ObstacleAvoidance,
description: format!(
"Plan a collision-free path through {} obstacles.",
num_obstacles,
),
size: num_obstacles,
world_bounds: [10.0, 10.0, 1.0],
obstacles,
skills: Vec::new(),
start: Some([0.0, 0.0, 0.0]),
goal: Some([10.0, 10.0, 0.0]),
formation: None,
}
}
fn gen_scene_graph(&self, difficulty: f32, rng: &mut impl Rng) -> RoboticsTaskSpec {
let num_objects = if difficulty < 0.3 { 3 } else if difficulty < 0.7 { 8 } else { 15 };
let mut obstacles = Vec::new();
for _ in 0..num_objects {
obstacles.push(TaskObstacle {
center: [
rng.gen_range(0.0..20.0),
rng.gen_range(0.0..20.0),
rng.gen_range(0.0..5.0),
],
radius: rng.gen_range(0.5..2.0),
});
}
RoboticsTaskSpec {
category: RoboticsCategory::SceneGraphConstruction,
description: format!(
"Build a scene graph with spatial relations for {} objects.",
num_objects,
),
size: num_objects,
world_bounds: [20.0, 20.0, 5.0],
obstacles,
skills: Vec::new(),
start: None,
goal: None,
formation: None,
}
}
fn gen_skill_sequencing(&self, difficulty: f32, _rng: &mut impl Rng) -> RoboticsTaskSpec {
let skill_names = if difficulty < 0.3 {
vec!["approach", "grasp", "lift"]
} else if difficulty < 0.7 {
vec!["scan", "approach", "align", "grasp", "lift", "place"]
} else {
vec![
"scan", "classify", "approach", "align", "grasp",
"lift", "navigate", "place", "verify", "retreat",
]
};
let skills: Vec<TaskSkill> = skill_names
.iter()
.enumerate()
.map(|(i, &name)| TaskSkill {
name: name.to_string(),
preconditions: if i > 0 {
vec![format!("{}_done", skill_names[i - 1])]
} else {
Vec::new()
},
effects: vec![format!("{}_done", name)],
cost: (i as f32 + 1.0) * 0.5,
})
.collect();
RoboticsTaskSpec {
category: RoboticsCategory::SkillSequencing,
description: format!(
"Sequence {} skills to achieve a pick-and-place goal.",
skills.len(),
),
size: skills.len(),
world_bounds: [10.0, 10.0, 3.0],
obstacles: Vec::new(),
skills,
start: None,
goal: None,
formation: None,
}
}
fn gen_swarm_formation(&self, difficulty: f32, _rng: &mut impl Rng) -> RoboticsTaskSpec {
let num_robots = if difficulty < 0.3 { 4 } else if difficulty < 0.7 { 8 } else { 16 };
let formation = if difficulty < 0.5 { "circle" } else { "grid" };
RoboticsTaskSpec {
category: RoboticsCategory::SwarmFormation,
description: format!(
"Assign {} robots to a {} formation.",
num_robots, formation,
),
size: num_robots,
world_bounds: [20.0, 20.0, 1.0],
obstacles: Vec::new(),
skills: Vec::new(),
start: None,
goal: None,
formation: Some(formation.to_string()),
}
}
// -- evaluation helpers ------------------------------------------------
fn score_clustering(&self, spec: &RoboticsTaskSpec, sol: &RoboticsSolution) -> Evaluation {
let expected_clusters = spec.obstacles.len();
let mut notes = Vec::new();
if sol.cluster_ids.is_empty() {
return Evaluation::zero(vec!["No cluster assignments".into()]);
}
let actual_clusters = *sol.cluster_ids.iter().max().unwrap_or(&0) + 1;
let cluster_accuracy = if expected_clusters > 0 {
1.0 - ((actual_clusters as f32 - expected_clusters as f32).abs()
/ expected_clusters as f32)
.min(1.0)
} else {
0.0
};
let correctness = cluster_accuracy;
let efficiency = if sol.cluster_ids.len() == spec.size { 1.0 } else { 0.5 };
let elegance = if actual_clusters <= expected_clusters * 2 { 0.8 } else { 0.3 };
if (actual_clusters as i32 - expected_clusters as i32).unsigned_abs() > 2 {
notes.push(format!(
"Expected ~{} clusters, got {}",
expected_clusters, actual_clusters,
));
}
Evaluation {
score: (0.6 * correctness + 0.25 * efficiency + 0.15 * elegance).clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes,
}
}
fn score_avoidance(&self, spec: &RoboticsTaskSpec, sol: &RoboticsSolution) -> Evaluation {
let mut notes = Vec::new();
if sol.waypoints.is_empty() {
return Evaluation::zero(vec!["Empty path".into()]);
}
// Check start/goal proximity.
let start = spec.start.unwrap_or([0.0; 3]);
let goal = spec.goal.unwrap_or([10.0, 10.0, 0.0]);
let start_dist = dist3(&sol.waypoints[0], &start);
let goal_dist = sol.waypoints.last().map_or(f64::MAX, |wp| dist3(wp, &goal));
let reaches_goal = start_dist < 1.0 && goal_dist < 1.0;
// Check collisions.
let mut collisions = 0;
for wp in &sol.waypoints {
for obs in &spec.obstacles {
if dist3(wp, &obs.center) < obs.radius {
collisions += 1;
}
}
}
let correctness = if reaches_goal { 0.6 } else { 0.2 }
+ (1.0 - (collisions as f32 / (sol.waypoints.len() * spec.obstacles.len()).max(1) as f32).min(1.0)) * 0.4;
let efficiency = 1.0 - (sol.waypoints.len() as f32 / 100.0).min(1.0);
let elegance = if collisions == 0 { 0.9 } else { 0.3 };
if collisions > 0 {
notes.push(format!("{} collision(s) detected", collisions));
}
if !reaches_goal {
notes.push("Path does not reach goal".into());
}
Evaluation {
score: (0.6 * correctness + 0.25 * efficiency + 0.15 * elegance).clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes,
}
}
fn score_scene_graph(&self, spec: &RoboticsTaskSpec, sol: &RoboticsSolution) -> Evaluation {
let expected_objects = spec.obstacles.len();
let mut notes = Vec::new();
if sol.edges.is_empty() && expected_objects > 1 {
notes.push("No edges in scene graph".into());
}
// Check node coverage.
let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
for &(a, b) in &sol.edges {
seen.insert(a);
seen.insert(b);
}
let coverage = if expected_objects > 0 {
seen.len() as f32 / expected_objects as f32
} else {
1.0
};
let correctness = coverage.min(1.0);
let max_edges = if expected_objects >= 2 {
expected_objects * (expected_objects - 1) / 2
} else {
0
};
let efficiency = if sol.edges.len() <= max_edges {
0.9
} else {
0.5
};
let elegance = if coverage >= 0.8 { 0.8 } else { 0.4 };
Evaluation {
score: (0.6 * correctness + 0.25 * efficiency + 0.15 * elegance).clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes,
}
}
fn score_skill_sequence(&self, spec: &RoboticsTaskSpec, sol: &RoboticsSolution) -> Evaluation {
let mut notes = Vec::new();
if sol.skill_sequence.is_empty() {
return Evaluation::zero(vec!["Empty skill sequence".into()]);
}
// Check dependency ordering.
let mut violations = 0;
for (i, name) in sol.skill_sequence.iter().enumerate() {
if let Some(skill) = spec.skills.iter().find(|s| &s.name == name) {
for pre in &skill.preconditions {
// Precondition must appear earlier.
let pre_skill = pre.trim_end_matches("_done");
let pre_pos = sol.skill_sequence.iter().position(|s| s == pre_skill);
if let Some(pp) = pre_pos {
if pp >= i {
violations += 1;
notes.push(format!("{} before its precondition {}", name, pre_skill));
}
} else {
violations += 1;
notes.push(format!("Missing precondition {} for {}", pre_skill, name));
}
}
}
}
let expected_skills = spec.skills.len();
let coverage = sol.skill_sequence.len() as f32 / expected_skills.max(1) as f32;
let dep_penalty = violations as f32 / expected_skills.max(1) as f32;
let correctness = (coverage.min(1.0) * (1.0 - dep_penalty.min(1.0))).max(0.0);
let efficiency = if sol.skill_sequence.len() <= expected_skills + 2 { 0.9 } else { 0.5 };
let elegance = if violations == 0 { 0.9 } else { 0.3 };
Evaluation {
score: (0.6 * correctness + 0.25 * efficiency + 0.15 * elegance).clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes,
}
}
fn score_formation(&self, spec: &RoboticsTaskSpec, sol: &RoboticsSolution) -> Evaluation {
let expected_robots = spec.size;
let mut notes = Vec::new();
if sol.waypoints.is_empty() {
return Evaluation::zero(vec!["No formation positions".into()]);
}
let correctness = (sol.waypoints.len() as f32 / expected_robots.max(1) as f32).min(1.0);
// Check positions are within bounds.
let bounds = &spec.world_bounds;
let in_bounds = sol
.waypoints
.iter()
.filter(|w| {
w[0].abs() <= bounds[0] && w[1].abs() <= bounds[1] && w[2].abs() <= bounds[2]
})
.count() as f32
/ sol.waypoints.len().max(1) as f32;
let efficiency = in_bounds;
// Check for collisions between robots (min spacing).
let mut too_close = 0;
for i in 0..sol.waypoints.len() {
for j in (i + 1)..sol.waypoints.len() {
if dist3(&sol.waypoints[i], &sol.waypoints[j]) < 0.5 {
too_close += 1;
}
}
}
let elegance = if too_close == 0 { 0.9 } else { 0.4 };
if too_close > 0 {
notes.push(format!("{} robot pair(s) too close (<0.5m)", too_close));
}
Evaluation {
score: (0.6 * correctness + 0.25 * efficiency + 0.15 * elegance).clamp(0.0, 1.0),
correctness,
efficiency,
elegance,
constraint_results: Vec::new(),
notes,
}
}
// -- embedding ---------------------------------------------------------
fn extract_features(&self, solution: &Solution) -> Vec<f32> {
let content = &solution.content;
let mut features = vec![0.0f32; EMBEDDING_DIM];
// Parse the structured solution if present.
let sol: RoboticsSolution = serde_json::from_value(solution.data.clone())
.or_else(|_| serde_json::from_str(content))
.unwrap_or(RoboticsSolution {
waypoints: Vec::new(),
cluster_ids: Vec::new(),
skill_sequence: Vec::new(),
edges: Vec::new(),
});
// Feature 0-7: Solution structure.
features[0] = sol.waypoints.len() as f32 / 50.0;
features[1] = sol.cluster_ids.len() as f32 / 100.0;
features[2] = sol.skill_sequence.len() as f32 / 20.0;
features[3] = sol.edges.len() as f32 / 50.0;
// Unique clusters.
let unique_clusters: std::collections::HashSet<&usize> = sol.cluster_ids.iter().collect();
features[4] = unique_clusters.len() as f32 / 20.0;
// Unique skills.
let unique_skills: std::collections::HashSet<&String> = sol.skill_sequence.iter().collect();
features[5] = unique_skills.len() as f32 / 20.0;
// Spatial extent of waypoints.
if !sol.waypoints.is_empty() {
let max_dist = sol
.waypoints
.iter()
.map(|w| (w[0] * w[0] + w[1] * w[1] + w[2] * w[2]).sqrt() as f32)
.fold(0.0f32, f32::max);
features[6] = max_dist / 50.0;
}
// Feature 8-15: Text-based features (cross-domain compatible).
features[8] = content.matches("cluster").count() as f32 / 5.0;
features[9] = content.matches("obstacle").count() as f32 / 5.0;
features[10] = content.matches("path").count() as f32 / 5.0;
features[11] = content.matches("scene").count() as f32 / 3.0;
features[12] = content.matches("formation").count() as f32 / 3.0;
features[13] = content.matches("skill").count() as f32 / 5.0;
features[14] = content.matches("robot").count() as f32 / 5.0;
features[15] = content.matches("point").count() as f32 / 10.0;
// Feature 16-23: Spatial reasoning indicators.
features[16] = content.matches("distance").count() as f32 / 5.0;
features[17] = content.matches("position").count() as f32 / 5.0;
features[18] = content.matches("radius").count() as f32 / 3.0;
features[19] = content.matches("collision").count() as f32 / 3.0;
features[20] = content.matches("adjacent").count() as f32 / 3.0;
features[21] = content.matches("near").count() as f32 / 3.0;
features[22] = content.matches("velocity").count() as f32 / 3.0;
features[23] = content.matches("trajectory").count() as f32 / 3.0;
// Feature 32-39: Planning overlap (cross-domain with PlanningDomain).
features[32] = content.matches("allocate").count() as f32 / 3.0;
features[33] = content.matches("schedule").count() as f32 / 3.0;
features[34] = content.matches("constraint").count() as f32 / 3.0;
features[35] = content.matches("goal").count() as f32 / 3.0;
features[36] = content.matches("precondition").count() as f32 / 3.0;
features[37] = content.matches("parallel").count() as f32 / 3.0;
features[38] = content.matches("sequence").count() as f32 / 3.0;
features[39] = content.matches("assign").count() as f32 / 3.0;
// Feature 48-55: Orchestration overlap (cross-domain with ToolOrchestration).
features[48] = content.matches("pipeline").count() as f32 / 3.0;
features[49] = content.matches("sensor").count() as f32 / 3.0;
features[50] = content.matches("fuse").count() as f32 / 2.0;
features[51] = content.matches("detect").count() as f32 / 3.0;
features[52] = content.matches("track").count() as f32 / 3.0;
features[53] = content.matches("coordinate").count() as f32 / 3.0;
features[54] = content.matches("merge").count() as f32 / 3.0;
features[55] = content.matches("update").count() as f32 / 3.0;
// Normalize to unit length.
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for f in &mut features {
*f /= norm;
}
}
features
}
}
impl Default for RoboticsDomain {
fn default() -> Self {
Self::new()
}
}
impl Domain for RoboticsDomain {
fn id(&self) -> &DomainId {
&self.id
}
fn name(&self) -> &str {
"Cognitive Robotics"
}
fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec<Task> {
let mut rng = rand::thread_rng();
let difficulty = difficulty.clamp(0.0, 1.0);
(0..count)
.map(|i| {
let roll: f32 = rng.gen();
let spec = if roll < 0.2 {
self.gen_clustering(difficulty, &mut rng)
} else if roll < 0.4 {
self.gen_avoidance(difficulty, &mut rng)
} else if roll < 0.6 {
self.gen_scene_graph(difficulty, &mut rng)
} else if roll < 0.8 {
self.gen_skill_sequencing(difficulty, &mut rng)
} else {
self.gen_swarm_formation(difficulty, &mut rng)
};
Task {
id: format!("robotics_{}_d{:.0}", i, difficulty * 100.0),
domain_id: self.id.clone(),
difficulty,
spec: serde_json::to_value(&spec).unwrap_or_default(),
constraints: Vec::new(),
}
})
.collect()
}
fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation {
let spec: RoboticsTaskSpec = match serde_json::from_value(task.spec.clone()) {
Ok(s) => s,
Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]),
};
let sol: RoboticsSolution = serde_json::from_value(solution.data.clone())
.or_else(|_| serde_json::from_str(&solution.content))
.unwrap_or(RoboticsSolution {
waypoints: Vec::new(),
cluster_ids: Vec::new(),
skill_sequence: Vec::new(),
edges: Vec::new(),
});
match spec.category {
RoboticsCategory::PointCloudClustering => self.score_clustering(&spec, &sol),
RoboticsCategory::ObstacleAvoidance => self.score_avoidance(&spec, &sol),
RoboticsCategory::SceneGraphConstruction => self.score_scene_graph(&spec, &sol),
RoboticsCategory::SkillSequencing => self.score_skill_sequence(&spec, &sol),
RoboticsCategory::SwarmFormation => self.score_formation(&spec, &sol),
}
}
fn embed(&self, solution: &Solution) -> DomainEmbedding {
let features = self.extract_features(solution);
DomainEmbedding::new(features, self.id.clone())
}
fn embedding_dim(&self) -> usize {
EMBEDDING_DIM
}
fn reference_solution(&self, task: &Task) -> Option<Solution> {
let spec: RoboticsTaskSpec = serde_json::from_value(task.spec.clone()).ok()?;
let sol = match spec.category {
RoboticsCategory::PointCloudClustering => {
// Assign each point-group to its own cluster.
let num_clusters = spec.obstacles.len().max(1);
let cluster_ids: Vec<usize> = (0..spec.size)
.map(|i| i * num_clusters / spec.size.max(1))
.collect();
RoboticsSolution {
waypoints: Vec::new(),
cluster_ids,
skill_sequence: Vec::new(),
edges: Vec::new(),
}
}
RoboticsCategory::ObstacleAvoidance => {
// Straight-line path (naive reference).
let start = spec.start.unwrap_or([0.0; 3]);
let goal = spec.goal.unwrap_or([10.0, 10.0, 0.0]);
let steps = 10;
let waypoints: Vec<[f64; 3]> = (0..=steps)
.map(|s| {
let t = s as f64 / steps as f64;
[
start[0] + (goal[0] - start[0]) * t,
start[1] + (goal[1] - start[1]) * t,
start[2] + (goal[2] - start[2]) * t,
]
})
.collect();
RoboticsSolution {
waypoints,
cluster_ids: Vec::new(),
skill_sequence: Vec::new(),
edges: Vec::new(),
}
}
RoboticsCategory::SceneGraphConstruction => {
// Connect all pairs within distance 10.
let mut edges = Vec::new();
for i in 0..spec.obstacles.len() {
for j in (i + 1)..spec.obstacles.len() {
let d = dist3(&spec.obstacles[i].center, &spec.obstacles[j].center);
if d < 10.0 {
edges.push((i, j));
}
}
}
RoboticsSolution {
waypoints: Vec::new(),
cluster_ids: Vec::new(),
skill_sequence: Vec::new(),
edges,
}
}
RoboticsCategory::SkillSequencing => {
let skill_sequence: Vec<String> =
spec.skills.iter().map(|s| s.name.clone()).collect();
RoboticsSolution {
waypoints: Vec::new(),
cluster_ids: Vec::new(),
skill_sequence,
edges: Vec::new(),
}
}
RoboticsCategory::SwarmFormation => {
// Circle formation.
let n = spec.size;
let waypoints: Vec<[f64; 3]> = (0..n)
.map(|i| {
let angle = 2.0 * std::f64::consts::PI * i as f64 / n as f64;
[5.0 * angle.cos(), 5.0 * angle.sin(), 0.0]
})
.collect();
RoboticsSolution {
waypoints,
cluster_ids: Vec::new(),
skill_sequence: Vec::new(),
edges: Vec::new(),
}
}
};
let content = serde_json::to_string_pretty(&sol).ok()?;
Some(Solution {
task_id: task.id.clone(),
content,
data: serde_json::to_value(&sol).ok()?,
})
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2)).sqrt()
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_robotics_domain_id() {
let domain = RoboticsDomain::new();
assert_eq!(domain.id().0, "robotics");
assert_eq!(domain.name(), "Cognitive Robotics");
}
#[test]
fn test_generate_tasks_all_difficulties() {
let domain = RoboticsDomain::new();
for &d in &[0.1, 0.5, 0.9] {
let tasks = domain.generate_tasks(10, d);
assert_eq!(tasks.len(), 10);
for task in &tasks {
assert_eq!(task.domain_id, *domain.id());
}
}
}
#[test]
fn test_reference_solution_exists() {
let domain = RoboticsDomain::new();
let tasks = domain.generate_tasks(20, 0.5);
for task in &tasks {
let ref_sol = domain.reference_solution(task);
assert!(ref_sol.is_some(), "Reference solution missing for {}", task.id);
}
}
#[test]
fn test_evaluate_reference_solutions() {
let domain = RoboticsDomain::new();
let tasks = domain.generate_tasks(20, 0.3);
for task in &tasks {
if let Some(solution) = domain.reference_solution(task) {
let eval = domain.evaluate(task, &solution);
assert!(
eval.score >= 0.0 && eval.score <= 1.0,
"Score out of range for {}: {}",
task.id,
eval.score,
);
}
}
}
#[test]
fn test_embedding_dimension() {
let domain = RoboticsDomain::new();
assert_eq!(domain.embedding_dim(), EMBEDDING_DIM);
let sol = Solution {
task_id: "test".into(),
content: "cluster points into groups near obstacles using distance threshold".into(),
data: serde_json::Value::Null,
};
let embedding = domain.embed(&sol);
assert_eq!(embedding.dim, EMBEDDING_DIM);
assert_eq!(embedding.vector.len(), EMBEDDING_DIM);
}
#[test]
fn test_cross_domain_embedding_compatibility() {
let domain = RoboticsDomain::new();
let robotics_sol = Solution {
task_id: "r1".into(),
content: "plan path through obstacles avoiding collision with distance checks".into(),
data: serde_json::Value::Null,
};
let robotics_emb = domain.embed(&robotics_sol);
// Embedding should be same dimension as other domains (64).
assert_eq!(robotics_emb.dim, 64);
// Cosine similarity with itself should be 1.0.
let self_sim = robotics_emb.cosine_similarity(&robotics_emb);
assert!((self_sim - 1.0).abs() < 1e-4);
}
#[test]
fn test_score_skill_ordering_violation() {
let domain = RoboticsDomain::new();
let spec = RoboticsTaskSpec {
category: RoboticsCategory::SkillSequencing,
description: "test".into(),
size: 3,
world_bounds: [10.0, 10.0, 3.0],
obstacles: Vec::new(),
skills: vec![
TaskSkill {
name: "approach".into(),
preconditions: Vec::new(),
effects: vec!["approach_done".into()],
cost: 1.0,
},
TaskSkill {
name: "grasp".into(),
preconditions: vec!["approach_done".into()],
effects: vec!["grasp_done".into()],
cost: 1.0,
},
TaskSkill {
name: "lift".into(),
preconditions: vec!["grasp_done".into()],
effects: vec!["lift_done".into()],
cost: 1.0,
},
],
start: None,
goal: None,
formation: None,
};
// Correct ordering.
let good = RoboticsSolution {
waypoints: Vec::new(),
cluster_ids: Vec::new(),
skill_sequence: vec!["approach".into(), "grasp".into(), "lift".into()],
edges: Vec::new(),
};
let good_eval = domain.score_skill_sequence(&spec, &good);
assert!(good_eval.correctness > 0.5);
// Bad ordering (reversed).
let bad = RoboticsSolution {
waypoints: Vec::new(),
cluster_ids: Vec::new(),
skill_sequence: vec!["lift".into(), "grasp".into(), "approach".into()],
edges: Vec::new(),
};
let bad_eval = domain.score_skill_sequence(&spec, &bad);
assert!(bad_eval.score < good_eval.score);
}
#[test]
fn test_engine_with_robotics_domain() {
use ruvector_domain_expansion::DomainExpansionEngine;
let mut engine = DomainExpansionEngine::new();
engine.register_domain(Box::new(RoboticsDomain::new()));
let ids = engine.domain_ids();
// 3 built-in + robotics = 4.
assert_eq!(ids.len(), 4);
assert!(ids.iter().any(|id| id.0 == "robotics"));
// Generate tasks from robotics domain.
let domain_id = DomainId("robotics".into());
let tasks = engine.generate_tasks(&domain_id, 5, 0.5);
assert_eq!(tasks.len(), 5);
// Embed a robotics solution.
let sol = Solution {
task_id: "r".into(),
content: "navigate robot through obstacle field using sensor fusion pipeline".into(),
data: serde_json::Value::Null,
};
let emb = engine.embed(&domain_id, &sol);
assert!(emb.is_some());
assert_eq!(emb.unwrap().dim, 64);
}
#[test]
fn test_transfer_from_planning_to_robotics() {
use ruvector_domain_expansion::transfer::{ArmId, ContextBucket};
use ruvector_domain_expansion::DomainExpansionEngine;
let mut engine = DomainExpansionEngine::new();
engine.register_domain(Box::new(RoboticsDomain::new()));
let planning_id = DomainId("structured_planning".into());
let robotics_id = DomainId("robotics".into());
let bucket = ContextBucket {
difficulty_tier: "medium".into(),
category: "spatial".into(),
};
// Train on planning domain.
for _ in 0..30 {
engine.thompson.record_outcome(
&planning_id,
bucket.clone(),
ArmId("greedy".into()),
0.85,
1.0,
);
}
// Transfer to robotics.
engine.initiate_transfer(&planning_id, &robotics_id);
// Verify transfer priors are seeded.
let arm = engine.select_arm(&robotics_id, &bucket);
assert!(arm.is_some());
}
}

View File

@@ -0,0 +1,61 @@
//! # ruvector-robotics
//!
//! Unified cognitive robotics platform built on ruvector's vector database,
//! graph neural networks, and self-learning infrastructure.
//!
//! ## Modules
//!
//! - [`bridge`]: Core robotics types, converters, spatial indexing, and perception pipeline
//! - [`perception`]: Scene graph construction, obstacle detection, trajectory prediction
//! - [`cognitive`]: Cognitive architecture with behavior trees, memory, skills, and swarm intelligence
//! - [`mcp`]: Model Context Protocol tool registrations for agentic robotics
//!
//! ## Quick Start
//!
//! ```rust
//! use ruvector_robotics::bridge::{Point3D, PointCloud, SpatialIndex};
//!
//! // Create a point cloud from sensor data
//! let cloud = PointCloud::new(
//! vec![Point3D::new(1.0, 2.0, 3.0), Point3D::new(4.0, 5.0, 6.0)],
//! 1000,
//! );
//!
//! // Index and search
//! let mut index = SpatialIndex::new(3);
//! index.insert_point_cloud(&cloud);
//! let nearest = index.search_nearest(&[2.0, 3.0, 4.0], 1).unwrap();
//! assert_eq!(nearest.len(), 1);
//! ```
pub mod bridge;
pub mod cognitive;
pub mod mcp;
pub mod perception;
pub mod planning;
/// Cross-domain transfer learning integration with `ruvector-domain-expansion`.
///
/// Requires the `domain-expansion` feature flag.
#[cfg(feature = "domain-expansion")]
pub mod domain_expansion;
/// RVF packaging for robotics data.
///
/// Bridges point clouds, scene graphs, trajectories, Gaussian splats, and
/// obstacles into the RuVector Format (`.rvf`) for persistence and similarity
/// search. Requires the `rvf` feature flag.
#[cfg(feature = "rvf")]
pub mod rvf;
// Convenience re-exports of the most commonly used types.
pub use bridge::{
BridgeConfig, DistanceMetric, OccupancyGrid, Obstacle as BridgeObstacle, Point3D, PointCloud,
Pose, Quaternion, RobotState, SceneEdge, SceneGraph, SceneObject, SensorFrame, SpatialIndex,
Trajectory,
};
pub use cognitive::{BehaviorNode, BehaviorStatus, BehaviorTree, CognitiveCore, CognitiveState};
pub use perception::{
ObstacleDetector, PerceptionConfig, PerceptionPipeline, SceneGraphBuilder,
};
pub use planning::{GridPath, VelocityCommand};

View File

@@ -0,0 +1,350 @@
//! MCP tool execution engine.
//!
//! [`ToolExecutor`] wires up the perception pipeline, spatial index, and
//! memory system to actually *execute* tool requests, turning the schema-only
//! registry into a working tool backend.
use std::time::Instant;
use crate::bridge::{Point3D, PointCloud, SceneObject, SpatialIndex};
use crate::mcp::{ToolRequest, ToolResponse};
use crate::perception::PerceptionPipeline;
/// Stateful executor that handles incoming [`ToolRequest`]s by dispatching to
/// the appropriate subsystem.
pub struct ToolExecutor {
pipeline: PerceptionPipeline,
index: SpatialIndex,
}
impl ToolExecutor {
/// Create a new executor with default subsystem configurations.
pub fn new() -> Self {
Self {
pipeline: PerceptionPipeline::with_thresholds(0.5, 2.0),
index: SpatialIndex::new(3),
}
}
/// Execute a tool request and return a response with timing.
pub fn execute(&mut self, request: &ToolRequest) -> ToolResponse {
let start = Instant::now();
let result = match request.tool_name.as_str() {
"detect_obstacles" => self.handle_detect_obstacles(request),
"build_scene_graph" => self.handle_build_scene_graph(request),
"predict_trajectory" => self.handle_predict_trajectory(request),
"focus_attention" => self.handle_focus_attention(request),
"detect_anomalies" => self.handle_detect_anomalies(request),
"spatial_search" => self.handle_spatial_search(request),
"insert_points" => self.handle_insert_points(request),
other => Err(format!("unknown tool: {other}")),
};
let latency_us = start.elapsed().as_micros() as u64;
match result {
Ok(value) => ToolResponse::ok(value, latency_us),
Err(msg) => ToolResponse::err(msg, latency_us),
}
}
/// Access the internal spatial index (e.g. for testing).
pub fn index(&self) -> &SpatialIndex {
&self.index
}
// -- handlers -----------------------------------------------------------
fn handle_detect_obstacles(
&self,
req: &ToolRequest,
) -> std::result::Result<serde_json::Value, String> {
let cloud = parse_point_cloud(req, "point_cloud_json")?;
let pos = parse_position(req, "robot_position")?;
let max_dist = req
.arguments
.get("max_distance")
.and_then(|v| v.as_f64())
.unwrap_or(20.0);
let obstacles = self
.pipeline
.detect_obstacles(&cloud, pos, max_dist)
.map_err(|e| e.to_string())?;
serde_json::to_value(&obstacles).map_err(|e| e.to_string())
}
fn handle_build_scene_graph(
&self,
req: &ToolRequest,
) -> std::result::Result<serde_json::Value, String> {
let objects: Vec<SceneObject> = parse_json_arg(req, "objects_json")?;
let max_edge = req
.arguments
.get("max_edge_distance")
.and_then(|v| v.as_f64())
.unwrap_or(5.0);
let graph = self
.pipeline
.build_scene_graph(&objects, max_edge)
.map_err(|e| e.to_string())?;
serde_json::to_value(&graph).map_err(|e| e.to_string())
}
fn handle_predict_trajectory(
&self,
req: &ToolRequest,
) -> std::result::Result<serde_json::Value, String> {
let pos = parse_position(req, "position")?;
let vel = parse_position(req, "velocity")?;
let steps = req
.arguments
.get("steps")
.and_then(|v| v.as_u64())
.unwrap_or(10) as usize;
let dt = req
.arguments
.get("dt")
.and_then(|v| v.as_f64())
.unwrap_or(0.1);
let traj = self
.pipeline
.predict_trajectory(pos, vel, steps, dt)
.map_err(|e| e.to_string())?;
serde_json::to_value(&traj).map_err(|e| e.to_string())
}
fn handle_focus_attention(
&self,
req: &ToolRequest,
) -> std::result::Result<serde_json::Value, String> {
let cloud = parse_point_cloud(req, "point_cloud_json")?;
let center = parse_position(req, "center")?;
let radius = req
.arguments
.get("radius")
.and_then(|v| v.as_f64())
.ok_or("missing 'radius'")?;
let focused = self
.pipeline
.focus_attention(&cloud, center, radius)
.map_err(|e| e.to_string())?;
serde_json::to_value(&focused).map_err(|e| e.to_string())
}
fn handle_detect_anomalies(
&self,
req: &ToolRequest,
) -> std::result::Result<serde_json::Value, String> {
let cloud = parse_point_cloud(req, "point_cloud_json")?;
let anomalies = self
.pipeline
.detect_anomalies(&cloud)
.map_err(|e| e.to_string())?;
serde_json::to_value(&anomalies).map_err(|e| e.to_string())
}
fn handle_spatial_search(
&self,
req: &ToolRequest,
) -> std::result::Result<serde_json::Value, String> {
let query: Vec<f32> = req
.arguments
.get("query")
.and_then(|v| v.as_array())
.map(|a| a.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect())
.ok_or("missing 'query'")?;
let k = req
.arguments
.get("k")
.and_then(|v| v.as_u64())
.unwrap_or(5) as usize;
let results = self
.index
.search_nearest(&query, k)
.map_err(|e| e.to_string())?;
let pairs: Vec<serde_json::Value> = results
.iter()
.map(|(idx, dist)| serde_json::json!({"index": idx, "distance": dist}))
.collect();
Ok(serde_json::json!(pairs))
}
fn handle_insert_points(
&mut self,
req: &ToolRequest,
) -> std::result::Result<serde_json::Value, String> {
let points: Vec<Point3D> = parse_json_arg(req, "points_json")?;
let cloud = PointCloud::new(points, 0);
self.index.insert_point_cloud(&cloud);
Ok(serde_json::json!({"inserted": cloud.len(), "total": self.index.len()}))
}
}
impl Default for ToolExecutor {
fn default() -> Self {
Self::new()
}
}
// -- argument parsers -------------------------------------------------------
fn parse_point_cloud(
req: &ToolRequest,
key: &str,
) -> std::result::Result<PointCloud, String> {
let raw = req
.arguments
.get(key)
.ok_or_else(|| format!("missing '{key}'"))?;
if let Some(s) = raw.as_str() {
serde_json::from_str(s).map_err(|e| format!("invalid point cloud JSON: {e}"))
} else {
serde_json::from_value(raw.clone()).map_err(|e| format!("invalid point cloud: {e}"))
}
}
fn parse_position(
req: &ToolRequest,
key: &str,
) -> std::result::Result<[f64; 3], String> {
let arr = req
.arguments
.get(key)
.and_then(|v| v.as_array())
.ok_or_else(|| format!("missing '{key}'"))?;
if arr.len() < 3 {
return Err(format!("'{key}' must have at least 3 elements"));
}
let x = arr[0].as_f64().ok_or("non-numeric")?;
let y = arr[1].as_f64().ok_or("non-numeric")?;
let z = arr[2].as_f64().ok_or("non-numeric")?;
Ok([x, y, z])
}
fn parse_json_arg<T: serde::de::DeserializeOwned>(
req: &ToolRequest,
key: &str,
) -> std::result::Result<T, String> {
let raw = req
.arguments
.get(key)
.ok_or_else(|| format!("missing '{key}'"))?;
if let Some(s) = raw.as_str() {
serde_json::from_str(s).map_err(|e| format!("invalid JSON for '{key}': {e}"))
} else {
serde_json::from_value(raw.clone()).map_err(|e| format!("invalid '{key}': {e}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_request(tool: &str, args: serde_json::Value) -> ToolRequest {
let arguments: HashMap<String, serde_json::Value> =
serde_json::from_value(args).unwrap();
ToolRequest { tool_name: tool.to_string(), arguments }
}
#[test]
fn test_detect_obstacles() {
let mut exec = ToolExecutor::new();
let cloud = PointCloud::new(
vec![
Point3D::new(2.0, 0.0, 0.0),
Point3D::new(2.1, 0.0, 0.0),
Point3D::new(2.0, 0.1, 0.0),
],
1000,
);
let cloud_json = serde_json::to_string(&cloud).unwrap();
let req = make_request("detect_obstacles", serde_json::json!({
"point_cloud_json": cloud_json,
"robot_position": [0.0, 0.0, 0.0],
}));
let resp = exec.execute(&req);
assert!(resp.success);
}
#[test]
fn test_predict_trajectory() {
let mut exec = ToolExecutor::new();
let req = make_request("predict_trajectory", serde_json::json!({
"position": [0.0, 0.0, 0.0],
"velocity": [1.0, 0.0, 0.0],
"steps": 5,
"dt": 0.5,
}));
let resp = exec.execute(&req);
assert!(resp.success);
let traj = resp.result;
assert_eq!(traj["waypoints"].as_array().unwrap().len(), 5);
}
#[test]
fn test_insert_and_search() {
let mut exec = ToolExecutor::new();
// Insert points
let points = vec![
Point3D::new(1.0, 0.0, 0.0),
Point3D::new(2.0, 0.0, 0.0),
Point3D::new(10.0, 0.0, 0.0),
];
let points_json = serde_json::to_string(&points).unwrap();
let req = make_request("insert_points", serde_json::json!({
"points_json": points_json,
}));
let resp = exec.execute(&req);
assert!(resp.success);
assert_eq!(resp.result["total"], 3);
// Search
let req = make_request("spatial_search", serde_json::json!({
"query": [1.0, 0.0, 0.0],
"k": 2,
}));
let resp = exec.execute(&req);
assert!(resp.success);
let results = resp.result.as_array().unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_unknown_tool() {
let mut exec = ToolExecutor::new();
let req = make_request("nonexistent", serde_json::json!({}));
let resp = exec.execute(&req);
assert!(!resp.success);
assert!(resp.error.unwrap().contains("unknown tool"));
}
#[test]
fn test_build_scene_graph() {
let mut exec = ToolExecutor::new();
let objects = vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(1, [2.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
];
let objects_json = serde_json::to_string(&objects).unwrap();
let req = make_request("build_scene_graph", serde_json::json!({
"objects_json": objects_json,
"max_edge_distance": 5.0,
}));
let resp = exec.execute(&req);
assert!(resp.success);
assert_eq!(resp.result["edges"].as_array().unwrap().len(), 1);
}
}

View File

@@ -0,0 +1,683 @@
//! MCP tool registrations for agentic robotics.
//!
//! Provides a registry of robotics tools that can be exposed via MCP servers.
//! This is a lightweight, dependency-free implementation that models tool
//! definitions, categories, and JSON schema generation without pulling in an
//! external MCP SDK.
pub mod executor;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ---------------------------------------------------------------------------
// Parameter types
// ---------------------------------------------------------------------------
/// JSON Schema type for a tool parameter.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ParamType {
String,
Number,
Integer,
Boolean,
Array,
Object,
}
impl ParamType {
fn as_schema_str(self) -> &'static str {
match self {
Self::String => "string",
Self::Number => "number",
Self::Integer => "integer",
Self::Boolean => "boolean",
Self::Array => "array",
Self::Object => "object",
}
}
}
/// A single parameter accepted by a tool.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolParameter {
pub name: String,
pub description: String,
pub param_type: ParamType,
pub required: bool,
}
impl ToolParameter {
pub fn new(name: &str, description: &str, param_type: ParamType, required: bool) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
param_type,
required,
}
}
}
// ---------------------------------------------------------------------------
// Tool categories
// ---------------------------------------------------------------------------
/// High-level category that a tool belongs to.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ToolCategory {
Perception,
Navigation,
Cognition,
Swarm,
Memory,
Planning,
}
// ---------------------------------------------------------------------------
// Tool definition
// ---------------------------------------------------------------------------
/// Complete definition of a single MCP-exposed tool.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Vec<ToolParameter>,
pub category: ToolCategory,
}
impl ToolDefinition {
pub fn new(
name: &str,
description: &str,
parameters: Vec<ToolParameter>,
category: ToolCategory,
) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
parameters,
category,
}
}
/// Convert this tool definition into its JSON Schema representation.
fn to_schema(&self) -> serde_json::Value {
let mut properties = serde_json::Map::new();
let mut required: Vec<serde_json::Value> = Vec::new();
for param in &self.parameters {
let mut prop = serde_json::Map::new();
prop.insert(
"type".to_string(),
serde_json::Value::String(param.param_type.as_schema_str().to_string()),
);
prop.insert(
"description".to_string(),
serde_json::Value::String(param.description.clone()),
);
properties.insert(param.name.clone(), serde_json::Value::Object(prop));
if param.required {
required.push(serde_json::Value::String(param.name.clone()));
}
}
serde_json::json!({
"name": self.name,
"description": self.description,
"inputSchema": {
"type": "object",
"properties": properties,
"required": required,
}
})
}
}
// ---------------------------------------------------------------------------
// Request / Response
// ---------------------------------------------------------------------------
/// A request to invoke a tool by name with JSON arguments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolRequest {
pub tool_name: String,
pub arguments: HashMap<String, serde_json::Value>,
}
/// The result of a tool invocation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResponse {
pub success: bool,
pub result: serde_json::Value,
pub error: Option<String>,
pub latency_us: u64,
}
impl ToolResponse {
/// Convenience constructor for a successful response.
pub fn ok(result: serde_json::Value, latency_us: u64) -> Self {
Self { success: true, result, error: None, latency_us }
}
/// Convenience constructor for a failed response.
pub fn err(message: impl Into<String>, latency_us: u64) -> Self {
Self {
success: false,
result: serde_json::Value::Null,
error: Some(message.into()),
latency_us,
}
}
}
// ---------------------------------------------------------------------------
// Registry
// ---------------------------------------------------------------------------
/// Registry of robotics tools exposed to MCP clients.
///
/// Call [`RoboticsToolRegistry::new`] to get a registry pre-populated with all
/// built-in tools, or start from [`RoboticsToolRegistry::empty`] and register
/// tools manually.
#[derive(Debug, Clone)]
pub struct RoboticsToolRegistry {
tools: HashMap<String, ToolDefinition>,
}
impl Default for RoboticsToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl RoboticsToolRegistry {
/// Create a registry pre-populated with all built-in robotics tools.
pub fn new() -> Self {
let mut registry = Self { tools: HashMap::new() };
registry.register_defaults();
registry
}
/// Create an empty registry with no tools registered.
pub fn empty() -> Self {
Self { tools: HashMap::new() }
}
/// Register a single tool. Overwrites any existing tool with the same name.
pub fn register_tool(&mut self, tool: ToolDefinition) {
self.tools.insert(tool.name.clone(), tool);
}
/// List every registered tool (unordered).
pub fn list_tools(&self) -> Vec<&ToolDefinition> {
self.tools.values().collect()
}
/// Look up a tool by its exact name.
pub fn get_tool(&self, name: &str) -> Option<&ToolDefinition> {
self.tools.get(name)
}
/// Return all tools belonging to the given category.
pub fn list_by_category(&self, category: ToolCategory) -> Vec<&ToolDefinition> {
self.tools.values().filter(|t| t.category == category).collect()
}
/// Produce a full MCP-compatible JSON schema describing every tool.
pub fn to_mcp_schema(&self) -> serde_json::Value {
let mut tools: Vec<serde_json::Value> =
self.tools.values().map(|t| t.to_schema()).collect();
// Sort by name for deterministic output.
tools.sort_by(|a, b| {
let na = a.get("name").and_then(|v| v.as_str()).unwrap_or("");
let nb = b.get("name").and_then(|v| v.as_str()).unwrap_or("");
na.cmp(nb)
});
serde_json::json!({ "tools": tools })
}
// -- default tool registration ------------------------------------------
fn register_defaults(&mut self) {
self.register_tool(ToolDefinition::new(
"detect_obstacles",
"Detect obstacles in a point cloud relative to the robot position",
vec![
ToolParameter::new(
"point_cloud_json", "JSON-encoded point cloud", ParamType::String, true,
),
ToolParameter::new(
"robot_position", "Robot [x,y,z] position", ParamType::Array, true,
),
ToolParameter::new(
"max_distance", "Maximum detection distance in meters", ParamType::Number, false,
),
],
ToolCategory::Perception,
));
self.register_tool(ToolDefinition::new(
"build_scene_graph",
"Build a scene graph from detected objects with spatial edges",
vec![
ToolParameter::new(
"objects_json", "JSON array of scene objects", ParamType::String, true,
),
ToolParameter::new(
"max_edge_distance", "Maximum edge distance in meters", ParamType::Number, false,
),
],
ToolCategory::Perception,
));
self.register_tool(ToolDefinition::new(
"predict_trajectory",
"Predict future trajectory from current position and velocity",
vec![
ToolParameter::new("position", "Current [x,y,z] position", ParamType::Array, true),
ToolParameter::new("velocity", "Current [vx,vy,vz] velocity", ParamType::Array, true),
ToolParameter::new("steps", "Number of prediction steps", ParamType::Integer, true),
ToolParameter::new("dt", "Time step in seconds", ParamType::Number, false),
],
ToolCategory::Navigation,
));
self.register_tool(ToolDefinition::new(
"focus_attention",
"Extract a region of interest from a point cloud by center and radius",
vec![
ToolParameter::new(
"point_cloud_json", "JSON-encoded point cloud", ParamType::String, true,
),
ToolParameter::new("center", "Focus center [x,y,z]", ParamType::Array, true),
ToolParameter::new("radius", "Attention radius in meters", ParamType::Number, true),
],
ToolCategory::Perception,
));
self.register_tool(ToolDefinition::new(
"detect_anomalies",
"Detect anomalous points in a point cloud using statistical analysis",
vec![
ToolParameter::new(
"point_cloud_json", "JSON-encoded point cloud", ParamType::String, true,
),
],
ToolCategory::Perception,
));
self.register_tool(ToolDefinition::new(
"spatial_search",
"Search for nearest neighbours in the spatial index",
vec![
ToolParameter::new("query", "Query vector [x,y,z]", ParamType::Array, true),
ToolParameter::new("k", "Number of neighbours to return", ParamType::Integer, true),
],
ToolCategory::Perception,
));
self.register_tool(ToolDefinition::new(
"insert_points",
"Insert points into the spatial index for later retrieval",
vec![
ToolParameter::new(
"points_json", "JSON array of [x,y,z] points", ParamType::String, true,
),
],
ToolCategory::Perception,
));
self.register_tool(ToolDefinition::new(
"store_memory",
"Store a vector in episodic memory with an importance score",
vec![
ToolParameter::new("key", "Unique memory key", ParamType::String, true),
ToolParameter::new("data", "Data vector to store", ParamType::Array, true),
ToolParameter::new(
"importance", "Importance weight 0.0-1.0", ParamType::Number, false,
),
],
ToolCategory::Memory,
));
self.register_tool(ToolDefinition::new(
"recall_memory",
"Recall the k most similar memories to a query vector",
vec![
ToolParameter::new(
"query", "Query vector for similarity search", ParamType::Array, true,
),
ToolParameter::new("k", "Number of memories to recall", ParamType::Integer, true),
],
ToolCategory::Memory,
));
self.register_tool(ToolDefinition::new(
"learn_skill",
"Learn a new skill from demonstration trajectories",
vec![
ToolParameter::new("name", "Skill name identifier", ParamType::String, true),
ToolParameter::new(
"demonstrations_json",
"JSON array of demonstration trajectories",
ParamType::String,
true,
),
],
ToolCategory::Cognition,
));
self.register_tool(ToolDefinition::new(
"execute_skill",
"Execute a previously learned skill by name",
vec![
ToolParameter::new("name", "Name of the skill to execute", ParamType::String, true),
],
ToolCategory::Cognition,
));
self.register_tool(ToolDefinition::new(
"plan_behavior",
"Generate a behavior tree plan for a given goal and preconditions",
vec![
ToolParameter::new("goal", "Goal description", ParamType::String, true),
ToolParameter::new(
"conditions_json",
"JSON object of current conditions",
ParamType::String,
false,
),
],
ToolCategory::Planning,
));
self.register_tool(ToolDefinition::new(
"coordinate_swarm",
"Coordinate a multi-robot swarm for a given task",
vec![
ToolParameter::new(
"task_json", "JSON-encoded task specification", ParamType::String, true,
),
],
ToolCategory::Swarm,
));
self.register_tool(ToolDefinition::new(
"update_world_model",
"Update the internal world model with a new or changed object",
vec![
ToolParameter::new(
"object_json", "JSON-encoded object to upsert", ParamType::String, true,
),
],
ToolCategory::Cognition,
));
self.register_tool(ToolDefinition::new(
"get_world_state",
"Retrieve the current world model state, optionally filtered by object id",
vec![
ToolParameter::new(
"object_id", "Optional object id to filter", ParamType::Integer, false,
),
],
ToolCategory::Cognition,
));
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_has_15_default_tools() {
let registry = RoboticsToolRegistry::new();
assert_eq!(registry.list_tools().len(), 15);
}
#[test]
fn test_list_tools_returns_all() {
let registry = RoboticsToolRegistry::new();
let tools = registry.list_tools();
let mut names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
names.sort();
let expected = vec![
"build_scene_graph",
"coordinate_swarm",
"detect_anomalies",
"detect_obstacles",
"execute_skill",
"focus_attention",
"get_world_state",
"insert_points",
"learn_skill",
"plan_behavior",
"predict_trajectory",
"recall_memory",
"spatial_search",
"store_memory",
"update_world_model",
];
assert_eq!(names, expected);
}
#[test]
fn test_get_tool_by_name() {
let registry = RoboticsToolRegistry::new();
let tool = registry.get_tool("detect_obstacles").unwrap();
assert_eq!(tool.category, ToolCategory::Perception);
assert_eq!(tool.parameters.len(), 3);
assert!(tool.parameters.iter().any(|p| p.name == "point_cloud_json" && p.required));
let tool = registry.get_tool("predict_trajectory").unwrap();
assert_eq!(tool.category, ToolCategory::Navigation);
assert_eq!(tool.parameters.len(), 4);
assert!(registry.get_tool("nonexistent").is_none());
}
#[test]
fn test_list_by_category_perception() {
let registry = RoboticsToolRegistry::new();
let perception = registry.list_by_category(ToolCategory::Perception);
assert_eq!(perception.len(), 6);
for tool in &perception {
assert_eq!(tool.category, ToolCategory::Perception);
}
}
#[test]
fn test_list_by_category_counts() {
let registry = RoboticsToolRegistry::new();
assert_eq!(registry.list_by_category(ToolCategory::Perception).len(), 6);
assert_eq!(registry.list_by_category(ToolCategory::Navigation).len(), 1);
assert_eq!(registry.list_by_category(ToolCategory::Cognition).len(), 4);
assert_eq!(registry.list_by_category(ToolCategory::Memory).len(), 2);
assert_eq!(registry.list_by_category(ToolCategory::Planning).len(), 1);
assert_eq!(registry.list_by_category(ToolCategory::Swarm).len(), 1);
}
#[test]
fn test_to_mcp_schema_valid_json() {
let registry = RoboticsToolRegistry::new();
let schema = registry.to_mcp_schema();
let tools = schema.get("tools").unwrap().as_array().unwrap();
assert_eq!(tools.len(), 15);
// Tools are sorted by name.
let names: Vec<&str> = tools
.iter()
.map(|t| t.get("name").unwrap().as_str().unwrap())
.collect();
let mut sorted = names.clone();
sorted.sort();
assert_eq!(names, sorted);
// Each tool has the expected schema shape.
for tool in tools {
assert!(tool.get("name").unwrap().is_string());
assert!(tool.get("description").unwrap().is_string());
let input = tool.get("inputSchema").unwrap();
assert_eq!(input.get("type").unwrap().as_str().unwrap(), "object");
assert!(input.get("properties").unwrap().is_object());
assert!(input.get("required").unwrap().is_array());
}
}
#[test]
fn test_schema_required_fields() {
let registry = RoboticsToolRegistry::new();
let schema = registry.to_mcp_schema();
let tools = schema["tools"].as_array().unwrap();
let obs = tools.iter().find(|t| t["name"] == "detect_obstacles").unwrap();
let required = obs["inputSchema"]["required"].as_array().unwrap();
let req_names: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
assert!(req_names.contains(&"point_cloud_json"));
assert!(req_names.contains(&"robot_position"));
assert!(!req_names.contains(&"max_distance"));
}
#[test]
fn test_tool_request_serialization() {
let mut args = HashMap::new();
args.insert("k".to_string(), serde_json::json!(5));
args.insert("query".to_string(), serde_json::json!([1.0, 2.0, 3.0]));
let req = ToolRequest { tool_name: "spatial_search".to_string(), arguments: args };
let json = serde_json::to_string(&req).unwrap();
let deserialized: ToolRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.tool_name, "spatial_search");
assert_eq!(deserialized.arguments["k"], serde_json::json!(5));
}
#[test]
fn test_tool_response_ok() {
let resp = ToolResponse::ok(serde_json::json!({"obstacles": 3}), 420);
assert!(resp.success);
assert!(resp.error.is_none());
assert_eq!(resp.latency_us, 420);
assert_eq!(resp.result["obstacles"], 3);
let json = serde_json::to_string(&resp).unwrap();
let deserialized: ToolResponse = serde_json::from_str(&json).unwrap();
assert!(deserialized.success);
}
#[test]
fn test_tool_response_err() {
let resp = ToolResponse::err("something went wrong", 100);
assert!(!resp.success);
assert_eq!(resp.error.as_deref(), Some("something went wrong"));
assert!(resp.result.is_null());
}
#[test]
fn test_register_custom_tool() {
let mut registry = RoboticsToolRegistry::new();
assert_eq!(registry.list_tools().len(), 15);
let custom = ToolDefinition::new(
"my_custom_tool",
"A custom tool for testing",
vec![ToolParameter::new("input", "The input data", ParamType::String, true)],
ToolCategory::Cognition,
);
registry.register_tool(custom);
assert_eq!(registry.list_tools().len(), 16);
let tool = registry.get_tool("my_custom_tool").unwrap();
assert_eq!(tool.description, "A custom tool for testing");
assert_eq!(tool.parameters.len(), 1);
}
#[test]
fn test_register_overwrites_existing() {
let mut registry = RoboticsToolRegistry::new();
let replacement = ToolDefinition::new(
"detect_obstacles",
"Replaced description",
vec![],
ToolCategory::Perception,
);
registry.register_tool(replacement);
assert_eq!(registry.list_tools().len(), 15);
let tool = registry.get_tool("detect_obstacles").unwrap();
assert_eq!(tool.description, "Replaced description");
assert!(tool.parameters.is_empty());
}
#[test]
fn test_empty_registry() {
let registry = RoboticsToolRegistry::empty();
assert_eq!(registry.list_tools().len(), 0);
assert!(registry.get_tool("detect_obstacles").is_none());
}
#[test]
fn test_param_type_serde_roundtrip() {
let types = vec![
ParamType::String,
ParamType::Number,
ParamType::Integer,
ParamType::Boolean,
ParamType::Array,
ParamType::Object,
];
for pt in types {
let json = serde_json::to_string(&pt).unwrap();
let deserialized: ParamType = serde_json::from_str(&json).unwrap();
assert_eq!(pt, deserialized);
}
}
#[test]
fn test_tool_category_serde_roundtrip() {
let categories = vec![
ToolCategory::Perception,
ToolCategory::Navigation,
ToolCategory::Cognition,
ToolCategory::Swarm,
ToolCategory::Memory,
ToolCategory::Planning,
];
for cat in categories {
let json = serde_json::to_string(&cat).unwrap();
let deserialized: ToolCategory = serde_json::from_str(&json).unwrap();
assert_eq!(cat, deserialized);
}
}
#[test]
fn test_tool_definition_serde_roundtrip() {
let tool = ToolDefinition::new(
"test_tool",
"A tool for testing",
vec![
ToolParameter::new("a", "param a", ParamType::String, true),
ToolParameter::new("b", "param b", ParamType::Number, false),
],
ToolCategory::Navigation,
);
let json = serde_json::to_string(&tool).unwrap();
let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(tool, deserialized);
}
#[test]
fn test_default_trait() {
let registry = RoboticsToolRegistry::default();
assert_eq!(registry.list_tools().len(), 15);
}
}

View File

@@ -0,0 +1,159 @@
//! Shared spatial-hash clustering with union-find.
//!
//! Used by the obstacle detector, scene graph builder, and perception pipeline
//! to avoid duplicating the same algorithm.
//!
//! ## Optimizations
//!
//! - Union-by-rank prevents tree degeneration, keeping `find` near O(α(n)).
//! - Path halving in `find` for efficient path compression.
//! - `#[inline]` on hot helpers to ensure inlining in tight loops.
use crate::bridge::{Point3D, PointCloud};
use std::collections::HashMap;
/// Cluster a point cloud using spatial hashing and union-find over adjacent cells.
///
/// Points are binned into a 3-D grid with the given `cell_size`. Cells that
/// share a face, edge, or corner (26-neighbourhood) are merged via union-find.
/// Returns the resulting groups as separate point vectors.
pub fn cluster_point_cloud(cloud: &PointCloud, cell_size: f64) -> Vec<Vec<Point3D>> {
if cloud.points.is_empty() || cell_size <= 0.0 {
return Vec::new();
}
// 1. Map each point to a grid cell.
let mut cell_map: HashMap<(i64, i64, i64), Vec<usize>> = HashMap::new();
for (idx, p) in cloud.points.iter().enumerate() {
let key = cell_key(p, cell_size);
cell_map.entry(key).or_default().push(idx);
}
// 2. Build union-find over cells (with rank for balanced merges).
let cells: Vec<(i64, i64, i64)> = cell_map.keys().copied().collect();
let cell_count = cells.len();
let cell_idx: HashMap<(i64, i64, i64), usize> = cells
.iter()
.enumerate()
.map(|(i, &k)| (k, i))
.collect();
let mut parent: Vec<usize> = (0..cell_count).collect();
let mut rank: Vec<u8> = vec![0; cell_count];
for &(cx, cy, cz) in &cells {
let a = cell_idx[&(cx, cy, cz)];
for dx in -1..=1_i64 {
for dy in -1..=1_i64 {
for dz in -1..=1_i64 {
let neighbor = (cx + dx, cy + dy, cz + dz);
if let Some(&b) = cell_idx.get(&neighbor) {
uf_union(&mut parent, &mut rank, a, b);
}
}
}
}
}
// 3. Group points by their root representative.
let mut groups: HashMap<usize, Vec<Point3D>> = HashMap::new();
for (key, point_indices) in &cell_map {
let ci = cell_idx[key];
let root = uf_find(&mut parent, ci);
let entry = groups.entry(root).or_default();
for &pi in point_indices {
entry.push(cloud.points[pi]);
}
}
groups.into_values().collect()
}
/// Compute the grid cell key for a point.
#[inline]
fn cell_key(p: &Point3D, cell_size: f64) -> (i64, i64, i64) {
(
(p.x as f64 / cell_size).floor() as i64,
(p.y as f64 / cell_size).floor() as i64,
(p.z as f64 / cell_size).floor() as i64,
)
}
/// Path-compressing find (path halving).
#[inline]
fn uf_find(parent: &mut [usize], mut i: usize) -> usize {
while parent[i] != i {
parent[i] = parent[parent[i]];
i = parent[i];
}
i
}
/// Union by rank: attaches the shorter tree under the taller root.
#[inline]
fn uf_union(parent: &mut [usize], rank: &mut [u8], a: usize, b: usize) {
let ra = uf_find(parent, a);
let rb = uf_find(parent, b);
if ra != rb {
match rank[ra].cmp(&rank[rb]) {
std::cmp::Ordering::Less => parent[ra] = rb,
std::cmp::Ordering::Greater => parent[rb] = ra,
std::cmp::Ordering::Equal => {
parent[rb] = ra;
rank[ra] += 1;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cloud(pts: &[[f32; 3]]) -> PointCloud {
let points: Vec<Point3D> = pts.iter().map(|a| Point3D::new(a[0], a[1], a[2])).collect();
PointCloud::new(points, 0)
}
#[test]
fn test_empty_cloud() {
let cloud = PointCloud::default();
let clusters = cluster_point_cloud(&cloud, 1.0);
assert!(clusters.is_empty());
}
#[test]
fn test_single_cluster() {
let cloud = make_cloud(&[
[1.0, 1.0, 0.0],
[1.1, 1.0, 0.0],
[1.0, 1.1, 0.0],
]);
let clusters = cluster_point_cloud(&cloud, 0.5);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].len(), 3);
}
#[test]
fn test_two_clusters() {
let cloud = make_cloud(&[
[0.0, 0.0, 0.0],
[0.1, 0.0, 0.0],
[10.0, 10.0, 0.0],
[10.1, 10.0, 0.0],
]);
let clusters = cluster_point_cloud(&cloud, 0.5);
assert_eq!(clusters.len(), 2);
}
#[test]
fn test_negative_coordinates() {
let cloud = make_cloud(&[
[-1.0, -1.0, 0.0],
[-0.9, -1.0, 0.0],
[1.0, 1.0, 0.0],
]);
let clusters = cluster_point_cloud(&cloud, 0.5);
assert_eq!(clusters.len(), 2);
}
}

View File

@@ -0,0 +1,175 @@
//! Configuration types for the perception pipeline.
//!
//! Provides tuning knobs for scene-graph construction, obstacle detection,
//! and the top-level perception configuration that bundles them together.
use serde::{Deserialize, Serialize};
// ---------------------------------------------------------------------------
// Scene-graph configuration
// ---------------------------------------------------------------------------
/// Tuning parameters for scene-graph construction from point clouds.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SceneGraphConfig {
/// Maximum distance between two points to be considered part of the same
/// cluster (metres).
pub cluster_radius: f64,
/// Minimum number of points required to form a valid cluster / object.
pub min_cluster_size: usize,
/// Hard cap on the number of objects the builder will emit.
pub max_objects: usize,
/// Maximum centre-to-centre distance for two objects to be connected by an
/// edge in the scene graph (metres).
pub edge_distance_threshold: f64,
}
impl Default for SceneGraphConfig {
fn default() -> Self {
Self {
cluster_radius: 0.5,
min_cluster_size: 3,
max_objects: 256,
edge_distance_threshold: 5.0,
}
}
}
// ---------------------------------------------------------------------------
// Obstacle configuration
// ---------------------------------------------------------------------------
/// Tuning parameters for the obstacle detector.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ObstacleConfig {
/// Minimum number of points that must fall inside a cluster to be
/// considered an obstacle.
pub min_obstacle_size: usize,
/// Maximum range from the robot within which obstacles are detected
/// (metres).
pub max_detection_range: f64,
/// Extra padding added around every detected obstacle (metres).
pub safety_margin: f64,
}
impl Default for ObstacleConfig {
fn default() -> Self {
Self {
min_obstacle_size: 3,
max_detection_range: 20.0,
safety_margin: 0.2,
}
}
}
// ---------------------------------------------------------------------------
// Top-level perception configuration
// ---------------------------------------------------------------------------
/// Aggregated configuration for the full perception pipeline.
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct PerceptionConfig {
pub scene_graph: SceneGraphConfig,
pub obstacle: ObstacleConfig,
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scene_graph_config_defaults() {
let cfg = SceneGraphConfig::default();
assert!((cfg.cluster_radius - 0.5).abs() < f64::EPSILON);
assert_eq!(cfg.min_cluster_size, 3);
assert_eq!(cfg.max_objects, 256);
assert!((cfg.edge_distance_threshold - 5.0).abs() < f64::EPSILON);
}
#[test]
fn test_obstacle_config_defaults() {
let cfg = ObstacleConfig::default();
assert_eq!(cfg.min_obstacle_size, 3);
assert!((cfg.max_detection_range - 20.0).abs() < f64::EPSILON);
assert!((cfg.safety_margin - 0.2).abs() < f64::EPSILON);
}
#[test]
fn test_perception_config_defaults() {
let cfg = PerceptionConfig::default();
assert_eq!(cfg.scene_graph, SceneGraphConfig::default());
assert_eq!(cfg.obstacle, ObstacleConfig::default());
}
#[test]
fn test_scene_graph_config_serde_roundtrip() {
let cfg = SceneGraphConfig {
cluster_radius: 1.0,
min_cluster_size: 5,
max_objects: 128,
edge_distance_threshold: 3.0,
};
let json = serde_json::to_string(&cfg).unwrap();
let restored: SceneGraphConfig = serde_json::from_str(&json).unwrap();
assert_eq!(cfg, restored);
}
#[test]
fn test_obstacle_config_serde_roundtrip() {
let cfg = ObstacleConfig {
min_obstacle_size: 10,
max_detection_range: 50.0,
safety_margin: 0.5,
};
let json = serde_json::to_string(&cfg).unwrap();
let restored: ObstacleConfig = serde_json::from_str(&json).unwrap();
assert_eq!(cfg, restored);
}
#[test]
fn test_perception_config_serde_roundtrip() {
let cfg = PerceptionConfig::default();
let json = serde_json::to_string_pretty(&cfg).unwrap();
let restored: PerceptionConfig = serde_json::from_str(&json).unwrap();
assert_eq!(cfg, restored);
}
#[test]
fn test_config_clone_equality() {
let a = PerceptionConfig::default();
let b = a.clone();
assert_eq!(a, b);
}
#[test]
fn test_config_debug_format() {
let cfg = PerceptionConfig::default();
let dbg = format!("{:?}", cfg);
assert!(dbg.contains("PerceptionConfig"));
assert!(dbg.contains("SceneGraphConfig"));
assert!(dbg.contains("ObstacleConfig"));
}
#[test]
fn test_custom_perception_config() {
let cfg = PerceptionConfig {
scene_graph: SceneGraphConfig {
cluster_radius: 2.0,
min_cluster_size: 10,
max_objects: 64,
edge_distance_threshold: 10.0,
},
obstacle: ObstacleConfig {
min_obstacle_size: 5,
max_detection_range: 100.0,
safety_margin: 1.0,
},
};
assert!((cfg.scene_graph.cluster_radius - 2.0).abs() < f64::EPSILON);
assert_eq!(cfg.obstacle.min_obstacle_size, 5);
}
}

View File

@@ -0,0 +1,782 @@
//! Perception subsystem: scene graph construction, obstacle detection, and pipeline.
//!
//! This module sits on top of [`crate::bridge`] types and provides higher-level
//! perception building blocks used by the cognitive architecture.
pub mod clustering;
pub mod config;
pub mod obstacle_detector;
pub mod scene_graph;
pub mod sensor_fusion;
pub use config::{ObstacleConfig, PerceptionConfig, SceneGraphConfig};
pub use obstacle_detector::{ClassifiedObstacle, DetectedObstacle, ObstacleClass, ObstacleDetector};
pub use scene_graph::PointCloudSceneGraphBuilder;
use serde::{Deserialize, Serialize};
use crate::bridge::{
Obstacle, Point3D, PointCloud, SceneEdge, SceneGraph, SceneObject, Trajectory,
};
// ---------------------------------------------------------------------------
// Error type
// ---------------------------------------------------------------------------
/// Errors emitted by perception pipeline operations.
#[derive(Debug, thiserror::Error)]
pub enum PerceptionError {
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Processing failed: {0}")]
ProcessingFailed(String),
}
/// Convenience alias used throughout the perception module.
pub type Result<T> = std::result::Result<T, PerceptionError>;
// ---------------------------------------------------------------------------
// Anomaly type
// ---------------------------------------------------------------------------
/// A point-cloud anomaly detected via z-score outlier analysis.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Anomaly {
pub position: [f64; 3],
pub score: f64,
pub description: String,
pub timestamp: i64,
}
// ---------------------------------------------------------------------------
// SceneGraphBuilder
// ---------------------------------------------------------------------------
/// Builds a [`SceneGraph`] from detected obstacles or raw point clouds.
///
/// The builder clusters scene objects, computes spatial edges between nearby
/// objects, and produces a timestamped scene graph.
#[derive(Debug, Clone)]
pub struct SceneGraphBuilder {
edge_distance_threshold: f64,
max_objects: usize,
}
impl Default for SceneGraphBuilder {
fn default() -> Self {
Self {
edge_distance_threshold: 5.0,
max_objects: 256,
}
}
}
impl SceneGraphBuilder {
/// Create a new builder with explicit parameters.
pub fn new(edge_distance_threshold: f64, max_objects: usize) -> Self {
Self {
edge_distance_threshold,
max_objects,
}
}
/// Build a scene graph from a list of [`SceneObject`]s.
///
/// Edges are created between every pair of objects whose centers are within
/// `edge_distance_threshold`.
pub fn build(&self, mut objects: Vec<SceneObject>, timestamp: i64) -> SceneGraph {
objects.truncate(self.max_objects);
let threshold_sq = self.edge_distance_threshold * self.edge_distance_threshold;
let mut edges = Vec::new();
for i in 0..objects.len() {
for j in (i + 1)..objects.len() {
let dx = objects[i].center[0] - objects[j].center[0];
let dy = objects[i].center[1] - objects[j].center[1];
let dz = objects[i].center[2] - objects[j].center[2];
let dist_sq = dx * dx + dy * dy + dz * dz;
if dist_sq <= threshold_sq {
// Only compute sqrt for edges that pass the filter.
let dist = dist_sq.sqrt();
let relation = if dist < 1.0 {
"adjacent".to_string()
} else if dist < 3.0 {
"near".to_string()
} else {
"visible".to_string()
};
edges.push(SceneEdge {
from: objects[i].id,
to: objects[j].id,
distance: dist,
relation,
});
}
}
}
SceneGraph::new(objects, edges, timestamp)
}
/// Build a scene graph from detected obstacles.
pub fn build_from_obstacles(
&self,
obstacles: &[DetectedObstacle],
timestamp: i64,
) -> SceneGraph {
let objects: Vec<SceneObject> = obstacles
.iter()
.enumerate()
.map(|(i, obs)| {
let mut obj = SceneObject::new(i, obs.center, obs.extent);
obj.confidence = 1.0 - (obs.min_distance as f32 / 30.0).min(0.9);
obj.label = format!("obstacle_{}", i);
obj
})
.collect();
self.build(objects, timestamp)
}
/// Merge two scene graphs into one, re-computing edges.
pub fn merge(&self, a: &SceneGraph, b: &SceneGraph) -> SceneGraph {
let mut objects = a.objects.clone();
let offset = objects.len();
for obj in &b.objects {
let mut new_obj = obj.clone();
new_obj.id += offset;
objects.push(new_obj);
}
let timestamp = a.timestamp.max(b.timestamp);
self.build(objects, timestamp)
}
}
// ---------------------------------------------------------------------------
// PerceptionPipeline
// ---------------------------------------------------------------------------
/// End-to-end perception pipeline that processes sensor frames into scene
/// graphs and obstacle lists.
///
/// Supports two construction modes:
/// - [`PerceptionPipeline::new`] for config-driven construction
/// - [`PerceptionPipeline::with_thresholds`] for threshold-driven
/// construction (obstacle cell-size and anomaly z-score)
#[derive(Debug, Clone)]
pub struct PerceptionPipeline {
detector: ObstacleDetector,
graph_builder: SceneGraphBuilder,
frames_processed: u64,
obstacle_threshold: f64,
anomaly_threshold: f64,
}
impl PerceptionPipeline {
/// Create a new pipeline with the given configuration.
pub fn new(config: PerceptionConfig) -> Self {
let obstacle_threshold = config.obstacle.safety_margin * 5.0;
let detector = ObstacleDetector::new(config.obstacle);
let graph_builder = SceneGraphBuilder::new(
config.scene_graph.edge_distance_threshold,
config.scene_graph.max_objects,
);
Self {
detector,
graph_builder,
frames_processed: 0,
obstacle_threshold: obstacle_threshold.max(0.5),
anomaly_threshold: 2.0,
}
}
/// Create a pipeline from explicit thresholds.
///
/// * `obstacle_threshold` -- clustering cell size for obstacle grouping.
/// * `anomaly_threshold` -- z-score threshold for anomaly detection.
pub fn with_thresholds(obstacle_threshold: f64, anomaly_threshold: f64) -> Self {
use crate::perception::config::{ObstacleConfig, SceneGraphConfig};
let obstacle_cfg = ObstacleConfig::default();
let scene_cfg = SceneGraphConfig::default();
let detector = ObstacleDetector::new(obstacle_cfg.clone());
let graph_builder = SceneGraphBuilder::new(
scene_cfg.edge_distance_threshold,
scene_cfg.max_objects,
);
Self {
detector,
graph_builder,
frames_processed: 0,
obstacle_threshold,
anomaly_threshold,
}
}
/// Process a point cloud relative to a robot position, returning
/// detected obstacles and a scene graph.
pub fn process(
&mut self,
cloud: &PointCloud,
robot_pos: &[f64; 3],
) -> (Vec<DetectedObstacle>, SceneGraph) {
self.frames_processed += 1;
let obstacles = self.detector.detect(cloud, robot_pos);
let graph = self.graph_builder.build_from_obstacles(&obstacles, cloud.timestamp_us);
(obstacles, graph)
}
/// Classify previously detected obstacles.
pub fn classify(
&self,
obstacles: &[DetectedObstacle],
) -> Vec<ClassifiedObstacle> {
self.detector.classify_obstacles(obstacles)
}
/// Number of frames processed so far.
pub fn frames_processed(&self) -> u64 {
self.frames_processed
}
// -- Obstacle detection (bridge-level) ----------------------------------
/// Detect obstacles in `cloud` relative to `robot_position`.
///
/// Points further than `max_distance` from the robot are ignored.
/// Returns bridge-level [`Obstacle`] values sorted by distance.
pub fn detect_obstacles(
&self,
cloud: &PointCloud,
robot_position: [f64; 3],
max_distance: f64,
) -> Result<Vec<Obstacle>> {
if cloud.is_empty() {
return Ok(Vec::new());
}
let cell_size = self.obstacle_threshold.max(0.1);
let clusters = clustering::cluster_point_cloud(cloud, cell_size);
let mut obstacles: Vec<Obstacle> = Vec::new();
let mut next_id: u64 = 0;
for cluster in &clusters {
if cluster.len() < 2 {
continue;
}
let (center, radius) = Self::bounding_sphere(cluster);
let dist = Self::dist_3d(&center, &robot_position);
if dist > max_distance {
continue;
}
let confidence = (cluster.len() as f32 / cloud.points.len() as f32)
.clamp(0.1, 1.0);
obstacles.push(Obstacle {
id: next_id,
position: center,
distance: dist,
radius,
label: format!("obstacle_{}", next_id),
confidence,
});
next_id += 1;
}
obstacles.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (i, obs) in obstacles.iter_mut().enumerate() {
obs.id = i as u64;
}
Ok(obstacles)
}
// -- Scene-graph construction -------------------------------------------
/// Build a scene graph from pre-classified objects.
///
/// Edges are created between objects whose centres are within
/// `max_edge_distance`, labelled "adjacent" / "near" / "far".
pub fn build_scene_graph(
&self,
objects: &[SceneObject],
max_edge_distance: f64,
) -> Result<SceneGraph> {
if max_edge_distance <= 0.0 {
return Err(PerceptionError::InvalidInput(
"max_edge_distance must be positive".to_string(),
));
}
let mut edges: Vec<SceneEdge> = Vec::new();
let max_dist_sq = max_edge_distance * max_edge_distance;
for i in 0..objects.len() {
for j in (i + 1)..objects.len() {
let dx = objects[i].center[0] - objects[j].center[0];
let dy = objects[i].center[1] - objects[j].center[1];
let dz = objects[i].center[2] - objects[j].center[2];
let d_sq = dx * dx + dy * dy + dz * dz;
if d_sq <= max_dist_sq {
let d = d_sq.sqrt();
let relation = if d < max_edge_distance * 0.33 {
"adjacent"
} else if d < max_edge_distance * 0.66 {
"near"
} else {
"far"
};
edges.push(SceneEdge {
from: objects[i].id,
to: objects[j].id,
distance: d,
relation: relation.to_string(),
});
}
}
}
Ok(SceneGraph::new(objects.to_vec(), edges, 0))
}
// -- Trajectory prediction ----------------------------------------------
/// Predict a future trajectory via linear extrapolation.
///
/// Returns a [`Trajectory`] with `steps` waypoints, each separated by
/// `dt` seconds.
pub fn predict_trajectory(
&self,
position: [f64; 3],
velocity: [f64; 3],
steps: usize,
dt: f64,
) -> Result<Trajectory> {
if steps == 0 {
return Err(PerceptionError::InvalidInput(
"steps must be > 0".to_string(),
));
}
if dt <= 0.0 {
return Err(PerceptionError::InvalidInput(
"dt must be positive".to_string(),
));
}
let mut waypoints = Vec::with_capacity(steps);
let mut timestamps = Vec::with_capacity(steps);
for i in 1..=steps {
let t = i as f64 * dt;
waypoints.push([
position[0] + velocity[0] * t,
position[1] + velocity[1] * t,
position[2] + velocity[2] * t,
]);
timestamps.push((t * 1_000_000.0) as i64);
}
let confidence = (1.0 - (steps as f64 * dt * 0.1)).max(0.1);
Ok(Trajectory::new(waypoints, timestamps, confidence))
}
// -- Attention focusing -------------------------------------------------
/// Filter points from `cloud` that lie within `radius` of `center`.
pub fn focus_attention(
&self,
cloud: &PointCloud,
center: [f64; 3],
radius: f64,
) -> Result<Vec<Point3D>> {
if radius <= 0.0 {
return Err(PerceptionError::InvalidInput(
"radius must be positive".to_string(),
));
}
let r2 = radius * radius;
let focused: Vec<Point3D> = cloud
.points
.iter()
.filter(|p| {
let dx = p.x as f64 - center[0];
let dy = p.y as f64 - center[1];
let dz = p.z as f64 - center[2];
dx * dx + dy * dy + dz * dz <= r2
})
.copied()
.collect();
Ok(focused)
}
// -- Anomaly detection --------------------------------------------------
/// Detect anomalous points using z-score outlier analysis.
///
/// For each point the distance from the cloud centroid is computed;
/// points whose z-score exceeds `anomaly_threshold` are returned.
pub fn detect_anomalies(&self, cloud: &PointCloud) -> Result<Vec<Anomaly>> {
if cloud.points.len() < 2 {
return Ok(Vec::new());
}
let n = cloud.points.len() as f64;
// Pass 1: compute centroid.
let (mut cx, mut cy, mut cz) = (0.0_f64, 0.0_f64, 0.0_f64);
for p in &cloud.points {
cx += p.x as f64;
cy += p.y as f64;
cz += p.z as f64;
}
cx /= n;
cy /= n;
cz /= n;
// Pass 2: compute distances and running mean + variance (Welford's).
let mut distances: Vec<f64> = Vec::with_capacity(cloud.points.len());
let mut w_mean = 0.0_f64;
let mut w_m2 = 0.0_f64;
for (i, p) in cloud.points.iter().enumerate() {
let dx = p.x as f64 - cx;
let dy = p.y as f64 - cy;
let dz = p.z as f64 - cz;
let d = (dx * dx + dy * dy + dz * dz).sqrt();
distances.push(d);
let delta = d - w_mean;
w_mean += delta / (i + 1) as f64;
w_m2 += delta * (d - w_mean);
}
let variance = w_m2 / n;
let std_dev = variance.sqrt();
if std_dev < f64::EPSILON {
return Ok(Vec::new());
}
let mut anomalies = Vec::new();
for (i, p) in cloud.points.iter().enumerate() {
let z = (distances[i] - w_mean) / std_dev;
if z.abs() > self.anomaly_threshold {
anomalies.push(Anomaly {
position: [p.x as f64, p.y as f64, p.z as f64],
score: z.abs(),
description: format!(
"outlier at ({:.2}, {:.2}, {:.2}) z={:.2}",
p.x, p.y, p.z, z
),
timestamp: cloud.timestamp_us,
});
}
}
Ok(anomalies)
}
// -- private helpers ----------------------------------------------------
fn bounding_sphere(points: &[Point3D]) -> ([f64; 3], f64) {
debug_assert!(!points.is_empty(), "bounding_sphere called with empty slice");
let n = points.len() as f64;
let (mut sx, mut sy, mut sz) = (0.0_f64, 0.0_f64, 0.0_f64);
for p in points {
sx += p.x as f64;
sy += p.y as f64;
sz += p.z as f64;
}
let center = [sx / n, sy / n, sz / n];
// Compare squared distances and take a single sqrt at the end.
let radius_sq = points
.iter()
.map(|p| {
let dx = p.x as f64 - center[0];
let dy = p.y as f64 - center[1];
let dz = p.z as f64 - center[2];
dx * dx + dy * dy + dz * dz
})
.fold(0.0_f64, f64::max);
(center, radius_sq.sqrt())
}
#[inline]
fn dist_3d(a: &[f64; 3], b: &[f64; 3]) -> f64 {
((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2)).sqrt()
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::Point3D;
fn make_cloud(pts: &[[f32; 3]]) -> PointCloud {
let points: Vec<Point3D> =
pts.iter().map(|a| Point3D::new(a[0], a[1], a[2])).collect();
PointCloud::new(points, 1000)
}
// -- SceneGraphBuilder (inline) -----------------------------------------
#[test]
fn test_scene_graph_builder_basic() {
let builder = SceneGraphBuilder::default();
let objects = vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [0.5, 0.5, 0.5]),
SceneObject::new(1, [2.0, 0.0, 0.0], [0.5, 0.5, 0.5]),
SceneObject::new(2, [100.0, 0.0, 0.0], [0.5, 0.5, 0.5]),
];
let graph = builder.build(objects, 0);
assert_eq!(graph.objects.len(), 3);
assert_eq!(graph.edges.len(), 1);
assert_eq!(graph.edges[0].from, 0);
assert_eq!(graph.edges[0].to, 1);
}
#[test]
fn test_scene_graph_builder_merge() {
let builder = SceneGraphBuilder::new(10.0, 256);
let a = SceneGraph::new(
vec![SceneObject::new(0, [0.0, 0.0, 0.0], [0.5, 0.5, 0.5])],
vec![],
100,
);
let b = SceneGraph::new(
vec![SceneObject::new(0, [1.0, 0.0, 0.0], [0.5, 0.5, 0.5])],
vec![],
200,
);
let merged = builder.merge(&a, &b);
assert_eq!(merged.objects.len(), 2);
assert_eq!(merged.timestamp, 200);
assert!(!merged.edges.is_empty());
}
// -- PerceptionPipeline.process (config-driven) -------------------------
#[test]
fn test_perception_pipeline_process() {
let config = PerceptionConfig::default();
let mut pipeline = PerceptionPipeline::new(config);
let cloud = make_cloud(&[
[1.0, 0.0, 0.0],
[1.1, 0.0, 0.0],
[1.2, 0.0, 0.0],
[5.0, 5.0, 0.0],
[5.1, 5.0, 0.0],
[5.2, 5.0, 0.0],
]);
let (obstacles, graph) = pipeline.process(&cloud, &[0.0, 0.0, 0.0]);
assert!(!obstacles.is_empty());
assert!(!graph.objects.is_empty());
assert_eq!(pipeline.frames_processed(), 1);
}
// -- detect_obstacles ---------------------------------------------------
#[test]
fn test_detect_obstacles_empty() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let cloud = PointCloud::default();
let result = pipe.detect_obstacles(&cloud, [0.0; 3], 10.0).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_detect_obstacles_single_cluster() {
let pipe = PerceptionPipeline::with_thresholds(1.0, 2.0);
let cloud = make_cloud(&[
[2.0, 0.0, 0.0],
[2.1, 0.0, 0.0],
[2.0, 0.1, 0.0],
]);
let obs = pipe.detect_obstacles(&cloud, [0.0; 3], 10.0).unwrap();
assert_eq!(obs.len(), 1);
assert!(obs[0].distance > 1.0);
assert!(obs[0].distance < 3.0);
assert!(!obs[0].label.is_empty());
}
#[test]
fn test_detect_obstacles_filters_distant() {
let pipe = PerceptionPipeline::with_thresholds(1.0, 2.0);
let cloud = make_cloud(&[
[50.0, 0.0, 0.0],
[50.1, 0.0, 0.0],
[50.0, 0.1, 0.0],
]);
let obs = pipe.detect_obstacles(&cloud, [0.0; 3], 5.0).unwrap();
assert!(obs.is_empty());
}
// -- build_scene_graph --------------------------------------------------
#[test]
fn test_build_scene_graph_basic() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let objects = vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(1, [2.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
];
let graph = pipe.build_scene_graph(&objects, 5.0).unwrap();
assert_eq!(graph.objects.len(), 2);
assert_eq!(graph.edges.len(), 1);
}
#[test]
fn test_build_scene_graph_invalid_distance() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let result = pipe.build_scene_graph(&[], -1.0);
assert!(result.is_err());
}
#[test]
fn test_build_scene_graph_no_edges() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let objects = vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(1, [100.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
];
let graph = pipe.build_scene_graph(&objects, 5.0).unwrap();
assert!(graph.edges.is_empty());
}
// -- predict_trajectory -------------------------------------------------
#[test]
fn test_predict_trajectory_linear() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let traj = pipe
.predict_trajectory([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], 3, 1.0)
.unwrap();
assert_eq!(traj.len(), 3);
assert!((traj.waypoints[0][0] - 1.0).abs() < 1e-9);
assert!((traj.waypoints[1][0] - 2.0).abs() < 1e-9);
assert!((traj.waypoints[2][0] - 3.0).abs() < 1e-9);
}
#[test]
fn test_predict_trajectory_zero_steps() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let result = pipe.predict_trajectory([0.0; 3], [1.0, 0.0, 0.0], 0, 1.0);
assert!(result.is_err());
}
#[test]
fn test_predict_trajectory_negative_dt() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let result = pipe.predict_trajectory([0.0; 3], [1.0, 0.0, 0.0], 5, -0.1);
assert!(result.is_err());
}
// -- focus_attention ----------------------------------------------------
#[test]
fn test_focus_attention_filters() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let cloud = make_cloud(&[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[10.0, 0.0, 0.0],
]);
let focused = pipe
.focus_attention(&cloud, [0.0, 0.0, 0.0], 2.0)
.unwrap();
assert_eq!(focused.len(), 2);
}
#[test]
fn test_focus_attention_invalid_radius() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let cloud = PointCloud::default();
let result = pipe.focus_attention(&cloud, [0.0; 3], -1.0);
assert!(result.is_err());
}
// -- detect_anomalies ---------------------------------------------------
#[test]
fn test_detect_anomalies_outlier() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let mut pts: Vec<[f32; 3]> =
(0..20).map(|i| [i as f32 * 0.1, 0.0, 0.0]).collect();
pts.push([100.0, 100.0, 100.0]);
let cloud = make_cloud(&pts);
let anomalies = pipe.detect_anomalies(&cloud).unwrap();
assert!(!anomalies.is_empty());
assert!(anomalies.iter().any(|a| a.score > 2.0));
}
#[test]
fn test_detect_anomalies_no_outliers() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let cloud = make_cloud(&[
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
]);
let anomalies = pipe.detect_anomalies(&cloud).unwrap();
assert!(anomalies.is_empty());
}
#[test]
fn test_detect_anomalies_small_cloud() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let cloud = make_cloud(&[[1.0, 0.0, 0.0]]);
let anomalies = pipe.detect_anomalies(&cloud).unwrap();
assert!(anomalies.is_empty());
}
// -- edge cases & integration ------------------------------------------
#[test]
fn test_pipeline_debug() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let dbg = format!("{:?}", pipe);
assert!(dbg.contains("PerceptionPipeline"));
}
#[test]
fn test_scene_graph_edge_relations() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let objects = vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(1, [1.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(2, [6.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(3, [9.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
];
let graph = pipe.build_scene_graph(&objects, 10.0).unwrap();
let adj = graph.edges.iter().find(|e| e.from == 0 && e.to == 1);
assert!(adj.is_some());
assert_eq!(adj.unwrap().relation, "adjacent");
}
#[test]
fn test_trajectory_timestamps_are_microseconds() {
let pipe = PerceptionPipeline::with_thresholds(0.5, 2.0);
let traj = pipe
.predict_trajectory([0.0; 3], [1.0, 0.0, 0.0], 2, 0.5)
.unwrap();
// 0.5s = 500_000 us, 1.0s = 1_000_000 us
assert_eq!(traj.timestamps[0], 500_000);
assert_eq!(traj.timestamps[1], 1_000_000);
}
}

View File

@@ -0,0 +1,376 @@
//! Obstacle detection from point clouds.
//!
//! Uses spatial-hash clustering to group nearby points into obstacle
//! candidates, then filters and classifies them based on geometry.
use crate::bridge::{Point3D, PointCloud};
use crate::perception::clustering;
use crate::perception::config::ObstacleConfig;
// ---------------------------------------------------------------------------
// Public types
// ---------------------------------------------------------------------------
/// Classification category for an obstacle.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ObstacleClass {
/// Obstacle appears wall-like / elongated in at least one axis.
Static,
/// Compact obstacle that could be a moving object.
Dynamic,
/// Cannot determine class from geometry alone.
Unknown,
}
/// Raw detection result before classification.
#[derive(Debug, Clone)]
pub struct DetectedObstacle {
/// Centroid of the cluster.
pub center: [f64; 3],
/// Axis-aligned bounding-box half-extents.
pub extent: [f64; 3],
/// Number of points in the cluster.
pub point_count: usize,
/// Closest distance from the cluster centroid to the robot.
pub min_distance: f64,
}
/// A detected obstacle with an attached classification.
#[derive(Debug, Clone)]
pub struct ClassifiedObstacle {
pub obstacle: DetectedObstacle,
pub class: ObstacleClass,
pub confidence: f32,
}
// ---------------------------------------------------------------------------
// Detector
// ---------------------------------------------------------------------------
/// Detects and classifies obstacles from point-cloud data.
#[derive(Debug, Clone)]
pub struct ObstacleDetector {
config: ObstacleConfig,
}
impl ObstacleDetector {
/// Create a new detector with the given configuration.
pub fn new(config: ObstacleConfig) -> Self {
Self { config }
}
/// Detect obstacles in a point cloud relative to a robot position.
///
/// The algorithm:
/// 1. Discretise points into a spatial hash grid (cell size =
/// `safety_margin * 5`).
/// 2. Group cells using a simple flood-fill on the 26-neighbourhood.
/// 3. Filter clusters smaller than `min_obstacle_size`.
/// 4. Compute bounding box and centroid per cluster.
/// 5. Filter by `max_detection_range` from the robot.
/// 6. Sort results by distance (ascending).
pub fn detect(
&self,
cloud: &PointCloud,
robot_pos: &[f64; 3],
) -> Vec<DetectedObstacle> {
if cloud.is_empty() {
return Vec::new();
}
let cell_size = (self.config.safety_margin * 5.0).max(0.5);
let clusters = clustering::cluster_point_cloud(cloud, cell_size);
let mut obstacles: Vec<DetectedObstacle> = clusters
.into_iter()
.filter(|pts| pts.len() >= self.config.min_obstacle_size)
.filter_map(|pts| self.cluster_to_obstacle(&pts, robot_pos))
.filter(|o| o.min_distance <= self.config.max_detection_range)
.collect();
obstacles.sort_by(|a, b| {
a.min_distance
.partial_cmp(&b.min_distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
obstacles
}
/// Classify a list of detected obstacles using simple geometric
/// heuristics.
///
/// * **Static** -- the ratio of the largest to smallest extent is > 3
/// (wall-like).
/// * **Dynamic** -- the largest-to-smallest ratio is <= 2 (compact).
/// * **Unknown** -- everything else.
pub fn classify_obstacles(
&self,
obstacles: &[DetectedObstacle],
) -> Vec<ClassifiedObstacle> {
obstacles
.iter()
.map(|o| {
let (class, confidence) = self.classify_single(o);
ClassifiedObstacle {
obstacle: o.clone(),
class,
confidence,
}
})
.collect()
}
// -- private helpers ----------------------------------------------------
fn cluster_to_obstacle(
&self,
points: &[Point3D],
robot_pos: &[f64; 3],
) -> Option<DetectedObstacle> {
if points.is_empty() {
return None;
}
let (mut min_x, mut min_y, mut min_z) = (f64::MAX, f64::MAX, f64::MAX);
let (mut max_x, mut max_y, mut max_z) = (f64::MIN, f64::MIN, f64::MIN);
let (mut sum_x, mut sum_y, mut sum_z) = (0.0_f64, 0.0_f64, 0.0_f64);
for p in points {
let (px, py, pz) = (p.x as f64, p.y as f64, p.z as f64);
min_x = min_x.min(px);
min_y = min_y.min(py);
min_z = min_z.min(pz);
max_x = max_x.max(px);
max_y = max_y.max(py);
max_z = max_z.max(pz);
sum_x += px;
sum_y += py;
sum_z += pz;
}
let n = points.len() as f64;
let center = [sum_x / n, sum_y / n, sum_z / n];
let extent = [
(max_x - min_x) / 2.0 + self.config.safety_margin,
(max_y - min_y) / 2.0 + self.config.safety_margin,
(max_z - min_z) / 2.0 + self.config.safety_margin,
];
let dist = ((center[0] - robot_pos[0]).powi(2)
+ (center[1] - robot_pos[1]).powi(2)
+ (center[2] - robot_pos[2]).powi(2))
.sqrt();
Some(DetectedObstacle {
center,
extent,
point_count: points.len(),
min_distance: dist,
})
}
fn classify_single(&self, obstacle: &DetectedObstacle) -> (ObstacleClass, f32) {
let exts = &obstacle.extent;
let max_ext = exts[0].max(exts[1]).max(exts[2]);
let min_ext = exts[0].min(exts[1]).min(exts[2]);
if min_ext < f64::EPSILON {
return (ObstacleClass::Unknown, 0.3);
}
let ratio = max_ext / min_ext;
if ratio > 3.0 {
// Elongated -- likely a wall or static structure.
let confidence = (ratio / 10.0).min(1.0) as f32;
(ObstacleClass::Static, confidence.max(0.6))
} else if ratio <= 2.0 {
// Compact -- possibly a moving object.
let confidence = (1.0 - (ratio - 1.0) / 2.0).max(0.5) as f32;
(ObstacleClass::Dynamic, confidence)
} else {
(ObstacleClass::Unknown, 0.4)
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn make_cloud(raw: &[[f32; 3]]) -> PointCloud {
let points: Vec<Point3D> = raw.iter().map(|p| Point3D::new(p[0], p[1], p[2])).collect();
PointCloud::new(points, 0)
}
#[test]
fn test_detect_empty_cloud() {
let det = ObstacleDetector::new(ObstacleConfig::default());
let cloud = PointCloud::default();
let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
assert!(result.is_empty());
}
#[test]
fn test_detect_single_cluster() {
let det = ObstacleDetector::new(ObstacleConfig {
min_obstacle_size: 3,
max_detection_range: 100.0,
safety_margin: 0.1,
});
let cloud = make_cloud(&[
[1.0, 1.0, 0.0],
[1.1, 1.0, 0.0],
[1.0, 1.1, 0.0],
[1.1, 1.1, 0.0],
]);
let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
assert_eq!(result.len(), 1);
assert!(result[0].min_distance > 0.0);
assert_eq!(result[0].point_count, 4);
}
#[test]
fn test_detect_filters_by_range() {
let det = ObstacleDetector::new(ObstacleConfig {
min_obstacle_size: 3,
max_detection_range: 1.0,
safety_margin: 0.1,
});
// Cluster at ~10 units away -- should be filtered out.
let cloud = make_cloud(&[
[10.0, 0.0, 0.0],
[10.1, 0.0, 0.0],
[10.0, 0.1, 0.0],
]);
let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
assert!(result.is_empty());
}
#[test]
fn test_detect_filters_small_clusters() {
let det = ObstacleDetector::new(ObstacleConfig {
min_obstacle_size: 5,
max_detection_range: 100.0,
safety_margin: 0.1,
});
// Only 3 points -- below minimum.
let cloud = make_cloud(&[
[1.0, 1.0, 0.0],
[1.1, 1.0, 0.0],
[1.0, 1.1, 0.0],
]);
let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
assert!(result.is_empty());
}
#[test]
fn test_detect_sorted_by_distance() {
let det = ObstacleDetector::new(ObstacleConfig {
min_obstacle_size: 3,
max_detection_range: 100.0,
safety_margin: 0.1,
});
let cloud = make_cloud(&[
// Far cluster
[10.0, 0.0, 0.0],
[10.1, 0.0, 0.0],
[10.0, 0.1, 0.0],
// Near cluster
[1.0, 0.0, 0.0],
[1.1, 0.0, 0.0],
[1.0, 0.1, 0.0],
]);
let result = det.detect(&cloud, &[0.0, 0.0, 0.0]);
assert!(result.len() >= 1);
if result.len() >= 2 {
assert!(result[0].min_distance <= result[1].min_distance);
}
}
#[test]
fn test_classify_static_obstacle() {
let det = ObstacleDetector::new(ObstacleConfig::default());
// Wall-like: very elongated in X, thin in Y and Z.
let obstacle = DetectedObstacle {
center: [5.0, 0.0, 0.0],
extent: [10.0, 0.5, 0.5],
point_count: 50,
min_distance: 5.0,
};
let classified = det.classify_obstacles(&[obstacle]);
assert_eq!(classified.len(), 1);
assert_eq!(classified[0].class, ObstacleClass::Static);
assert!(classified[0].confidence >= 0.5);
}
#[test]
fn test_classify_dynamic_obstacle() {
let det = ObstacleDetector::new(ObstacleConfig::default());
// Compact: roughly equal extents.
let obstacle = DetectedObstacle {
center: [3.0, 0.0, 0.0],
extent: [1.0, 1.0, 1.0],
point_count: 20,
min_distance: 3.0,
};
let classified = det.classify_obstacles(&[obstacle]);
assert_eq!(classified.len(), 1);
assert_eq!(classified[0].class, ObstacleClass::Dynamic);
}
#[test]
fn test_classify_unknown_obstacle() {
let det = ObstacleDetector::new(ObstacleConfig::default());
// Intermediate ratio.
let obstacle = DetectedObstacle {
center: [5.0, 0.0, 0.0],
extent: [3.0, 1.1, 1.0],
point_count: 15,
min_distance: 5.0,
};
let classified = det.classify_obstacles(&[obstacle]);
assert_eq!(classified.len(), 1);
assert_eq!(classified[0].class, ObstacleClass::Unknown);
}
#[test]
fn test_classify_empty_list() {
let det = ObstacleDetector::new(ObstacleConfig::default());
let classified = det.classify_obstacles(&[]);
assert!(classified.is_empty());
}
#[test]
fn test_obstacle_detector_debug() {
let det = ObstacleDetector::new(ObstacleConfig::default());
let dbg = format!("{:?}", det);
assert!(dbg.contains("ObstacleDetector"));
}
#[test]
fn test_detect_two_separated_clusters() {
let det = ObstacleDetector::new(ObstacleConfig {
min_obstacle_size: 3,
max_detection_range: 200.0,
safety_margin: 0.1,
});
let cloud = make_cloud(&[
// Cluster A around (0, 0, 0)
[0.0, 0.0, 0.0],
[0.1, 0.0, 0.0],
[0.0, 0.1, 0.0],
// Cluster B around (100, 100, 0) -- very far away
[100.0, 100.0, 0.0],
[100.1, 100.0, 0.0],
[100.0, 100.1, 0.0],
]);
let result = det.detect(&cloud, &[50.0, 50.0, 0.0]);
assert_eq!(result.len(), 2);
}
}

View File

@@ -0,0 +1,373 @@
//! Scene-graph construction from point clouds and object lists.
//!
//! The [`SceneGraphBuilder`] turns raw sensor data into a structured
//! [`SceneGraph`] of objects and spatial relationships.
use std::collections::HashMap;
use crate::bridge::{Point3D, PointCloud, SceneEdge, SceneGraph, SceneObject};
use crate::perception::clustering;
use crate::perception::config::SceneGraphConfig;
// ---------------------------------------------------------------------------
// Builder
// ---------------------------------------------------------------------------
/// Builds [`SceneGraph`] instances from point clouds or pre-classified
/// object lists using spatial-hash clustering with union-find.
#[derive(Debug, Clone)]
pub struct PointCloudSceneGraphBuilder {
config: SceneGraphConfig,
}
impl PointCloudSceneGraphBuilder {
/// Create a new builder with the given configuration.
pub fn new(config: SceneGraphConfig) -> Self {
Self { config }
}
/// Build a scene graph by clustering a raw point cloud.
///
/// 1. Points are discretised into a spatial hash grid.
/// 2. Adjacent cells are merged via union-find.
/// 3. Each cluster above `min_cluster_size` becomes a `SceneObject`.
/// 4. Edges are created between objects whose centres are within
/// `edge_distance_threshold`.
pub fn build_from_point_cloud(&self, cloud: &PointCloud) -> SceneGraph {
if cloud.is_empty() {
return SceneGraph::default();
}
let clusters = clustering::cluster_point_cloud(cloud, self.config.cluster_radius);
// Convert clusters to SceneObjects (cap at max_objects).
let mut objects: Vec<SceneObject> = clusters
.into_iter()
.filter(|pts| pts.len() >= self.config.min_cluster_size)
.take(self.config.max_objects)
.enumerate()
.map(|(id, pts)| Self::cluster_to_object(id, &pts))
.collect();
objects.sort_by(|a, b| a.id.cmp(&b.id));
let edges = self.create_edges(&objects);
SceneGraph::new(objects, edges, cloud.timestamp_us)
}
/// Build a scene graph from a pre-existing list of objects.
///
/// Edges are created between objects whose centres are within
/// `edge_distance_threshold`.
pub fn build_from_objects(&self, objects: &[SceneObject]) -> SceneGraph {
let objects_vec: Vec<SceneObject> = objects
.iter()
.take(self.config.max_objects)
.cloned()
.collect();
let edges = self.create_edges(&objects_vec);
SceneGraph::new(objects_vec, edges, 0)
}
/// Merge multiple scene graphs into one, deduplicating objects that share
/// the same `id`.
pub fn merge_scenes(&self, scenes: &[SceneGraph]) -> SceneGraph {
let mut seen_ids: HashMap<usize, SceneObject> = HashMap::new();
let mut latest_ts: i64 = 0;
for scene in scenes {
latest_ts = latest_ts.max(scene.timestamp);
for obj in &scene.objects {
// Keep the first occurrence of each id.
seen_ids.entry(obj.id).or_insert_with(|| obj.clone());
}
}
let mut objects: Vec<SceneObject> = seen_ids.into_values().collect();
objects.sort_by(|a, b| a.id.cmp(&b.id));
let truncated: Vec<SceneObject> = objects
.into_iter()
.take(self.config.max_objects)
.collect();
let edges = self.create_edges(&truncated);
SceneGraph::new(truncated, edges, latest_ts)
}
// -- private helpers ----------------------------------------------------
fn cluster_to_object(id: usize, points: &[Point3D]) -> SceneObject {
debug_assert!(!points.is_empty(), "cluster_to_object called with empty slice");
let (mut min_x, mut min_y, mut min_z) = (f64::MAX, f64::MAX, f64::MAX);
let (mut max_x, mut max_y, mut max_z) = (f64::MIN, f64::MIN, f64::MIN);
let (mut sum_x, mut sum_y, mut sum_z) = (0.0_f64, 0.0_f64, 0.0_f64);
for p in points {
let (px, py, pz) = (p.x as f64, p.y as f64, p.z as f64);
min_x = min_x.min(px);
min_y = min_y.min(py);
min_z = min_z.min(pz);
max_x = max_x.max(px);
max_y = max_y.max(py);
max_z = max_z.max(pz);
sum_x += px;
sum_y += py;
sum_z += pz;
}
let n = points.len() as f64;
let center = [sum_x / n, sum_y / n, sum_z / n];
let extent = [
(max_x - min_x) / 2.0,
(max_y - min_y) / 2.0,
(max_z - min_z) / 2.0,
];
SceneObject {
id,
center,
extent,
confidence: 1.0,
label: format!("cluster_{}", id),
velocity: None,
}
}
fn create_edges(&self, objects: &[SceneObject]) -> Vec<SceneEdge> {
let mut edges = Vec::new();
let threshold = self.config.edge_distance_threshold;
for i in 0..objects.len() {
for j in (i + 1)..objects.len() {
let d = Self::distance_3d(&objects[i].center, &objects[j].center);
if d <= threshold {
let relation = if d < threshold * 0.33 {
"adjacent"
} else if d < threshold * 0.66 {
"near"
} else {
"far"
};
edges.push(SceneEdge {
from: objects[i].id,
to: objects[j].id,
distance: d,
relation: relation.to_string(),
});
}
}
}
edges
}
fn distance_3d(a: &[f64; 3], b: &[f64; 3]) -> f64 {
((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2)).sqrt()
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn make_cloud(raw: &[[f32; 3]]) -> PointCloud {
let points: Vec<Point3D> = raw.iter().map(|p| Point3D::new(p[0], p[1], p[2])).collect();
PointCloud::new(points, 1000)
}
#[test]
fn test_empty_cloud() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig::default());
let graph = builder.build_from_point_cloud(&PointCloud::default());
assert!(graph.objects.is_empty());
assert!(graph.edges.is_empty());
}
#[test]
fn test_single_cluster() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig {
cluster_radius: 1.0,
min_cluster_size: 3,
max_objects: 10,
edge_distance_threshold: 5.0,
});
let cloud = make_cloud(&[
[0.0, 0.0, 0.0],
[0.1, 0.0, 0.0],
[0.0, 0.1, 0.0],
[0.1, 0.1, 0.0],
]);
let graph = builder.build_from_point_cloud(&cloud);
assert_eq!(graph.objects.len(), 1);
assert!(graph.edges.is_empty()); // Only one object, no edges.
}
#[test]
fn test_room_point_cloud() {
// Simulate a room with two walls (clusters far apart).
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig {
cluster_radius: 0.5,
min_cluster_size: 3,
max_objects: 10,
edge_distance_threshold: 50.0,
});
let mut points = Vec::new();
// Wall 1: cluster around (0, 0, 0)
for i in 0..5 {
points.push([i as f32 * 0.1, 0.0, 0.0]);
}
// Wall 2: cluster around (10, 0, 0)
for i in 0..5 {
points.push([10.0 + i as f32 * 0.1, 0.0, 0.0]);
}
let cloud = make_cloud(&points);
let graph = builder.build_from_point_cloud(&cloud);
assert_eq!(graph.objects.len(), 2);
// Both walls should be connected since threshold is 50.
assert!(!graph.edges.is_empty());
}
#[test]
fn test_separated_clusters_no_edge() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig {
cluster_radius: 0.5,
min_cluster_size: 3,
max_objects: 10,
edge_distance_threshold: 2.0,
});
let cloud = make_cloud(&[
// Cluster A
[0.0, 0.0, 0.0],
[0.1, 0.0, 0.0],
[0.0, 0.1, 0.0],
// Cluster B -- far away (100 units)
[100.0, 0.0, 0.0],
[100.1, 0.0, 0.0],
[100.0, 0.1, 0.0],
]);
let graph = builder.build_from_point_cloud(&cloud);
assert_eq!(graph.objects.len(), 2);
// Should NOT have edges -- clusters are 100 units apart, threshold is 2.
assert!(graph.edges.is_empty());
}
#[test]
fn test_build_from_objects() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig {
edge_distance_threshold: 5.0,
..SceneGraphConfig::default()
});
let objects = vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(1, [3.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(2, [100.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
];
let graph = builder.build_from_objects(&objects);
assert_eq!(graph.objects.len(), 3);
// Objects 0 and 1 are 3.0 apart (within threshold),
// object 2 is 100.0 away (outside threshold).
assert_eq!(graph.edges.len(), 1);
assert_eq!(graph.edges[0].from, 0);
assert_eq!(graph.edges[0].to, 1);
}
#[test]
fn test_merge_deduplication() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig {
edge_distance_threshold: 10.0,
..SceneGraphConfig::default()
});
let scene_a = SceneGraph::new(
vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(1, [2.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
],
vec![],
100,
);
let scene_b = SceneGraph::new(
vec![
SceneObject::new(1, [2.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // duplicate id
SceneObject::new(2, [4.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
],
vec![],
200,
);
let merged = builder.merge_scenes(&[scene_a, scene_b]);
// Should have 3 unique objects: ids 0, 1, 2.
assert_eq!(merged.objects.len(), 3);
assert_eq!(merged.timestamp, 200);
}
#[test]
fn test_merge_preserves_latest_timestamp() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig::default());
let s1 = SceneGraph::new(vec![], vec![], 50);
let s2 = SceneGraph::new(vec![], vec![], 300);
let s3 = SceneGraph::new(vec![], vec![], 100);
let merged = builder.merge_scenes(&[s1, s2, s3]);
assert_eq!(merged.timestamp, 300);
}
#[test]
fn test_edge_relations() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig {
edge_distance_threshold: 30.0,
..SceneGraphConfig::default()
});
let objects = vec![
SceneObject::new(0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
SceneObject::new(1, [5.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // ~5 < 9.9 => adjacent
SceneObject::new(2, [15.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // ~15 < 19.8 => near
SceneObject::new(3, [25.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // ~25 < 30 => far
];
let graph = builder.build_from_objects(&objects);
// Check that adjacent relation exists for objects 0 and 1.
let edge_0_1 = graph
.edges
.iter()
.find(|e| e.from == 0 && e.to == 1);
assert!(edge_0_1.is_some());
assert_eq!(edge_0_1.unwrap().relation, "adjacent");
}
#[test]
fn test_max_objects_cap() {
let builder = PointCloudSceneGraphBuilder::new(SceneGraphConfig {
cluster_radius: 0.5,
min_cluster_size: 1,
max_objects: 2,
edge_distance_threshold: 100.0,
});
let cloud = make_cloud(&[
[0.0, 0.0, 0.0],
[50.0, 0.0, 0.0],
[100.0, 0.0, 0.0],
]);
let graph = builder.build_from_point_cloud(&cloud);
// min_cluster_size=1, so each point is its own cluster.
// max_objects=2, so at most 2 objects.
assert!(graph.objects.len() <= 2);
}
}

View File

@@ -0,0 +1,189 @@
//! Multi-sensor point cloud fusion.
//!
//! Aligns and merges point clouds from multiple sensors into a single
//! unified cloud, using nearest-timestamp matching and optional confidence
//! weighting.
use crate::bridge::{Point3D, PointCloud};
use serde::{Deserialize, Serialize};
/// Configuration for the sensor fusion module.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
/// Maximum timestamp delta (µs) for two frames to be considered
/// synchronised. Frames further apart are discarded.
pub max_time_delta_us: i64,
/// Whether to apply confidence weighting based on point density.
pub density_weighting: bool,
/// Minimum voxel size for down-sampling the fused cloud. Set to 0.0
/// to disable.
pub voxel_size: f64,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
max_time_delta_us: 50_000, // 50 ms
density_weighting: false,
voxel_size: 0.0,
}
}
}
/// Fuse multiple point clouds into a single unified cloud.
///
/// Clouds whose timestamps are further than `config.max_time_delta_us` from
/// the *reference* (first cloud in the slice) are skipped. When
/// `voxel_size > 0`, the merged cloud is down-sampled via voxel-grid
/// filtering.
pub fn fuse_clouds(clouds: &[PointCloud], config: &FusionConfig) -> PointCloud {
if clouds.is_empty() {
return PointCloud::default();
}
let reference_ts = clouds[0].timestamp_us;
// Pre-allocate merged vectors based on total point count of eligible clouds.
let total_cap: usize = clouds
.iter()
.filter(|c| (c.timestamp_us - reference_ts).abs() <= config.max_time_delta_us)
.map(|c| c.points.len())
.sum();
let mut merged_points: Vec<Point3D> = Vec::with_capacity(total_cap);
let mut merged_intensities: Vec<f32> = Vec::with_capacity(total_cap);
for cloud in clouds {
let dt = (cloud.timestamp_us - reference_ts).abs();
if dt > config.max_time_delta_us {
continue;
}
merged_points.extend_from_slice(&cloud.points);
if config.density_weighting && !cloud.is_empty() {
let weight = 1.0 / (cloud.points.len() as f32).sqrt();
merged_intensities.extend(cloud.intensities.iter().map(|i| i * weight));
} else {
merged_intensities.extend_from_slice(&cloud.intensities);
}
}
if config.voxel_size > 0.0 && !merged_points.is_empty() {
let (dp, di) = voxel_downsample(&merged_points, &merged_intensities, config.voxel_size);
merged_points = dp;
merged_intensities = di;
}
let mut result = PointCloud::new(merged_points, reference_ts);
result.intensities = merged_intensities;
result
}
/// Simple voxel grid down-sampling: keep one representative point per voxel.
fn voxel_downsample(
points: &[Point3D],
intensities: &[f32],
cell_size: f64,
) -> (Vec<Point3D>, Vec<f32>) {
use std::collections::HashMap;
let mut voxels: HashMap<(i64, i64, i64), (Point3D, f32, usize)> = HashMap::new();
for (i, p) in points.iter().enumerate() {
let key = (
(p.x as f64 / cell_size).floor() as i64,
(p.y as f64 / cell_size).floor() as i64,
(p.z as f64 / cell_size).floor() as i64,
);
let intensity = intensities.get(i).copied().unwrap_or(1.0);
let entry = voxels.entry(key).or_insert((*p, intensity, 0));
entry.2 += 1;
// Running average position.
let n = entry.2 as f32;
entry.0.x = entry.0.x + (p.x - entry.0.x) / n;
entry.0.y = entry.0.y + (p.y - entry.0.y) / n;
entry.0.z = entry.0.z + (p.z - entry.0.z) / n;
entry.1 = entry.1 + (intensity - entry.1) / n;
}
let mut out_pts = Vec::with_capacity(voxels.len());
let mut out_int = Vec::with_capacity(voxels.len());
for (pt, inten, _) in voxels.into_values() {
out_pts.push(pt);
out_int.push(inten);
}
(out_pts, out_int)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cloud(pts: &[[f32; 3]], ts: i64) -> PointCloud {
let points: Vec<Point3D> = pts.iter().map(|a| Point3D::new(a[0], a[1], a[2])).collect();
PointCloud::new(points, ts)
}
#[test]
fn test_fuse_empty() {
let result = fuse_clouds(&[], &FusionConfig::default());
assert!(result.is_empty());
}
#[test]
fn test_fuse_single() {
let c = make_cloud(&[[1.0, 0.0, 0.0]], 1000);
let result = fuse_clouds(&[c], &FusionConfig::default());
assert_eq!(result.len(), 1);
}
#[test]
fn test_fuse_two_clouds() {
let c1 = make_cloud(&[[1.0, 0.0, 0.0]], 1000);
let c2 = make_cloud(&[[2.0, 0.0, 0.0]], 1010);
let result = fuse_clouds(&[c1, c2], &FusionConfig::default());
assert_eq!(result.len(), 2);
}
#[test]
fn test_fuse_skips_stale() {
let c1 = make_cloud(&[[1.0, 0.0, 0.0]], 0);
let c2 = make_cloud(&[[2.0, 0.0, 0.0]], 100_000); // 100ms apart
let config = FusionConfig { max_time_delta_us: 50_000, ..Default::default() };
let result = fuse_clouds(&[c1, c2], &config);
assert_eq!(result.len(), 1); // c2 skipped
}
#[test]
fn test_voxel_downsample() {
let c1 = make_cloud(
&[
[0.0, 0.0, 0.0], [0.01, 0.01, 0.01], // same voxel
[5.0, 5.0, 5.0], // different voxel
],
0,
);
let config = FusionConfig { voxel_size: 1.0, ..Default::default() };
let result = fuse_clouds(&[c1], &config);
assert_eq!(result.len(), 2);
}
#[test]
fn test_density_weighting() {
let c1 = make_cloud(&[[1.0, 0.0, 0.0]], 0);
let config = FusionConfig { density_weighting: true, ..Default::default() };
let result = fuse_clouds(&[c1], &config);
assert_eq!(result.len(), 1);
// With 1 point, weight = 1/sqrt(1) = 1.0, so intensity unchanged.
assert!((result.intensities[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_fuse_preserves_timestamp() {
let c1 = make_cloud(&[[1.0, 0.0, 0.0]], 5000);
let c2 = make_cloud(&[[2.0, 0.0, 0.0]], 5010);
let result = fuse_clouds(&[c1, c2], &FusionConfig::default());
assert_eq!(result.timestamp_us, 5000);
}
}

View File

@@ -0,0 +1,402 @@
//! Motion planning: A\* grid search and potential-field velocity commands.
//!
//! Operates on the [`OccupancyGrid`](crate::bridge::OccupancyGrid) type from
//! the bridge module. Two planners are provided:
//!
//! - [`astar`]: discrete A\* on the occupancy grid returning a cell path.
//! - [`potential_field`]: continuous-space repulsive/attractive field producing
//! a velocity command.
use crate::bridge::OccupancyGrid;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
/// A 2-D grid cell coordinate.
pub type Cell = (usize, usize);
/// Result of an A\* search.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GridPath {
/// Sequence of `(x, y)` cells from start to goal (inclusive).
pub cells: Vec<Cell>,
/// Total traversal cost.
pub cost: f64,
}
/// Errors from planning operations.
#[derive(Debug, thiserror::Error)]
pub enum PlanningError {
#[error("start cell ({0}, {1}) is out of bounds or occupied")]
InvalidStart(usize, usize),
#[error("goal cell ({0}, {1}) is out of bounds or occupied")]
InvalidGoal(usize, usize),
#[error("no feasible path found")]
NoPath,
}
pub type Result<T> = std::result::Result<T, PlanningError>;
// ---------------------------------------------------------------------------
// A* search
// ---------------------------------------------------------------------------
/// Occupancy value above which a cell is considered blocked.
const OCCUPIED_THRESHOLD: f32 = 0.5;
#[derive(PartialEq)]
struct AStarEntry {
cell: Cell,
f: f64,
}
impl Eq for AStarEntry {}
impl Ord for AStarEntry {
fn cmp(&self, other: &Self) -> Ordering {
other.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for AStarEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Run A\* on `grid`, returning the shortest [`GridPath`] from `start` to
/// `goal`. Cells with occupancy >= 0.5 are treated as impassable.
///
/// Diagonal moves cost √2, cardinal moves cost 1.
pub fn astar(
grid: &OccupancyGrid,
start: Cell,
goal: Cell,
) -> Result<GridPath> {
if !cell_free(grid, start) {
return Err(PlanningError::InvalidStart(start.0, start.1));
}
if !cell_free(grid, goal) {
return Err(PlanningError::InvalidGoal(goal.0, goal.1));
}
if start == goal {
return Ok(GridPath { cells: vec![start], cost: 0.0 });
}
let mut g_score: HashMap<Cell, f64> = HashMap::with_capacity(128);
let mut came_from: HashMap<Cell, Cell> = HashMap::with_capacity(128);
let mut open = BinaryHeap::new();
let mut closed: HashSet<Cell> = HashSet::with_capacity(128);
let mut neighbor_buf: Vec<(usize, usize, f64)> = Vec::with_capacity(8);
g_score.insert(start, 0.0);
open.push(AStarEntry { cell: start, f: heuristic(start, goal) });
while let Some(AStarEntry { cell, .. }) = open.pop() {
if cell == goal {
return Ok(reconstruct_path(&came_from, goal, &g_score));
}
// Skip already-expanded nodes (avoids re-expansion from stale heap entries).
if !closed.insert(cell) {
continue;
}
let current_g = g_score[&cell];
neighbors_into(grid, cell, &mut neighbor_buf);
for &(nx, ny, step_cost) in &neighbor_buf {
let neighbor = (nx, ny);
if closed.contains(&neighbor) {
continue;
}
let tentative_g = current_g + step_cost;
if tentative_g < *g_score.get(&neighbor).unwrap_or(&f64::INFINITY) {
g_score.insert(neighbor, tentative_g);
came_from.insert(neighbor, cell);
open.push(AStarEntry {
cell: neighbor,
f: tentative_g + heuristic(neighbor, goal),
});
}
}
}
Err(PlanningError::NoPath)
}
#[inline]
fn cell_free(grid: &OccupancyGrid, (x, y): Cell) -> bool {
grid.get(x, y).is_some_and(|v| v < OCCUPIED_THRESHOLD)
}
#[inline]
fn heuristic(a: Cell, b: Cell) -> f64 {
let dx = (a.0 as f64 - b.0 as f64).abs();
let dy = (a.1 as f64 - b.1 as f64).abs();
// Octile distance.
let (min, max) = if dx < dy { (dx, dy) } else { (dy, dx) };
min * std::f64::consts::SQRT_2 + (max - min)
}
/// Write neighbours of `cell` into `out`, reusing the buffer to avoid
/// per-expansion heap allocation.
#[inline]
fn neighbors_into(grid: &OccupancyGrid, (cx, cy): Cell, out: &mut Vec<(usize, usize, f64)>) {
out.clear();
for dx in [-1_i64, 0, 1] {
for dy in [-1_i64, 0, 1] {
if dx == 0 && dy == 0 {
continue;
}
let nx = cx as i64 + dx;
let ny = cy as i64 + dy;
if nx < 0 || ny < 0 {
continue;
}
let (nx, ny) = (nx as usize, ny as usize);
if cell_free(grid, (nx, ny)) {
let cost = if dx != 0 && dy != 0 {
std::f64::consts::SQRT_2
} else {
1.0
};
out.push((nx, ny, cost));
}
}
}
}
fn reconstruct_path(
came_from: &HashMap<Cell, Cell>,
goal: Cell,
g_score: &HashMap<Cell, f64>,
) -> GridPath {
let mut cells = vec![goal];
let mut current = goal;
while let Some(&prev) = came_from.get(&current) {
cells.push(prev);
current = prev;
}
cells.reverse();
let cost = g_score.get(&goal).copied().unwrap_or(0.0);
GridPath { cells, cost }
}
// ---------------------------------------------------------------------------
// Potential field
// ---------------------------------------------------------------------------
/// Output of the potential field planner: a 3-D velocity command.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct VelocityCommand {
pub vx: f64,
pub vy: f64,
pub vz: f64,
}
/// Configuration for the potential field planner.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialFieldConfig {
/// Attractive gain toward the goal.
pub attractive_gain: f64,
/// Repulsive gain away from obstacles.
pub repulsive_gain: f64,
/// Influence range for obstacles (metres).
pub obstacle_influence: f64,
/// Maximum output speed (m/s).
pub max_speed: f64,
}
impl Default for PotentialFieldConfig {
fn default() -> Self {
Self {
attractive_gain: 1.0,
repulsive_gain: 100.0,
obstacle_influence: 3.0,
max_speed: 2.0,
}
}
}
/// Compute a velocity command using attractive + repulsive potential fields.
///
/// * `robot` — current robot position `[x, y, z]`.
/// * `goal` — target position `[x, y, z]`.
/// * `obstacles` — positions of nearby obstacles.
pub fn potential_field(
robot: &[f64; 3],
goal: &[f64; 3],
obstacles: &[[f64; 3]],
config: &PotentialFieldConfig,
) -> VelocityCommand {
// Attractive force: linear pull toward goal.
let mut fx = config.attractive_gain * (goal[0] - robot[0]);
let mut fy = config.attractive_gain * (goal[1] - robot[1]);
let mut fz = config.attractive_gain * (goal[2] - robot[2]);
// Repulsive force: push away from each obstacle within influence range.
for obs in obstacles {
let dx = robot[0] - obs[0];
let dy = robot[1] - obs[1];
let dz = robot[2] - obs[2];
let dist = (dx * dx + dy * dy + dz * dz).sqrt().max(0.01);
if dist < config.obstacle_influence {
let strength =
config.repulsive_gain * (1.0 / dist - 1.0 / config.obstacle_influence) / (dist * dist);
fx += strength * dx / dist;
fy += strength * dy / dist;
fz += strength * dz / dist;
}
}
// Clamp to max speed.
let speed = (fx * fx + fy * fy + fz * fz).sqrt();
if speed > config.max_speed {
let s = config.max_speed / speed;
fx *= s;
fy *= s;
fz *= s;
}
VelocityCommand { vx: fx, vy: fy, vz: fz }
}
// ---------------------------------------------------------------------------
// Path smoothing
// ---------------------------------------------------------------------------
/// Convert a [`GridPath`] (grid cells) to world-space waypoints using the
/// grid resolution and origin.
pub fn path_to_waypoints(path: &GridPath, resolution: f64, origin: &[f64; 3]) -> Vec<[f64; 3]> {
path.cells
.iter()
.map(|&(x, y)| {
[
origin[0] + x as f64 * resolution,
origin[1] + y as f64 * resolution,
origin[2],
]
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn free_grid(w: usize, h: usize) -> OccupancyGrid {
OccupancyGrid::new(w, h, 1.0)
}
#[test]
fn test_astar_straight_line() {
let grid = free_grid(10, 10);
let path = astar(&grid, (0, 0), (5, 0)).unwrap();
assert_eq!(*path.cells.first().unwrap(), (0, 0));
assert_eq!(*path.cells.last().unwrap(), (5, 0));
assert!((path.cost - 5.0).abs() < 1e-6);
}
#[test]
fn test_astar_diagonal() {
let grid = free_grid(10, 10);
let path = astar(&grid, (0, 0), (3, 3)).unwrap();
assert_eq!(*path.cells.last().unwrap(), (3, 3));
// Pure diagonal = 3 * sqrt(2) ≈ 4.24
assert!((path.cost - 3.0 * std::f64::consts::SQRT_2).abs() < 1e-6);
}
#[test]
fn test_astar_same_cell() {
let grid = free_grid(5, 5);
let path = astar(&grid, (2, 2), (2, 2)).unwrap();
assert_eq!(path.cells.len(), 1);
assert!((path.cost).abs() < 1e-9);
}
#[test]
fn test_astar_around_wall() {
let mut grid = free_grid(10, 10);
// Vertical wall at x=3 from y=0 to y=4.
for y in 0..5 {
grid.set(3, y, 1.0);
}
let path = astar(&grid, (1, 2), (5, 2)).unwrap();
assert_eq!(*path.cells.last().unwrap(), (5, 2));
// Path must go around the wall, so cost > 4 (straight line).
assert!(path.cost > 4.0);
}
#[test]
fn test_astar_blocked() {
let mut grid = free_grid(5, 5);
// Full wall across the grid.
for y in 0..5 {
grid.set(2, y, 1.0);
}
let result = astar(&grid, (0, 2), (4, 2));
assert!(result.is_err());
}
#[test]
fn test_astar_invalid_start() {
let grid = free_grid(5, 5);
let result = astar(&grid, (10, 10), (2, 2));
assert!(result.is_err());
}
#[test]
fn test_potential_field_towards_goal() {
let cmd = potential_field(
&[0.0, 0.0, 0.0],
&[5.0, 0.0, 0.0],
&[],
&PotentialFieldConfig::default(),
);
assert!(cmd.vx > 0.0);
assert!(cmd.vy.abs() < 1e-9);
}
#[test]
fn test_potential_field_obstacle_repulsion() {
let cmd = potential_field(
&[0.0, 0.0, 0.0],
&[5.0, 0.0, 0.0],
&[[1.0, 0.0, 0.0]],
&PotentialFieldConfig::default(),
);
// Obstacle directly ahead — repulsion should reduce forward velocity.
let cmd_no_obs = potential_field(
&[0.0, 0.0, 0.0],
&[5.0, 0.0, 0.0],
&[],
&PotentialFieldConfig::default(),
);
assert!(cmd.vx < cmd_no_obs.vx);
}
#[test]
fn test_potential_field_max_speed() {
let config = PotentialFieldConfig { max_speed: 1.0, ..Default::default() };
let cmd = potential_field(
&[0.0, 0.0, 0.0],
&[100.0, 100.0, 0.0],
&[],
&config,
);
let speed = (cmd.vx * cmd.vx + cmd.vy * cmd.vy + cmd.vz * cmd.vz).sqrt();
assert!((speed - 1.0).abs() < 1e-9);
}
#[test]
fn test_path_to_waypoints() {
let path = GridPath {
cells: vec![(0, 0), (1, 0), (2, 0)],
cost: 2.0,
};
let wps = path_to_waypoints(&path, 0.5, &[0.0, 0.0, 0.0]);
assert_eq!(wps.len(), 3);
assert!((wps[1][0] - 0.5).abs() < 1e-9);
assert!((wps[2][0] - 1.0).abs() < 1e-9);
}
}

View File

@@ -0,0 +1,530 @@
//! RVF packaging for robotics data.
//!
//! Bridges the robotics crate with the [RuVector Format](crate) so that
//! point clouds, scene graphs, episodic memory, Gaussian splats, and
//! occupancy grids can be persisted, queried, and transferred as `.rvf`
//! files.
//!
//! Requires the `rvf` feature flag.
//!
//! # Quick Start
//!
//! ```ignore
//! use ruvector_robotics::rvf::RoboticsRvf;
//!
//! let mut rvf = RoboticsRvf::create("scene.rvf", 3)?;
//! rvf.pack_point_cloud(&cloud)?;
//! rvf.pack_scene_graph(&graph)?;
//! let similar = rvf.query_nearest(&[1.0, 2.0, 3.0], 5)?;
//! rvf.close()?;
//! ```
use std::path::Path;
use rvf_runtime::options::DistanceMetric;
use rvf_runtime::{
IngestResult, QueryOptions, RvfOptions, RvfStore, SearchResult,
};
use crate::bridge::{
GaussianConfig, Obstacle, PointCloud, SceneGraph, SceneObject, Trajectory,
};
use crate::bridge::gaussian::{gaussians_from_cloud, GaussianSplatCloud};
// ---------------------------------------------------------------------------
// Errors
// ---------------------------------------------------------------------------
/// Errors from RVF packaging operations.
#[derive(Debug, thiserror::Error)]
pub enum RvfPackError {
#[error("rvf store error: {0}")]
Store(String),
#[error("empty data: {0}")]
EmptyData(&'static str),
#[error("dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch { expected: usize, got: usize },
#[error("io error: {0}")]
Io(#[from] std::io::Error),
}
impl From<rvf_types::RvfError> for RvfPackError {
fn from(e: rvf_types::RvfError) -> Self {
RvfPackError::Store(format!("{e:?}"))
}
}
pub type Result<T> = std::result::Result<T, RvfPackError>;
// ---------------------------------------------------------------------------
// ID generation
// ---------------------------------------------------------------------------
// ID generation is handled by the `next_id` counter in `RoboticsRvf`.
// ---------------------------------------------------------------------------
// RoboticsRvf
// ---------------------------------------------------------------------------
/// High-level wrapper that packages robotics data into an RVF file.
///
/// Each robotics type is mapped to a vector encoding:
///
/// | Type | Dimension | Encoding |
/// |------|-----------|----------|
/// | Point cloud | 3 per point | `[x, y, z]` |
/// | Scene object | 9 | `[cx, cy, cz, ex, ey, ez, conf, vx, vy]` |
/// | Trajectory waypoint | 3 per step | `[x, y, z]` |
/// | Gaussian splat | 7 | `[cx, cy, cz, r, g, b, opacity]` |
/// | Obstacle | 6 | `[px, py, pz, dist, radius, conf]` |
pub struct RoboticsRvf {
store: RvfStore,
dimension: u16,
next_id: u64,
}
impl RoboticsRvf {
/// Create a new `.rvf` file at `path` for robotics vector data.
///
/// `dimension` is the per-vector size (e.g. 3 for raw point clouds,
/// 7 for Gaussian splats, 9 for scene objects).
pub fn create<P: AsRef<Path>>(path: P, dimension: u16) -> Result<Self> {
let options = RvfOptions {
dimension,
metric: DistanceMetric::L2,
..Default::default()
};
let store = RvfStore::create(path.as_ref(), options)?;
Ok(Self { store, dimension, next_id: 1 })
}
/// Open an existing `.rvf` file for read-write access.
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let store = RvfStore::open(path.as_ref())?;
let dim = store.dimension();
Ok(Self { store, dimension: dim, next_id: 1_000_000 })
}
/// Open an existing `.rvf` file for read-only queries.
pub fn open_readonly<P: AsRef<Path>>(path: P) -> Result<Self> {
let store = RvfStore::open_readonly(path.as_ref())?;
let dim = store.dimension();
Ok(Self { store, dimension: dim, next_id: 0 })
}
/// Current store status.
pub fn status(&self) -> rvf_runtime::StoreStatus {
self.store.status()
}
/// The vector dimension this store was created with.
pub fn dimension(&self) -> u16 {
self.dimension
}
// -- packing ----------------------------------------------------------
/// Pack a [`PointCloud`] into the RVF store (dimension must be 3).
pub fn pack_point_cloud(&mut self, cloud: &PointCloud) -> Result<IngestResult> {
self.check_dim(3)?;
if cloud.is_empty() {
return Err(RvfPackError::EmptyData("point cloud is empty"));
}
let vectors: Vec<Vec<f32>> = cloud
.points
.iter()
.map(|p| vec![p.x, p.y, p.z])
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let ids: Vec<u64> = (0..cloud.len())
.map(|_| {
let id = self.next_id;
self.next_id += 1;
id
})
.collect();
Ok(self.store.ingest_batch(&refs, &ids, None)?)
}
/// Pack scene objects into the RVF store (dimension must be 9).
///
/// Each object is encoded as `[cx, cy, cz, ex, ey, ez, conf, vx, vy]`.
pub fn pack_scene_objects(&mut self, objects: &[SceneObject]) -> Result<IngestResult> {
self.check_dim(9)?;
if objects.is_empty() {
return Err(RvfPackError::EmptyData("no scene objects"));
}
let vectors: Vec<Vec<f32>> = objects
.iter()
.map(|o| {
let vel = o.velocity.unwrap_or([0.0; 3]);
vec![
o.center[0] as f32,
o.center[1] as f32,
o.center[2] as f32,
o.extent[0] as f32,
o.extent[1] as f32,
o.extent[2] as f32,
o.confidence,
vel[0] as f32,
vel[1] as f32,
]
})
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let ids: Vec<u64> = (0..objects.len())
.map(|_| {
let id = self.next_id;
self.next_id += 1;
id
})
.collect();
Ok(self.store.ingest_batch(&refs, &ids, None)?)
}
/// Pack a scene graph (objects only) into the RVF store (dimension 9).
pub fn pack_scene_graph(&mut self, graph: &SceneGraph) -> Result<IngestResult> {
self.pack_scene_objects(&graph.objects)
}
/// Pack trajectory waypoints (dimension must be 3).
pub fn pack_trajectory(&mut self, trajectory: &Trajectory) -> Result<IngestResult> {
self.check_dim(3)?;
if trajectory.is_empty() {
return Err(RvfPackError::EmptyData("trajectory is empty"));
}
let vectors: Vec<Vec<f32>> = trajectory
.waypoints
.iter()
.map(|wp| vec![wp[0] as f32, wp[1] as f32, wp[2] as f32])
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let ids: Vec<u64> = (0..trajectory.len())
.map(|_| {
let id = self.next_id;
self.next_id += 1;
id
})
.collect();
Ok(self.store.ingest_batch(&refs, &ids, None)?)
}
/// Convert a point cloud to Gaussian splats and pack them (dimension 7).
///
/// Each Gaussian is encoded as `[cx, cy, cz, r, g, b, opacity]`.
pub fn pack_gaussians(
&mut self,
cloud: &PointCloud,
config: &GaussianConfig,
) -> Result<(GaussianSplatCloud, IngestResult)> {
self.check_dim(7)?;
let splat_cloud = gaussians_from_cloud(cloud, config);
if splat_cloud.is_empty() {
return Err(RvfPackError::EmptyData("no Gaussian splats produced"));
}
let vectors: Vec<Vec<f32>> = splat_cloud
.gaussians
.iter()
.map(|g| {
vec![
g.center[0] as f32,
g.center[1] as f32,
g.center[2] as f32,
g.color[0],
g.color[1],
g.color[2],
g.opacity,
]
})
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let ids: Vec<u64> = (0..splat_cloud.gaussians.len())
.map(|_| {
let id = self.next_id;
self.next_id += 1;
id
})
.collect();
let result = self.store.ingest_batch(&refs, &ids, None)?;
Ok((splat_cloud, result))
}
/// Pack obstacles into the RVF store (dimension must be 6).
///
/// Each obstacle is encoded as `[px, py, pz, distance, radius, confidence]`.
pub fn pack_obstacles(&mut self, obstacles: &[Obstacle]) -> Result<IngestResult> {
self.check_dim(6)?;
if obstacles.is_empty() {
return Err(RvfPackError::EmptyData("no obstacles"));
}
let vectors: Vec<Vec<f32>> = obstacles
.iter()
.map(|o| {
vec![
o.position[0] as f32,
o.position[1] as f32,
o.position[2] as f32,
o.distance as f32,
o.radius as f32,
o.confidence,
]
})
.collect();
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let ids: Vec<u64> = (0..obstacles.len())
.map(|_| {
let id = self.next_id;
self.next_id += 1;
id
})
.collect();
Ok(self.store.ingest_batch(&refs, &ids, None)?)
}
// -- querying ---------------------------------------------------------
/// Query the store for the `k` nearest vectors to `query`.
pub fn query_nearest(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<SearchResult>> {
if query.len() != self.dimension as usize {
return Err(RvfPackError::DimensionMismatch {
expected: self.dimension as usize,
got: query.len(),
});
}
Ok(self.store.query(query, k, &QueryOptions::default())?)
}
// -- lifecycle --------------------------------------------------------
/// Compact the store to reclaim dead space.
pub fn compact(&mut self) -> Result<rvf_runtime::CompactionResult> {
Ok(self.store.compact()?)
}
/// Close the store, flushing all data.
pub fn close(self) -> Result<()> {
Ok(self.store.close()?)
}
// -- internals --------------------------------------------------------
fn check_dim(&self, required: u16) -> Result<()> {
if self.dimension != required {
return Err(RvfPackError::DimensionMismatch {
expected: required as usize,
got: self.dimension as usize,
});
}
Ok(())
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::Point3D;
use tempfile::NamedTempFile;
fn tmp_path() -> std::path::PathBuf {
let f = NamedTempFile::new().unwrap();
let p = f.path().with_extension("rvf");
drop(f);
p
}
#[test]
fn test_pack_point_cloud_and_query() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 3).unwrap();
assert_eq!(rvf.dimension(), 3);
let cloud = PointCloud::new(
vec![
Point3D::new(1.0, 0.0, 0.0),
Point3D::new(2.0, 0.0, 0.0),
Point3D::new(10.0, 0.0, 0.0),
],
1000,
);
let result = rvf.pack_point_cloud(&cloud).unwrap();
assert_eq!(result.accepted, 3);
let hits = rvf.query_nearest(&[1.5, 0.0, 0.0], 2).unwrap();
assert_eq!(hits.len(), 2);
// Nearest should be one of the first two points.
assert!(hits[0].distance < 1.0);
rvf.close().unwrap();
// Verify file was created.
assert!(path.exists());
std::fs::remove_file(&path).ok();
}
#[test]
fn test_pack_scene_objects() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 9).unwrap();
let objects = vec![
SceneObject::new(0, [1.0, 2.0, 0.0], [0.5, 0.5, 1.8]),
SceneObject::new(1, [5.0, 0.0, 0.0], [1.0, 1.0, 2.0]),
];
let result = rvf.pack_scene_objects(&objects).unwrap();
assert_eq!(result.accepted, 2);
let hits = rvf.query_nearest(&[1.0, 2.0, 0.0, 0.5, 0.5, 1.8, 1.0, 0.0, 0.0], 1).unwrap();
assert_eq!(hits.len(), 1);
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
#[test]
fn test_pack_trajectory() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 3).unwrap();
let traj = Trajectory::new(
vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]],
vec![100, 200, 300],
0.95,
);
let result = rvf.pack_trajectory(&traj).unwrap();
assert_eq!(result.accepted, 3);
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
#[test]
fn test_pack_gaussians() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 7).unwrap();
let cloud = PointCloud::new(
vec![
Point3D::new(1.0, 0.0, 0.0),
Point3D::new(1.1, 0.0, 0.0),
Point3D::new(1.0, 0.1, 0.0),
Point3D::new(5.0, 5.0, 0.0),
Point3D::new(5.1, 5.0, 0.0),
Point3D::new(5.0, 5.1, 0.0),
],
1000,
);
let config = GaussianConfig { min_cluster_size: 3, ..Default::default() };
let (splat_cloud, result) = rvf.pack_gaussians(&cloud, &config).unwrap();
assert!(!splat_cloud.is_empty());
assert!(result.accepted > 0);
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
#[test]
fn test_pack_obstacles() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 6).unwrap();
let obstacles = vec![
Obstacle {
id: 0,
position: [2.0, 0.0, 0.0],
distance: 2.0,
radius: 0.5,
label: "person".into(),
confidence: 0.9,
},
];
let result = rvf.pack_obstacles(&obstacles).unwrap();
assert_eq!(result.accepted, 1);
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
#[test]
fn test_dimension_mismatch() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 3).unwrap();
// Trying to pack scene objects (dim 9) into a dim-3 store.
let objects = vec![SceneObject::new(0, [1.0, 0.0, 0.0], [0.5, 0.5, 0.5])];
let result = rvf.pack_scene_objects(&objects);
assert!(result.is_err());
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
#[test]
fn test_empty_data_rejected() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 3).unwrap();
let empty_cloud = PointCloud::default();
assert!(rvf.pack_point_cloud(&empty_cloud).is_err());
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
#[test]
fn test_open_and_requery() {
let path = tmp_path();
{
let mut rvf = RoboticsRvf::create(&path, 3).unwrap();
let cloud = PointCloud::new(
vec![Point3D::new(1.0, 0.0, 0.0), Point3D::new(2.0, 0.0, 0.0)],
1000,
);
rvf.pack_point_cloud(&cloud).unwrap();
rvf.close().unwrap();
}
// Reopen read-only and query.
let rvf = RoboticsRvf::open_readonly(&path).unwrap();
let status = rvf.status();
assert_eq!(status.total_vectors, 2);
let hits = rvf.query_nearest(&[1.0, 0.0, 0.0], 1).unwrap();
assert_eq!(hits.len(), 1);
assert!(hits[0].distance < 0.01);
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
#[test]
fn test_query_dimension_mismatch() {
let path = tmp_path();
let mut rvf = RoboticsRvf::create(&path, 3).unwrap();
let cloud = PointCloud::new(vec![Point3D::new(1.0, 0.0, 0.0)], 0);
rvf.pack_point_cloud(&cloud).unwrap();
// Query with wrong dimension.
let result = rvf.query_nearest(&[1.0, 0.0], 1);
assert!(result.is_err());
rvf.close().unwrap();
std::fs::remove_file(&path).ok();
}
}