Files

606 lines
19 KiB
Rust

//! Input Validation Utilities
//!
//! Provides comprehensive validation for all external inputs to prevent
//! security issues like path traversal, resource exhaustion, and invalid data.
use super::limits::{SecurityConfig, DEFAULT_MAX_NODE_ID_LEN, DEFAULT_MAX_STATE_DIM};
use std::path::{Component, Path};
use thiserror::Error;
/// Validation error types
#[derive(Debug, Error, Clone, PartialEq)]
pub enum ValidationError {
/// Node ID is too long
#[error("Node ID too long: {len} bytes (max: {max})")]
NodeIdTooLong { len: usize, max: usize },
/// Node ID contains invalid characters
#[error("Node ID contains invalid characters: {0}")]
InvalidNodeIdChars(String),
/// Node ID is empty
#[error("Node ID cannot be empty")]
EmptyNodeId,
/// State vector is too large
#[error("State dimension too large: {dim} (max: {max})")]
StateDimensionTooLarge { dim: usize, max: usize },
/// State vector is empty
#[error("State vector cannot be empty")]
EmptyState,
/// State contains invalid float value (NaN or Infinity)
#[error("State contains invalid float at index {index}: {value}")]
InvalidFloat { index: usize, value: String },
/// Matrix dimension too large
#[error("Matrix dimension too large: {dim} (max: {max})")]
MatrixDimensionTooLarge { dim: usize, max: usize },
/// Dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
/// Path traversal attempt detected
#[error("Path traversal detected in: {0}")]
PathTraversal(String),
/// Path contains invalid characters
#[error("Path contains invalid characters: {0}")]
InvalidPathChars(String),
/// Payload too large
#[error("Payload too large: {size} bytes (max: {max})")]
PayloadTooLarge { size: usize, max: usize },
/// Resource limit exceeded
#[error("Resource limit exceeded: {0}")]
ResourceLimitExceeded(String),
/// Custom validation error
#[error("{0}")]
Custom(String),
}
/// Result type for validation operations
pub type ValidationResult<T> = Result<T, ValidationError>;
/// Input validator with configurable limits
#[derive(Debug, Clone)]
pub struct InputValidator {
config: SecurityConfig,
}
impl Default for InputValidator {
fn default() -> Self {
Self::new(SecurityConfig::default())
}
}
impl InputValidator {
/// Create a new validator with the given configuration
#[must_use]
pub fn new(config: SecurityConfig) -> Self {
Self { config }
}
/// Create a validator with strict settings
#[must_use]
pub fn strict() -> Self {
Self::new(SecurityConfig::strict())
}
/// Validate a node ID
///
/// Checks:
/// - Non-empty
/// - Length within limits
/// - Contains only allowed characters (alphanumeric, dash, underscore, dot)
pub fn validate_node_id(&self, id: &str) -> ValidationResult<()> {
if id.is_empty() {
return Err(ValidationError::EmptyNodeId);
}
if id.len() > self.config.max_node_id_len {
return Err(ValidationError::NodeIdTooLong {
len: id.len(),
max: self.config.max_node_id_len,
});
}
if !is_valid_identifier(id) {
return Err(ValidationError::InvalidNodeIdChars(id.to_string()));
}
Ok(())
}
/// Validate a state vector
///
/// Checks:
/// - Non-empty
/// - Dimension within limits
/// - No NaN or Infinity values
pub fn validate_state(&self, state: &[f32]) -> ValidationResult<()> {
if state.is_empty() {
return Err(ValidationError::EmptyState);
}
if state.len() > self.config.graph_limits.max_state_dim {
return Err(ValidationError::StateDimensionTooLarge {
dim: state.len(),
max: self.config.graph_limits.max_state_dim,
});
}
// Check for NaN/Infinity
for (i, &val) in state.iter().enumerate() {
if val.is_nan() {
return Err(ValidationError::InvalidFloat {
index: i,
value: "NaN".to_string(),
});
}
if val.is_infinite() {
return Err(ValidationError::InvalidFloat {
index: i,
value: if val.is_sign_positive() {
"+Infinity"
} else {
"-Infinity"
}
.to_string(),
});
}
}
Ok(())
}
/// Validate matrix dimensions
pub fn validate_matrix_dims(&self, rows: usize, cols: usize) -> ValidationResult<()> {
let max = self.config.resource_limits.max_matrix_dim;
if rows > max {
return Err(ValidationError::MatrixDimensionTooLarge { dim: rows, max });
}
if cols > max {
return Err(ValidationError::MatrixDimensionTooLarge { dim: cols, max });
}
// Also check total elements to prevent memory exhaustion
let total = rows.saturating_mul(cols);
let max_elements = self.config.resource_limits.max_matrix_elements();
if total > max_elements {
return Err(ValidationError::ResourceLimitExceeded(format!(
"Matrix elements: {} (max: {})",
total, max_elements
)));
}
Ok(())
}
/// Validate payload size
pub fn validate_payload_size(&self, size: usize) -> ValidationResult<()> {
if size > self.config.resource_limits.max_payload_size {
return Err(ValidationError::PayloadTooLarge {
size,
max: self.config.resource_limits.max_payload_size,
});
}
Ok(())
}
/// Check if graph can accept more nodes
pub fn check_node_limit(&self, current_count: usize) -> ValidationResult<()> {
if !self.config.graph_limits.can_add_node(current_count) {
return Err(ValidationError::ResourceLimitExceeded(format!(
"Maximum nodes: {}",
self.config.graph_limits.max_nodes
)));
}
Ok(())
}
/// Check if graph can accept more edges
pub fn check_edge_limit(&self, current_count: usize) -> ValidationResult<()> {
if !self.config.graph_limits.can_add_edge(current_count) {
return Err(ValidationError::ResourceLimitExceeded(format!(
"Maximum edges: {}",
self.config.graph_limits.max_edges
)));
}
Ok(())
}
}
/// Path validator for file storage operations
#[derive(Debug, Clone, Default)]
pub struct PathValidator;
impl PathValidator {
/// Validate a path component to prevent traversal attacks
///
/// Rejects:
/// - Empty components
/// - "." or ".." components
/// - Absolute paths or drive letters
/// - Components with path separators
/// - Components starting with "~"
pub fn validate_path_component(component: &str) -> ValidationResult<()> {
if component.is_empty() {
return Err(ValidationError::InvalidPathChars(
"empty component".to_string(),
));
}
// Check for traversal attempts
if component == "." || component == ".." {
return Err(ValidationError::PathTraversal(component.to_string()));
}
// Check for absolute paths
if component.starts_with('/') || component.starts_with('\\') {
return Err(ValidationError::PathTraversal(component.to_string()));
}
// Check for Windows drive letters (C:, D:, etc.)
if component.len() >= 2 && component.chars().nth(1) == Some(':') {
return Err(ValidationError::PathTraversal(component.to_string()));
}
// Check for home directory reference
if component.starts_with('~') {
return Err(ValidationError::PathTraversal(component.to_string()));
}
// Check for path separators within the component
if component.contains('/') || component.contains('\\') {
return Err(ValidationError::PathTraversal(component.to_string()));
}
// Check for null bytes
if component.contains('\0') {
return Err(ValidationError::InvalidPathChars("null byte".to_string()));
}
Ok(())
}
/// Validate a complete path stays within a base directory
pub fn validate_path_within_base(base: &Path, path: &Path) -> ValidationResult<()> {
// Normalize both paths
let base_canonical = match base.canonicalize() {
Ok(p) => p,
Err(_) => base.to_path_buf(),
};
// Build the full path
let full_path = base.join(path);
// Check each component
for component in path.components() {
match component {
Component::ParentDir => {
return Err(ValidationError::PathTraversal(path.display().to_string()));
}
Component::Normal(s) => {
if let Some(s_str) = s.to_str() {
Self::validate_path_component(s_str)?;
}
}
Component::Prefix(_) | Component::RootDir => {
return Err(ValidationError::PathTraversal(path.display().to_string()));
}
Component::CurDir => {}
}
}
// Final check: resolved path should start with base
if let Ok(resolved) = full_path.canonicalize() {
if !resolved.starts_with(&base_canonical) {
return Err(ValidationError::PathTraversal(path.display().to_string()));
}
}
Ok(())
}
}
/// State vector validator
#[derive(Debug, Clone)]
pub struct StateValidator {
max_dim: usize,
}
impl Default for StateValidator {
fn default() -> Self {
Self {
max_dim: DEFAULT_MAX_STATE_DIM,
}
}
}
impl StateValidator {
/// Create a validator with custom max dimension
#[must_use]
pub fn new(max_dim: usize) -> Self {
Self { max_dim }
}
/// Validate state vector and return validated copy
pub fn validate(&self, state: &[f32]) -> ValidationResult<Vec<f32>> {
if state.is_empty() {
return Err(ValidationError::EmptyState);
}
if state.len() > self.max_dim {
return Err(ValidationError::StateDimensionTooLarge {
dim: state.len(),
max: self.max_dim,
});
}
// Check for and handle invalid floats
let mut validated = Vec::with_capacity(state.len());
for (i, &val) in state.iter().enumerate() {
if val.is_nan() {
return Err(ValidationError::InvalidFloat {
index: i,
value: "NaN".to_string(),
});
}
if val.is_infinite() {
return Err(ValidationError::InvalidFloat {
index: i,
value: format!("{}", val),
});
}
validated.push(val);
}
Ok(validated)
}
/// Validate and clamp state values to a range
pub fn validate_and_clamp(
&self,
state: &[f32],
min: f32,
max: f32,
) -> ValidationResult<Vec<f32>> {
if state.is_empty() {
return Err(ValidationError::EmptyState);
}
if state.len() > self.max_dim {
return Err(ValidationError::StateDimensionTooLarge {
dim: state.len(),
max: self.max_dim,
});
}
let mut result = Vec::with_capacity(state.len());
for (i, &val) in state.iter().enumerate() {
if val.is_nan() {
return Err(ValidationError::InvalidFloat {
index: i,
value: "NaN".to_string(),
});
}
// Clamp infinite values to min/max
let clamped = if val.is_infinite() {
if val.is_sign_positive() {
max
} else {
min
}
} else {
val.clamp(min, max)
};
result.push(clamped);
}
Ok(result)
}
}
// ============================================================================
// Standalone validation functions
// ============================================================================
/// Check if a string is a valid identifier (alphanumeric, dash, underscore, dot)
#[must_use]
pub fn is_valid_identifier(s: &str) -> bool {
if s.is_empty() {
return false;
}
// First character must be alphanumeric
let first_char = s.chars().next().unwrap();
if !first_char.is_ascii_alphanumeric() {
return false;
}
// Rest can be alphanumeric, dash, underscore, or dot
s.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
}
/// Check if a state vector is valid (no NaN/Infinity)
#[must_use]
pub fn is_valid_state(state: &[f32]) -> bool {
!state.is_empty() && state.iter().all(|&x| x.is_finite())
}
/// Sanitize a path component by removing unsafe characters
///
/// Returns None if the component cannot be sanitized safely
pub fn sanitize_path_component(component: &str) -> Option<String> {
if component.is_empty() || component == "." || component == ".." {
return None;
}
// Filter to only safe characters
let sanitized: String = component
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
.collect();
if sanitized.is_empty() || sanitized == "." || sanitized == ".." {
return None;
}
Some(sanitized)
}
/// Validate a dimension value
pub fn validate_dimension(dim: usize, max: usize) -> ValidationResult<()> {
if dim == 0 {
return Err(ValidationError::Custom(
"Dimension cannot be zero".to_string(),
));
}
if dim > max {
return Err(ValidationError::MatrixDimensionTooLarge { dim, max });
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_identifier() {
assert!(is_valid_identifier("node1"));
assert!(is_valid_identifier("my-node"));
assert!(is_valid_identifier("my_node"));
assert!(is_valid_identifier("node.v1"));
assert!(is_valid_identifier("Node123"));
assert!(!is_valid_identifier(""));
assert!(!is_valid_identifier("-node"));
assert!(!is_valid_identifier("_node"));
assert!(!is_valid_identifier(".node"));
assert!(!is_valid_identifier("node/path"));
assert!(!is_valid_identifier("node\\path"));
assert!(!is_valid_identifier("node with space"));
}
#[test]
fn test_valid_state() {
assert!(is_valid_state(&[1.0, 2.0, 3.0]));
assert!(is_valid_state(&[0.0]));
assert!(is_valid_state(&[-1.0, 0.0, 1.0]));
assert!(!is_valid_state(&[]));
assert!(!is_valid_state(&[f32::NAN]));
assert!(!is_valid_state(&[f32::INFINITY]));
assert!(!is_valid_state(&[f32::NEG_INFINITY]));
assert!(!is_valid_state(&[1.0, f32::NAN, 3.0]));
}
#[test]
fn test_input_validator_node_id() {
let validator = InputValidator::default();
assert!(validator.validate_node_id("valid-node").is_ok());
assert!(validator.validate_node_id("node123").is_ok());
assert!(validator.validate_node_id("").is_err());
assert!(validator.validate_node_id("../traversal").is_err());
assert!(validator.validate_node_id("with space").is_err());
}
#[test]
fn test_input_validator_state() {
let validator = InputValidator::default();
assert!(validator.validate_state(&[1.0, 2.0, 3.0]).is_ok());
assert!(validator.validate_state(&[]).is_err());
assert!(validator.validate_state(&[f32::NAN]).is_err());
assert!(validator.validate_state(&[f32::INFINITY]).is_err());
}
#[test]
fn test_path_validator() {
assert!(PathValidator::validate_path_component("valid_name").is_ok());
assert!(PathValidator::validate_path_component("file.txt").is_ok());
assert!(PathValidator::validate_path_component("").is_err());
assert!(PathValidator::validate_path_component(".").is_err());
assert!(PathValidator::validate_path_component("..").is_err());
assert!(PathValidator::validate_path_component("../etc").is_err());
assert!(PathValidator::validate_path_component("/etc").is_err());
assert!(PathValidator::validate_path_component("C:\\").is_err());
assert!(PathValidator::validate_path_component("~user").is_err());
}
#[test]
fn test_sanitize_path() {
assert_eq!(
sanitize_path_component("valid_name"),
Some("valid_name".to_string())
);
assert_eq!(
sanitize_path_component("file.txt"),
Some("file.txt".to_string())
);
assert_eq!(
sanitize_path_component("bad/path"),
Some("badpath".to_string())
);
assert_eq!(
sanitize_path_component("bad\\path"),
Some("badpath".to_string())
);
assert_eq!(sanitize_path_component(""), None);
assert_eq!(sanitize_path_component("."), None);
assert_eq!(sanitize_path_component(".."), None);
assert_eq!(sanitize_path_component("///"), None);
}
#[test]
fn test_state_validator() {
let validator = StateValidator::new(100);
assert!(validator.validate(&[1.0, 2.0]).is_ok());
assert!(validator.validate(&[]).is_err());
assert!(validator.validate(&[f32::NAN]).is_err());
let large: Vec<f32> = (0..101).map(|x| x as f32).collect();
assert!(validator.validate(&large).is_err());
}
#[test]
fn test_state_validator_clamp() {
let validator = StateValidator::new(100);
let result = validator.validate_and_clamp(&[f32::INFINITY, -1.0, 0.5], -1.0, 1.0);
assert!(result.is_ok());
let clamped = result.unwrap();
assert_eq!(clamped, vec![1.0, -1.0, 0.5]);
}
#[test]
fn test_matrix_validation() {
let validator = InputValidator::default();
assert!(validator.validate_matrix_dims(100, 100).is_ok());
assert!(validator.validate_matrix_dims(8192, 8192).is_ok());
assert!(validator.validate_matrix_dims(10000, 10000).is_err());
}
#[test]
fn test_dimension_validation() {
assert!(validate_dimension(100, 1000).is_ok());
assert!(validate_dimension(0, 1000).is_err());
assert!(validate_dimension(1001, 1000).is_err());
}
}